#include "mex.h"
#include "Mates.h"
#include <stdio.h>
#include <math.h>
#include <float.h>
#include <string.h>
#include <stdlib.h>

/* 
E. Lpez-Rubio, M.N. Florentn-Nez, Kernel regression based feature extraction for 3D MR image denoising
Medical Image Analysis. DOI:10.1016/j.media.2011.02.006

Use the following commands to compile this MEX file at the Matlab prompt:
32-bit Windows:
mex SteeringMatrix3D.c MatesLap.c Debugging.c lapack.a blas.a libf2c.a
64-bit Windows:
mex LINKFLAGS="$LINKFLAGS /NODEFAULTLIB:LIBCMT" SteeringMatrix3D.c MatesLap.c Debugging.c BLAS_nowrap.lib libf2c.lib clapack_nowrap.lib


Usage:
C = SteeringMatrix3D(GradX, GradY, GradZ, IsClean, wsize, lambda);

Notes:
GradX=Gradient in the X direction
GradY=Gradient in the Y direction
GradZ=Gradient in the Z direction
IsClean=Values of the membership function which indicates whether a given voxel is not corrupted
wsize=size of local analysis window
lambda=regularization parameter
C=map of steering matrices


*/

/* Extend a 2D matrix (mirror extension at the edges) */
void Extend2D(double *Dest,double *Orig,int M, int N,int radius);

/* Extend a 3D matrix (mirror extension at the edges) */
void Extend3D(double *Dest,double *Orig,int M,int N,int P,int radius);

void mexFunction(int nlhs, mxArray* plhs[],
                 int nrhs, const mxArray* prhs[])
{  
	double *Memberships,*GradX,*GradY,*GradZ,*C,*K;
	double *GradXExt,*GradYExt,*GradZExt,*G,*U;
	double S[3];
	double V[9];
	double lambda,aux,S1,S2,S3;
	double AuxMat[9];
	double AuxMat2[9];
	int radius,wsize,wsize3,M,N,P,NdxWin,NdxX,NdxY,NdxZ;
	int Dims[5];
	const int *DimsInput;
	mxArray *My_plhs[1];
	mxArray *My_prhs[1];

    /* Get input data */      
	wsize=(int)mxGetScalar(prhs[4]);  
	wsize3=wsize*wsize*wsize;
	lambda=mxGetScalar(prhs[5]);    
	GradX=mxGetPr(prhs[0]);
	GradY=mxGetPr(prhs[1]);
	GradZ=mxGetPr(prhs[2]);
	Memberships=mxGetPr(prhs[3]);
    
    
    /* Create output array */
	DimsInput=mxGetDimensions(prhs[0]);
	M=DimsInput[0];  /* Size in the X dimension of the input image */
	N=DimsInput[1];  /* Size in the Y dimension of the input image */
	P=DimsInput[2];  /* Size in the Z dimension of the input image */
	Dims[0]=3;
	Dims[1]=3;
	Dims[2]=M;
	Dims[3]=N;
	Dims[4]=P;
    plhs[0]=mxCreateNumericArray(5,Dims,mxDOUBLE_CLASS,mxREAL);
	C=mxGetPr(plhs[0]);
	    
    
    /* Get working variables */
	if (wsize%2==0)
	{
		wsize++;
	}
	radius = wsize/2;

	/* Get the filter */
	My_prhs[0]=mxCreateDoubleMatrix(1,1,mxREAL);
	(*mxGetPr(My_prhs[0]))=(double)radius;
	mexCallMATLAB(1,My_plhs,1,My_prhs,"diskfilter3d");
	K=mxGetPr(My_plhs[0]);

    /* Prepare auxiliary arrays */
	GradXExt=mxMalloc((M+2*radius)*(N+2*radius)*(P+2*radius)*sizeof(double));
	GradYExt=mxMalloc((M+2*radius)*(N+2*radius)*(P+2*radius)*sizeof(double));
	GradZExt=mxMalloc((M+2*radius)*(N+2*radius)*(P+2*radius)*sizeof(double));
	G=mxMalloc(wsize3*3*sizeof(double));
	U=mxMalloc(wsize3*3*sizeof(double));

	/* Extend the gradient maps */
	Extend3D(GradXExt,GradX,M,N,P,radius);	
	Extend3D(GradYExt,GradY,M,N,P,radius);	
	Extend3D(GradZExt,GradZ,M,N,P,radius);	
	

	/* Estimate the steering matrices */
	for(NdxZ=0;NdxZ<P;NdxZ++)
	{
		for(NdxY=0;NdxY<N;NdxY++)
		{
			for(NdxX=0;NdxX<M;NdxX++)
			{
				/* Get the window with the relevant gradient values */
				for(NdxWin=0;NdxWin<wsize*wsize;NdxWin++)
				{
					memcpy(G+NdxWin*wsize,
						GradXExt+(NdxZ+(NdxWin/wsize))*(M+2*radius)*(N+2*radius)
							+(NdxY+(NdxWin%wsize))*(M+2*radius)+NdxX,
						wsize*sizeof(double));
					memcpy(G+NdxWin*wsize+wsize3,
						GradYExt+(NdxZ+(NdxWin/wsize))*(M+2*radius)*(N+2*radius)
							+(NdxY+(NdxWin%wsize))*(M+2*radius)+NdxX,
						wsize*sizeof(double));
					memcpy(G+NdxWin*wsize+2*wsize3,
						GradZExt+(NdxZ+(NdxWin/wsize))*(M+2*radius)*(N+2*radius)
							+(NdxY+(NdxWin%wsize))*(M+2*radius)+NdxX,
						wsize*sizeof(double));
				}

				

				/* Multiply by the filter */
				for(NdxWin=0;NdxWin<wsize3;NdxWin++)
				{
					G[NdxWin]*=K[NdxWin];
					G[wsize3+NdxWin]*=K[NdxWin];
					G[2*wsize3+NdxWin]*=K[NdxWin];
				}
			

				/* Compute Singular Value Decomposition */
				EconomySVD(G,S,U,V,wsize3,3);

				
				/* Regularize singular values */
				S1 = S[0] + lambda;
				S2 = S[1] + lambda;
				S3 = S[2] + lambda;
				
				/* Compute the steering matrix */			
				MatrixProduct(V,V,AuxMat,3,1,3);
				ScalarMatrixProduct(S1,AuxMat,AuxMat,3,3);
				MatrixProduct(V+3,V+3,AuxMat2,3,1,3);
				ScalarMatrixProduct(S2,AuxMat2,AuxMat2,3,3);
				MatrixSum(AuxMat,AuxMat2,AuxMat,3,3);
				MatrixProduct(V+6,V+6,AuxMat2,3,1,3);
				ScalarMatrixProduct(S3,AuxMat2,AuxMat2,3,3);
				MatrixSum(AuxMat,AuxMat2,AuxMat,3,3);

				/* Apply adaptive weighting and write on output */
				ScalarMatrixProduct(1.0/Memberships[NdxZ*M*N+NdxY*M+NdxX],AuxMat,
					C+9*(NdxZ*M*N+NdxY*M+NdxX),3,3);
			}

		}

	}

	/* Release memory */
	mxDestroyArray(My_plhs[0]);
	mxDestroyArray(My_prhs[0]);
	mxFree(GradXExt);
	mxFree(GradYExt);
	mxFree(GradZExt);
	mxFree(G);
	mxFree(U);
    
}    



