#include "mex.h"
#include <math.h>

/* Number of bins of the hash tables */
#define NUM_HASH_BINS 1024
/* Mask to compute hash value, must be log2(NUM_HASH_BINS) binary ones */
#define HASH_MASK 0x3FFu

/* 

Compute leave-one-out (LOO) probability densities for the Discrete Histogram Transform.
Coded by Ezequiel Lopez-Rubio. November 2013.

In order to compile this MEX function, type the following at Matlab prompt:
>> mex TestHTLOOMEX.c

[TestDensityLOO]=TestHTLOOMEX(Model,TrainSamples);

Inputs:
	Model			DHT model
	TrainSamples	DxN matrix with N training samples of dimension D
Output:
	TestDensityLOO	Nx1 vector with the LOO probability densities

*/

/* Matrix difference */
void MatrixDifference(double *A,double *B,double *Result,int NumRows,int NumCols)
{
    register double *ptra;
    register double *ptrb;
    register double *ptrres;
    register int ndx;
    register int NumElements;
    
    ptra=A;
    ptrb=B;
    ptrres=Result;
    NumElements=NumRows*NumCols;
    for(ndx=0;ndx<NumElements;ndx++)
    {
        (*ptrres)=(*ptra)-(*ptrb);
        ptrres++;
        ptra++;
        ptrb++;
    }    
}

/* Matrix sum. It supports that one of the operands is also the result*/
void MatrixSum(double *A,double *B,double *Result,int NumRows,int NumCols)
{
    register double *ptra;
    register double *ptrb;
    register double *ptrres;
    register int ndx;
    register int NumElements;
    
    ptra=A;
    ptrb=B;
    ptrres=Result;
    NumElements=NumRows*NumCols;
    for(ndx=0;ndx<NumElements;ndx++)
    {
        (*ptrres)=(*ptra)+(*ptrb);
        ptrres++;
        ptra++;
        ptrb++;
    }    
}

/* Matrix product */
void MatrixProduct(double *A,double *B,double *Result,int NumRowsA,
    int NumColsA,int NumColsB)
{
    register double *ptra;
    register double *ptrb;
    register double *ptrres;
    register int i;
    register int j;
    register int k;
    register double Sum;
    
    ptrres=Result;
    for(j=0;j<NumColsB;j++)
    {
        for(i=0;i<NumRowsA;i++)
        {
            Sum=0.0;
            ptrb=B+NumColsA*j;
            ptra=A+i;
            for(k=0;k<NumColsA;k++)
            {
                Sum+=(*ptra)*(*ptrb);
                ptra+=NumRowsA;
                ptrb++;
            }    
            (*ptrres)=Sum;
            ptrres++;
        }
    }            
}   

