#include "mex.h"
#include "Mates.h"
#include <stdio.h>
#include <math.h>
#include <float.h>
#include <string.h>
#include <stdlib.h>

/* 

E. Lpez-Rubio, M.N. Florentn-Nez, Kernel regression based feature extraction for 3D MR image denoising
Medical Image Analysis. DOI:10.1016/j.media.2011.02.006

Use the following commands to compile this MEX file at the Matlab prompt:
32-bit Windows:
mex ZerothOrderKernelRegression3D.c MatesLap.c Debugging.c lapack.a blas.a libf2c.a
64-bit Windows:
mex LINKFLAGS="$LINKFLAGS /NODEFAULTLIB:LIBCMT" ZerothOrderKernelRegression3D.c MatesLap.c Debugging.c BLAS_nowrap.lib libf2c.lib clapack_nowrap.lib

Usage:
[RecImg] = ZerothOrderKernelRegression3D(NoisyImg,SmoothImg,GradX,GradY,GradZ,h,ksize);

Notes:
RecImg=Image reconstructed by zeroth order kernel regression
NoisyImg=Input (noisy) image
SmoothImg=Smoothed version of the input (noisy) image
GradX=Gradient in the X direction
GradY=Gradient in the Y direction
GradZ=Gradient in the Z direction
h=Smoothing parameter
ksize=Side of the search window


*/


/* Extend a 2D matrix (mirror extension at the edges) */
void Extend2D(double *Dest,double *Orig,int M, int N,int radius);

/* Extend a 3D matrix (mirror extension at the edges) */
void Extend3D(double *Dest,double *Orig,int M,int N,int P,int radius);


void mexFunction(int nlhs, mxArray* plhs[],
                 int nrhs, const mxArray* prhs[])
{  
	double *SmoothImg,*NoisyImg,*RecImg,*GradX,*GradY,*GradZ;
	double *Window,*WindowT,*WindowNoisy,*Weights;
	double *NoisyImgExt,*SmoothImgExt,*GradXExt,*GradYExt,*GradZExt;
	double ArgExp,MaxWeight,SumWeights,h,MySquaredNorm,ThisPixel[4],VectorDiff[4];
	int radius,ksize,ksize3,M,N,P,NdxZ,NdxY,NdxX,OrigOffset,DestOffset;
	int NdxCenter,ThisNdx,NdxWin;
	const int *Dims;

    /* Get input data */
    h=mxGetScalar(prhs[5]);    
	ksize=(int)mxGetScalar(prhs[6]);  
	ksize3=ksize*ksize*ksize;
	NoisyImg=mxGetPr(prhs[0]);
	SmoothImg=mxGetPr(prhs[1]);
	GradX=mxGetPr(prhs[2]);
	GradY=mxGetPr(prhs[3]);
	GradZ=mxGetPr(prhs[4]);
    
    
    /* Create output array */
	Dims=mxGetDimensions(prhs[0]);
	M=Dims[0];  /* Size in the X dimension of the input image */
	N=Dims[1];  /* Size in the Y dimension of the input image */
	P=Dims[2];  /* Size in the Z dimension of the input image */
	plhs[0]=mxCreateNumericArray(3,Dims,mxDOUBLE_CLASS,mxREAL); 
	RecImg=mxGetPr(plhs[0]);	
    
    
    /* Get working variables */
	radius = (ksize-1)/2;
         
    /* Prepare auxiliary arrays */
	NoisyImgExt=mxMalloc((N+2*radius)*(M+2*radius)*(P+2*radius)*sizeof(double));
	SmoothImgExt=mxMalloc((N+2*radius)*(M+2*radius)*(P+2*radius)*sizeof(double));
	GradXExt=mxMalloc((N+2*radius)*(M+2*radius)*(P+2*radius)*sizeof(double));
	GradYExt=mxMalloc((N+2*radius)*(M+2*radius)*(P+2*radius)*sizeof(double));
	GradZExt=mxMalloc((N+2*radius)*(M+2*radius)*(P+2*radius)*sizeof(double));
	WindowT=mxMalloc(ksize3*4*sizeof(double));
	Window=mxMalloc(4*ksize3*sizeof(double));
	WindowNoisy=mxMalloc(ksize3*sizeof(double));
	Weights=mxMalloc(ksize3*sizeof(double));


	/* Extend the noisy image, the smoothed image and the gradient images */
	Extend3D(NoisyImgExt,NoisyImg,M,N,P,radius);	
	Extend3D(SmoothImgExt,SmoothImg,M,N,P,radius);
	Extend3D(GradXExt,GradX,M,N,P,radius);
	Extend3D(GradYExt,GradY,M,N,P,radius);
	Extend3D(GradZExt,GradZ,M,N,P,radius);

	/* Estimate the original image */
	for(NdxZ=0;NdxZ<P;NdxZ++)
	{
		for(NdxY=0;NdxY<N;NdxY++)
		{
			for(NdxX=0;NdxX<M;NdxX++)
			{
				/* Get the features of the current pixel */
				ThisNdx=NdxZ*M*N+NdxY*M+NdxX;
				ThisPixel[0]=GradX[ThisNdx];
				ThisPixel[1]=GradY[ThisNdx];
				ThisPixel[2]=GradZ[ThisNdx];
				ThisPixel[3]=SmoothImg[ThisNdx];

				/* Get the window with the features relevant for weight computation (Window), and
				 the window with the noisy values from the input image (WindowNoisy) */
				for(NdxWin=0;NdxWin<ksize*ksize;NdxWin++)
				{
					OrigOffset=(NdxZ+(NdxWin/ksize))*(M+2*radius)*(N+2*radius)
							+(NdxY+(NdxWin%ksize))*(M+2*radius)+NdxX;
					DestOffset=NdxWin*ksize;
					memcpy(WindowT+DestOffset,GradXExt+OrigOffset,
						ksize*sizeof(double));
					memcpy(WindowT+DestOffset+ksize3,GradYExt+OrigOffset,
						ksize*sizeof(double));
					memcpy(WindowT+DestOffset+2*ksize3,GradZExt+OrigOffset,
						ksize*sizeof(double));
					memcpy(WindowT+DestOffset+3*ksize3,SmoothImgExt+OrigOffset,
						ksize*sizeof(double));
					memcpy(WindowNoisy+DestOffset,NoisyImgExt+OrigOffset,
						ksize*sizeof(double));
				}
				Traspose(WindowT,Window,ksize3,4);
				
				/* Compute the squared Euclidean distances among the feature vectors of the current pixel and
				those of the search window pixels, and the corresponding weights */
				MaxWeight=0.0;
				NdxCenter=ksize3/2;
				SumWeights=0.0;
				for(NdxWin=0;NdxWin<ksize3;NdxWin++)
				{
					if (NdxWin!=NdxCenter)
					{
						/* Compute squared norm */
						Difference(ThisPixel,Window+4*NdxWin,VectorDiff,4);
						SquaredNorm(VectorDiff,&MySquaredNorm,4);
						/* Compute weight. Avoid evaluation of exp() if the result is zero */
						ArgExp=MySquaredNorm/h;
						if (ArgExp>750.0)
						{
							Weights[NdxWin]=0.0;
						}
						else
						{
							Weights[NdxWin]=exp(-ArgExp);
							SumWeights+=Weights[NdxWin];
							if (Weights[NdxWin]>MaxWeight)
							{
								MaxWeight=Weights[NdxWin];
							}
						}
					}
				}
				Weights[NdxCenter]=MaxWeight;
				SumWeights+=MaxWeight;
				
				/* Normalize weights */
				for(NdxWin=0;NdxWin<ksize3;NdxWin++)
				{
					Weights[NdxWin]/=SumWeights;
				}

				/* Multiply the noisy values from the input image by the normalized weights */
				RecImg[ThisNdx]=0.0;
				for(NdxWin=0;NdxWin<ksize3;NdxWin++)
				{
					RecImg[ThisNdx]+=Weights[NdxWin]*WindowNoisy[NdxWin];
				}

			}

		}

	}

	/* Release memory */
	mxFree(NoisyImgExt);
	mxFree(SmoothImgExt);
	mxFree(GradXExt);
	mxFree(GradYExt);
	mxFree(GradZExt);
	mxFree(WindowT);
	mxFree(Window);
	mxFree(WindowNoisy);
	mxFree(Weights);
    
}    