/* Extend a 2D matrix (mirror extension at the edges) */
void Extend2D(double *Dest,double *Orig,int M, int N,int radius)
{
	int Offset,NdxCol,NdxRow,cnt;

	/* Copy the original matrix to the destination */
	Offset=radius*(M+2*radius)+radius;
	for(NdxCol=0;NdxCol<N;NdxCol++)
	{		
		memcpy(Dest+Offset+NdxCol*(M+2*radius),Orig+NdxCol*M,
			M*sizeof(double));
	}

	/* Extend to the left */
	for(NdxCol=0;NdxCol<radius;NdxCol++)
	{		
		memcpy(Dest+radius+NdxCol*(M+2*radius),Orig+(radius-NdxCol)*M,
			M*sizeof(double));
	}

	/* Extend to the right */
	Offset=(N+radius)*(M+2*radius)+radius;
	for(NdxCol=0;NdxCol<radius;NdxCol++)
	{		
		memcpy(Dest+Offset+NdxCol*(M+2*radius),Orig+(N-2-NdxCol)*M,
			M*sizeof(double));
	}

	/* Extend to the top */	
	for(NdxCol=0;NdxCol<N+2*radius;NdxCol++)
	{		
		for(NdxRow=0;NdxRow<radius;NdxRow++)
		{
			Dest[NdxRow+NdxCol*(M+2*radius)]=Dest[2*radius-NdxRow+NdxCol*(M+2*radius)];
		}		
	}

	/* Extend to the bottom */	
	for(NdxCol=0;NdxCol<N+2*radius;NdxCol++)
	{		
		for(NdxRow=radius+M,cnt=0;NdxRow<M+2*radius;NdxRow++,cnt++)
		{
			Dest[NdxRow+NdxCol*(M+2*radius)]=Dest[M+radius-2-cnt+NdxCol*(M+2*radius)];
		}		
	}

}


/* Extend a 3D matrix (mirror extension at the edges) */
void Extend3D(double *Dest,double *Orig,int M,int N,int P,int radius)
{
	int ndx;

	/* Extend the 2D slices of the original matrix */
	for(ndx=0;ndx<P;ndx++)
	{
		Extend2D(Dest+(N+2*radius)*(M+2*radius)*(ndx+radius),
			Orig+N*M*ndx,M,N,radius);
	}

	/* Top of the 3D matrix */
	for(ndx=0;ndx<radius;ndx++)
	{
		memcpy(Dest+(N+2*radius)*(M+2*radius)*ndx,
			Dest+(N+2*radius)*(M+2*radius)*(2*radius-ndx-1),
			(N+2*radius)*(M+2*radius)*sizeof(double));
	}

	/* Bottom of the 3D matrix */
	for(ndx=0;ndx<radius;ndx++)
	{
		memcpy(Dest+(N+2*radius)*(M+2*radius)*(P+radius+ndx),
			Dest+(N+2*radius)*(M+2*radius)*(P+radius-ndx-1),
			(N+2*radius)*(M+2*radius)*sizeof(double));
	}
}