void mexFunction(int nlhs, mxArray* plhs[],
                 int nrhs, const mxArray* prhs[])
{

	mxArray *Model,*TestSamples,*MyHistogram,*MyIndices,*MyCounts,*MyCell;
	int NumTestSamples,Dimension,NumHistograms;
	int NdxHistogram,NdxHash,NdxSample,NdxDim,MyNumElems;
	double *ptrTestSamples,*ptrTestDensity,*ptrMyHashBin,*ptrVolumeBin;
	double *ptrA,*ptrb,*ptrMyA,*ptrMyb,*AuxVector,*AuxVector2;
	double **ptrIndices,**ptrCounts;
	double *ptrInvC,*ptrMu;
	double Factor,NumTrainSamples,GaussianConstant,GaussianDensity;
	int *NumElemsHashBin;
	int HashValue,MyElement,NdxBin;


	/* Get input mxArrays */
	Model=prhs[0];
	TestSamples=prhs[1];

	/* Get working data */
	Dimension=mxGetM(TestSamples);
	NumTestSamples=mxGetN(TestSamples);
	ptrTestSamples=mxGetPr(TestSamples);
	NumHistograms=(int)mxGetScalar(mxGetField(Model,0,"NumHistograms"));
	NumTrainSamples=mxGetScalar(mxGetField(Model,0,"NumSamples"))-1.0; /* Subtract 1.0 because we leave one out */
	ptrA=mxGetPr(mxGetField(Model,0,"A"));
	ptrb=mxGetPr(mxGetField(Model,0,"b"));
	ptrVolumeBin=mxGetPr(mxGetField(Model,0,"VolumeBin"));
	ptrInvC=mxGetPr(mxGetField(Model,0,"InvGlobalCovarianceMatrix"));
	ptrMu=mxGetPr(mxGetField(Model,0,"GlobalMean"));
	GaussianConstant=mxGetScalar(mxGetField(Model,0,"GaussianConstant"));

	/* Create output mxArray. Note that Matlab initializes its elements to zero. */
	plhs[0]=mxCreateDoubleMatrix(1,NumTestSamples,mxREAL);
	ptrTestDensity=mxGetPr(plhs[0]);

	/* Create auxiliary arrays */
	ptrIndices=mxMalloc(NUM_HASH_BINS*sizeof(double *));
	ptrCounts=mxMalloc(NUM_HASH_BINS*sizeof(double *));
	NumElemsHashBin=mxMalloc(NUM_HASH_BINS*sizeof(int));
	AuxVector=mxMalloc(Dimension*sizeof(double));
	AuxVector2=mxMalloc(Dimension*sizeof(double));

	/* For each histogram of the model */
	for(NdxHistogram=0;NdxHistogram<NumHistograms;NdxHistogram++)
	{
		Factor=1.0/(NumTrainSamples*ptrVolumeBin[NdxHistogram]);
		ptrMyA=ptrA+Dimension*Dimension*NdxHistogram;
		ptrMyb=ptrb+Dimension*NdxHistogram;

		/* Obtain the pointers to the elements of the hash table */
		MyHistogram=mxGetCell(mxGetField(Model,0,"Histogram"),NdxHistogram);
		MyIndices=mxGetField(MyHistogram,0,"Indices");
		MyCounts=mxGetField(MyHistogram,0,"Counts");
		for(NdxHash=0;NdxHash<NUM_HASH_BINS;NdxHash++)
		{
			MyCell=mxGetCell(MyIndices,NdxHash);
			/* Check whether this element of the hash table is empty */
			if (MyCell==NULL)
			{
				ptrIndices[NdxHash]=NULL;
				ptrCounts[NdxHash]=NULL;
				NumElemsHashBin[NdxHash]=0;
			}
			else
			{
				ptrIndices[NdxHash]=mxGetPr(MyCell);
				NumElemsHashBin[NdxHash]=mxGetN(MyCell);
				MyCell=mxGetCell(MyCounts,NdxHash);
				ptrCounts[NdxHash]=mxGetPr(MyCell);
				
			}
		}
		/* Process all test samples */
		for(NdxSample=0;NdxSample<NumTestSamples;NdxSample++)
		{
			/* Transform the sample */
			MatrixProduct(ptrMyA,ptrTestSamples+NdxSample*Dimension,
				AuxVector,Dimension,Dimension,1);
			MatrixSum(AuxVector,ptrMyb,AuxVector,Dimension,1);

			/* Round the result and compute hash value */
			HashValue=0;
			for(NdxDim=0;NdxDim<Dimension;NdxDim++)
			{
				AuxVector2[NdxDim]=floor(AuxVector[NdxDim]+0.5);
				HashValue+=(int)AuxVector2[NdxDim];
			}
			HashValue&=HASH_MASK;

			/* Look for the histogram bin corresponding to this test sample */
			MyNumElems=NumElemsHashBin[HashValue];
			ptrMyHashBin=ptrIndices[HashValue];
			MyElement=-1;
			for(NdxBin=0;NdxBin<MyNumElems;NdxBin++)
			{
				if (memcmp(AuxVector2,ptrMyHashBin+NdxBin*Dimension,
					Dimension*sizeof(double))==0)
				{
					MyElement=NdxBin;
					break;
				}
			}

			/* If the histogram bin has been found, add the corresponding probability
			density*/
			if (MyElement>=0)
			{
				ptrTestDensity[NdxSample]+=(Factor*(ptrCounts[HashValue][MyElement]-1.0));
			}
		}
	}

	/* Compute Gaussian component and normalize the probability density of all test samples */
	for(NdxSample=0;NdxSample<NumTestSamples;NdxSample++)
	{
		/* Compute Gaussian density */
		MatrixDifference(ptrTestSamples+NdxSample*Dimension,
			ptrMu,AuxVector,Dimension,1);
		MatrixProduct(ptrInvC,AuxVector,AuxVector2,
			Dimension,Dimension,1);
		MatrixProduct(AuxVector,AuxVector2,&GaussianDensity,
			1,Dimension,1);
		GaussianDensity=GaussianConstant*exp(-0.5*GaussianDensity);

		/* Normalize and add Gaussian density */
		ptrTestDensity[NdxSample]=
			(1.0/(NumTrainSamples+1))*GaussianDensity+
			(NumTrainSamples/(NumTrainSamples+1))*
			(ptrTestDensity[NdxSample]/NumHistograms);
	}

	/* Release dynamic memory */
	mxFree(ptrIndices);
	mxFree(ptrCounts);
	mxFree(AuxVector);
	mxFree(AuxVector2);
	mxFree(NumElemsHashBin);

}