/* Extend a 2D matrix (mirror extension at the edges) */
void Extend2D(double *Dest,double *Orig,int M, int N,int radius)
{
	int Offset,NdxCol,NdxRow,cnt;

	/* Copy the original matrix to the destination */
	Offset=radius*(M+2*radius)+radius;
	for(NdxCol=0;NdxCol<N;NdxCol++)
	{		
		memcpy(Dest+Offset+NdxCol*(M+2*radius),Orig+NdxCol*M,
			M*sizeof(double));
	}

	/* Extend to the left */
	for(NdxCol=0;NdxCol<radius;NdxCol++)
	{		
		memcpy(Dest+radius+NdxCol*(M+2*radius),Orig+(radius-NdxCol)*M,
			M*sizeof(double));
	}

	/* Extend to the right */
	Offset=(N+radius)*(M+2*radius)+radius;
	for(NdxCol=0;NdxCol<radius;NdxCol++)
	{		
		memcpy(Dest+Offset+NdxCol*(M+2*radius),Orig+(N-2-NdxCol)*M,
			M*sizeof(double));
	}

	/* Extend to the top */	
	for(NdxCol=0;NdxCol<N+2*radius;NdxCol++)
	{		
		for(NdxRow=0;NdxRow<radius;NdxRow++)
		{
			Dest[NdxRow+NdxCol*(M+2*radius)]=Dest[2*radius-NdxRow+NdxCol*(M+2*radius)];
		}		
	}

	/* Extend to the bottom */	
	for(NdxCol=0;NdxCol<N+2*radius;NdxCol++)
	{		
		for(NdxRow=radius+M,cnt=0;NdxRow<M+2*radius;NdxRow++,cnt++)
		{
			Dest[NdxRow+NdxCol*(M+2*radius)]=Dest[M+radius-2-cnt+NdxCol*(M+2*radius)];
		}		
	}

}


/* Extend a 3D matrix (mirror extension at the edges) */
void Extend3D(double *Dest,double *Orig,int M,int N,int P,int radius)
{
	int ndx;

	/* Extend the 2D slices of the original matrix */
	for(ndx=0;ndx<P;ndx++)
	{
		Extend2D(Dest+(N+2*radius)*(M+2*radius)*(ndx+radius),
			Orig+N*M*ndx,M,N,radius);
	}

	/* Top of the 3D matrix */
	for(ndx=0;ndx<radius;ndx++)
	{
		memcpy(Dest+(N+2*radius)*(M+2*radius)*ndx,
			Dest+(N+2*radius)*(M+2*radius)*(2*radius-ndx-1),
			(N+2*radius)*(M+2*radius)*sizeof(double));
	}

	/* Bottom of the 3D matrix */
	for(ndx=0;ndx<radius;ndx++)
	{
		memcpy(Dest+(N+2*radius)*(M+2*radius)*(P+radius+ndx),
			Dest+(N+2*radius)*(M+2*radius)*(P+radius-ndx-1),
			(N+2*radius)*(M+2*radius)*sizeof(double));
	}
}