#include "mex.h"
#include "Mates.h"
#include <stdio.h>
#include <math.h>
#include <float.h>
#include <string.h>
#include <stdlib.h>

/* 
E. Lpez-Rubio, M.N. Florentn-Nez, Kernel regression based feature extraction for 3D MR image denoising
Medical Image Analysis. DOI:10.1016/j.media.2011.02.006

Use the following commands to compile this MEX file at the Matlab prompt:
32-bit Windows:
mex SteeringKernelRegression3D.c MatesLap.c Debugging.c lapack.a blas.a libf2c.a
64-bit Windows:
mex LINKFLAGS="$LINKFLAGS /NODEFAULTLIB:LIBCMT" SteeringKernelRegression3D.c MatesLap.c Debugging.c BLAS_nowrap.lib libf2c.lib clapack_nowrap.lib

Usage:
[RecImg, GradX, GradY, GradZ] = SteeringKernelRegression3D(NoisyImg, IsClean, h, C, ksize);

Notes:
RecImg=Reconstructed image
GradX=Gradient in the X direction
GradY=Gradient in the Y direction
GradZ=Gradient in the Z direction
NoisyImg=Input (noisy) image
IsClean=Values of the membership function which indicates whether a given voxel is not corrupted
h=Kernel regression global smoothing parameter
C=steering matrices map
ksize=Kernel size


*/


/* Extend a 2D matrix (mirror extension at the edges) */
void Extend2D(double *Dest,double *Orig,int M, int N,int radius);

/* Extend a 3D matrix (mirror extension at the edges) */
void Extend3D(double *Dest,double *Orig,int M,int N,int P,int radius);

void mexFunction(int nlhs, mxArray* plhs[],
                 int nrhs, const mxArray* prhs[])
{  
	double *Memberships,*NoisyImg,*RecImg,*GradX,*GradY,*GradZ,*C;
	double *Xx,*XxT,*Xw,*XwT,*XxTXw,*Inv,*A,*AT,*XCoord,*YCoord,*ZCoord,*W;
	double *NoisyImgExt,*MembershipsExt,*Window,*WindowMemb;
	double *WindowC11,*WindowC12,*WindowC13,*WindowC22,*WindowC23,*WindowC33;
	double *WindowSqrtDetC;
	double *C11,*C12,*C13,*C22,*C23,*C33;
	double *C11Ext,*C12Ext,*C13Ext,*C22Ext,*C23Ext,*C33Ext,*SqrtDetC,*SqrtDetCExt;
	double h,ArgExp;
	int radius,ksize,ksize3,M,N,P,NdxWin,Ndx,Cnt,NdxX,NdxY,NdxZ,OrigOffset,DestOffset;
	const int *DimsInput;

    /* Get input data */
    h=mxGetScalar(prhs[2]);    
	ksize=(int)mxGetScalar(prhs[4]);  
	ksize3=ksize*ksize*ksize;
	NoisyImg=mxGetPr(prhs[0]);
	Memberships=mxGetPr(prhs[1]);
	C=mxGetPr(prhs[3]);
    
    
    /* Create output arrays */
	DimsInput=mxGetDimensions(prhs[0]);
	M=DimsInput[0];  /* Size in the X dimension of the input image */
	N=DimsInput[1];  /* Size in the Y dimension of the input image */
	P=DimsInput[2];  /* Size in the Z dimension of the input image */
    plhs[0]=mxCreateNumericArray(3,DimsInput,mxDOUBLE_CLASS,mxREAL);
	RecImg=mxGetPr(plhs[0]);
    plhs[1]=mxCreateNumericArray(3,DimsInput,mxDOUBLE_CLASS,mxREAL);
	GradX=mxGetPr(plhs[1]);
    plhs[2]=mxCreateNumericArray(3,DimsInput,mxDOUBLE_CLASS,mxREAL);
	GradY=mxGetPr(plhs[2]);
    plhs[3]=mxCreateNumericArray(3,DimsInput,mxDOUBLE_CLASS,mxREAL);
	GradZ=mxGetPr(plhs[3]);

    
    
    /* Get working variables */
	radius = (ksize-1)/2;
         
    /* Prepare auxiliary arrays */
    Xx=mxMalloc(ksize3*10*sizeof(double));
	XxT=mxMalloc(10*ksize3*sizeof(double));
	Xw=mxMalloc(ksize3*10*sizeof(double));
	XwT=mxMalloc(10*ksize3*sizeof(double));
	XxTXw=mxMalloc(10*10*sizeof(double));
	Inv=mxMalloc(10*10*sizeof(double));
	A=mxMalloc(10*ksize3*sizeof(double));
	AT=mxMalloc(ksize3*10*sizeof(double));
	XCoord=mxMalloc(ksize3*sizeof(double));
	YCoord=mxMalloc(ksize3*sizeof(double));	
	ZCoord=mxMalloc(ksize3*sizeof(double));	
	W=mxMalloc(ksize3*sizeof(double));
	NoisyImgExt=mxMalloc((N+2*radius)*(M+2*radius)*(P+2*radius)*sizeof(double));
	MembershipsExt=mxMalloc((N+2*radius)*(M+2*radius)*(P+2*radius)*sizeof(double));
	C11=mxMalloc(N*M*P*sizeof(double));
	C12=mxMalloc(N*M*P*sizeof(double));
	C13=mxMalloc(N*M*P*sizeof(double));
	C22=mxMalloc(N*M*P*sizeof(double));
	C23=mxMalloc(N*M*P*sizeof(double));
	C33=mxMalloc(N*M*P*sizeof(double));
	SqrtDetC=mxMalloc(N*M*P*sizeof(double));
	C11Ext=mxMalloc((N+2*radius)*(M+2*radius)*(P+2*radius)*sizeof(double));	
	C12Ext=mxMalloc((N+2*radius)*(M+2*radius)*(P+2*radius)*sizeof(double));	
	C13Ext=mxMalloc((N+2*radius)*(M+2*radius)*(P+2*radius)*sizeof(double));	
	C22Ext=mxMalloc((N+2*radius)*(M+2*radius)*(P+2*radius)*sizeof(double));	
	C23Ext=mxMalloc((N+2*radius)*(M+2*radius)*(P+2*radius)*sizeof(double));	
	C33Ext=mxMalloc((N+2*radius)*(M+2*radius)*(P+2*radius)*sizeof(double));	
	SqrtDetCExt=mxMalloc((N+2*radius)*(M+2*radius)*(P+2*radius)*sizeof(double));	
	Window=mxMalloc(ksize3*sizeof(double));
	WindowMemb=mxMalloc(ksize3*sizeof(double));
	WindowC11=mxMalloc(ksize3*sizeof(double));
	WindowC12=mxMalloc(ksize3*sizeof(double));
	WindowC13=mxMalloc(ksize3*sizeof(double));
	WindowC22=mxMalloc(ksize3*sizeof(double));
	WindowC23=mxMalloc(ksize3*sizeof(double));
	WindowC33=mxMalloc(ksize3*sizeof(double));
	WindowSqrtDetC=mxMalloc(ksize3*sizeof(double));


	/* Prepare window coordinates */
	for(NdxWin=0;NdxWin<ksize3;NdxWin++)
	{
		XCoord[NdxWin]=(NdxWin%ksize)-radius;
		YCoord[NdxWin]=((NdxWin/ksize)%ksize)-radius;
		ZCoord[NdxWin]=(NdxWin/(ksize*ksize))-radius;
	}

	

	/* Prepare feature matrix and its traspose */
	for(NdxWin=0;NdxWin<ksize3;NdxWin++)
	{
		Xx[NdxWin]=1.0;
		Xx[ksize3+NdxWin]=XCoord[NdxWin];
		Xx[2*ksize3+NdxWin]=YCoord[NdxWin];
		Xx[3*ksize3+NdxWin]=ZCoord[NdxWin];
		Xx[4*ksize3+NdxWin]=XCoord[NdxWin]*XCoord[NdxWin];
		Xx[5*ksize3+NdxWin]=XCoord[NdxWin]*YCoord[NdxWin];
		Xx[6*ksize3+NdxWin]=XCoord[NdxWin]*ZCoord[NdxWin];
		Xx[7*ksize3+NdxWin]=YCoord[NdxWin]*YCoord[NdxWin];
		Xx[8*ksize3+NdxWin]=YCoord[NdxWin]*ZCoord[NdxWin];
		Xx[9*ksize3+NdxWin]=ZCoord[NdxWin]*ZCoord[NdxWin];
	}	
	Traspose(Xx,XxT,ksize3,10);

	/* Prepare steering matrices */
	for(Ndx=0;Ndx<N*M*P;Ndx++)
	{
		C11[Ndx]=C[9*Ndx];
		C12[Ndx]=C[9*Ndx+1];
		C13[Ndx]=C[9*Ndx+2];
		C22[Ndx]=C[9*Ndx+4];
		C23[Ndx]=C[9*Ndx+5];
		C33[Ndx]=C[9*Ndx+8];
		SqrtDetC[Ndx]=sqrt(C[9*Ndx]*C[9*Ndx+4]*C[9*Ndx+8]+
			2.0*C[9*Ndx+1]*C[9*Ndx+5]*C[9*Ndx+2]-
			C[9*Ndx+2]*C[9*Ndx+4]*C[9*Ndx+2]-
			C[9*Ndx]*C[9*Ndx+5]*C[9*Ndx+5]-
			C[9*Ndx+8]*C[9*Ndx+1]*C[9*Ndx+1]);
	}

	/* Extend the noisy image, the memberships map and the maps for steering matrices*/
	Extend3D(NoisyImgExt,NoisyImg,M,N,P,radius);	
	Extend3D(MembershipsExt,Memberships,M,N,P,radius);	
	Extend3D(C11Ext,C11,M,N,P,radius);
	Extend3D(C12Ext,C12,M,N,P,radius);
	Extend3D(C13Ext,C13,M,N,P,radius);
	Extend3D(C22Ext,C22,M,N,P,radius);
	Extend3D(C23Ext,C23,M,N,P,radius);
	Extend3D(C33Ext,C33,M,N,P,radius);
	Extend3D(SqrtDetCExt,SqrtDetC,M,N,P,radius);

	/*---------*/

	/* Estimate the original image and the gradients */
	for(NdxZ=0;NdxZ<P;NdxZ++)
	{
		for(NdxY=0;NdxY<N;NdxY++)
		{
			for(NdxX=0;NdxX<M;NdxX++)
			{
				/* Get the window with the relevant noisy samples and the rest of parameters */
				for(NdxWin=0;NdxWin<ksize*ksize;NdxWin++)
				{
					DestOffset=NdxWin*ksize;
					OrigOffset=(NdxZ+(NdxWin/ksize))*(M+2*radius)*(N+2*radius)
								+(NdxY+(NdxWin%ksize))*(M+2*radius)+NdxX;
					memcpy(Window+DestOffset,NoisyImgExt+OrigOffset,
							ksize*sizeof(double));
					memcpy(WindowC11+DestOffset,C11Ext+OrigOffset,
							ksize*sizeof(double));
					memcpy(WindowC12+DestOffset,C12Ext+OrigOffset,
							ksize*sizeof(double));
					memcpy(WindowC13+DestOffset,C13Ext+OrigOffset,
							ksize*sizeof(double));
					memcpy(WindowC22+DestOffset,C22Ext+OrigOffset,
							ksize*sizeof(double));
					memcpy(WindowC23+DestOffset,C23Ext+OrigOffset,
							ksize*sizeof(double));
					memcpy(WindowC33+DestOffset,C33Ext+OrigOffset,
							ksize*sizeof(double));
					memcpy(WindowSqrtDetC+DestOffset,SqrtDetCExt+OrigOffset,
							ksize*sizeof(double));
					memcpy(WindowMemb+DestOffset,MembershipsExt+OrigOffset,
							ksize*sizeof(double));
				}
				
				/* Prepare weight matrix */	
				for(NdxWin=0;NdxWin<ksize3;NdxWin++)
				{	
					W[NdxWin]=
						Xx[ksize3+NdxWin]*
						(WindowC11[NdxWin]*Xx[ksize3+NdxWin]+WindowC12[NdxWin]*Xx[2*ksize3+NdxWin]+
						WindowC13[NdxWin]*Xx[3*ksize3+NdxWin])

						+Xx[2*ksize3+NdxWin]*
						(WindowC12[NdxWin]*Xx[ksize3+NdxWin]+WindowC22[NdxWin]*Xx[2*ksize3+NdxWin]+
						WindowC23[NdxWin]*Xx[3*ksize3+NdxWin])

						+Xx[3*ksize3+NdxWin]*
						(WindowC13[NdxWin]*Xx[ksize3+NdxWin]+WindowC23[NdxWin]*Xx[2*ksize3+NdxWin]+
						WindowC33[NdxWin]*Xx[3*ksize3+NdxWin]);

					/* Avoid evaluation of exp() if the result is zero */
					ArgExp=(0.5/(h*h)) * W[NdxWin];
					if (ArgExp>750.0)
					{
						W[NdxWin]=0.0;
					}
					else
					{
						W[NdxWin]=exp(-ArgExp)*WindowSqrtDetC[NdxWin];
					}
				}

				/* The weights multiplied by the responsibilities */			
				for(NdxWin=0;NdxWin<ksize3;NdxWin++)
				{
					WindowMemb[NdxWin]*=W[NdxWin];
				}
				

				/* Equivalent kernel and its traspose */
				Ndx=0;
				for(Cnt=0;Cnt<10;Cnt++)
				{
					for(NdxWin=0;NdxWin<ksize3;NdxWin++)
					{
						Xw[Ndx]=Xx[Ndx]*WindowMemb[NdxWin];
						Ndx++;
					}					
				}			
				Traspose(Xw,XwT,ksize3,10);

				/* XxTXw = Xx^T * Xw */
				MatrixProduct(XxT,Xw,XxTXw,10,ksize3,10);
				/* Add 0.00001 to the elements of the diagonal */
				SumDiagonalConstant(XxTXw,0.00001,NULL,10);
				/* Invert it */
				Inverse(XxTXw,Inv,10);

				/* A=Inv*(Xw^T)*/
				MatrixProduct(Inv,XwT,A,10,10,ksize3);

				Traspose(A,AT,10,ksize3);

				/* Estimate the original image and the gradients at this position */
				(*RecImg)=0.0;
				(*GradX)=0.0;
				(*GradY)=0.0;
				(*GradZ)=0.0;
				for(NdxWin=0;NdxWin<ksize3;NdxWin++)
				{
					(*RecImg)+=AT[NdxWin]*Window[NdxWin];
					(*GradX)+=AT[ksize3+NdxWin]*Window[NdxWin];
					(*GradY)+=AT[2*ksize3+NdxWin]*Window[NdxWin];
					(*GradZ)+=AT[3*ksize3+NdxWin]*Window[NdxWin];
				}

				
							
				/* Advance output pointers */
				RecImg++;
				GradX++;
				GradY++;
				GradZ++;
			}
		}

	}

	/* Release memory */
	mxFree(Xx);
	mxFree(XxT);
	mxFree(Xw);
	mxFree(XwT);
	mxFree(XxTXw);
	mxFree(Inv);
	mxFree(A);
	mxFree(AT);
	mxFree(XCoord);
	mxFree(YCoord);
	mxFree(ZCoord);
	mxFree(W);
	mxFree(NoisyImgExt);
	mxFree(MembershipsExt);
	mxFree(Window);
	mxFree(WindowMemb);
	mxFree(C11);
	mxFree(C12);
	mxFree(C13);
	mxFree(C22);
	mxFree(C23);
	mxFree(C33);
	mxFree(SqrtDetC);
	mxFree(C11Ext);
	mxFree(C12Ext);
	mxFree(C13Ext);
	mxFree(C22Ext);
	mxFree(C23Ext);
	mxFree(C33Ext);
	mxFree(SqrtDetCExt);
	mxFree(WindowC11);
	mxFree(WindowC12);
	mxFree(WindowC13);
	mxFree(WindowC22);
	mxFree(WindowC23);
	mxFree(WindowC33);
	mxFree(WindowSqrtDetC);
    
}    



/* Extend a 2D matrix (mirror extension at the edges) */
void Extend2D(double *Dest,double *Orig,int M, int N,int radius)
{
	int Offset,NdxCol,NdxRow,cnt;

	/* Copy the original matrix to the destination */
	Offset=radius*(M+2*radius)+radius;
	for(NdxCol=0;NdxCol<N;NdxCol++)
	{		
		memcpy(Dest+Offset+NdxCol*(M+2*radius),Orig+NdxCol*M,
			M*sizeof(double));
	}

	/* Extend to the left */
	for(NdxCol=0;NdxCol<radius;NdxCol++)
	{		
		memcpy(Dest+radius+NdxCol*(M+2*radius),Orig+(radius-NdxCol)*M,
			M*sizeof(double));
	}

	/* Extend to the right */
	Offset=(N+radius)*(M+2*radius)+radius;
	for(NdxCol=0;NdxCol<radius;NdxCol++)
	{		
		memcpy(Dest+Offset+NdxCol*(M+2*radius),Orig+(N-2-NdxCol)*M,
			M*sizeof(double));
	}

	/* Extend to the top */	
	for(NdxCol=0;NdxCol<N+2*radius;NdxCol++)
	{		
		for(NdxRow=0;NdxRow<radius;NdxRow++)
		{
			Dest[NdxRow+NdxCol*(M+2*radius)]=Dest[2*radius-NdxRow+NdxCol*(M+2*radius)];
		}		
	}

	/* Extend to the bottom */	
	for(NdxCol=0;NdxCol<N+2*radius;NdxCol++)
	{		
		for(NdxRow=radius+M,cnt=0;NdxRow<M+2*radius;NdxRow++,cnt++)
		{
			Dest[NdxRow+NdxCol*(M+2*radius)]=Dest[M+radius-2-cnt+NdxCol*(M+2*radius)];
		}		
	}

}


/* Extend a 3D matrix (mirror extension at the edges) */
void Extend3D(double *Dest,double *Orig,int M,int N,int P,int radius)
{
	int ndx;

	/* Extend the 2D slices of the original matrix */
	for(ndx=0;ndx<P;ndx++)
	{
		Extend2D(Dest+(N+2*radius)*(M+2*radius)*(ndx+radius),
			Orig+N*M*ndx,M,N,radius);
	}

	/* Top of the 3D matrix */
	for(ndx=0;ndx<radius;ndx++)
	{
		memcpy(Dest+(N+2*radius)*(M+2*radius)*ndx,
			Dest+(N+2*radius)*(M+2*radius)*(2*radius-ndx-1),
			(N+2*radius)*(M+2*radius)*sizeof(double));
	}

	/* Bottom of the 3D matrix */
	for(ndx=0;ndx<radius;ndx++)
	{
		memcpy(Dest+(N+2*radius)*(M+2*radius)*(P+radius+ndx),
			Dest+(N+2*radius)*(M+2*radius)*(P+radius-ndx-1),
			(N+2*radius)*(M+2*radius)*sizeof(double));
	}
}
