#include "mex.h"
#include <math.h>

/* Number of bins of the hash tables */
#define NUM_HASH_BINS 1024
/* Mask to compute hash value, must be log2(NUM_HASH_BINS) binary ones */
#define HASH_MASK 0x3FFu

/* 

Compute a multivariate histogram for the Discrete Histogram Transform.
Coded by Ezequiel Lopez-Rubio. November 2013.

In order to compile this MEX function, type the following at Matlab prompt:
>> mex ComputeHistogramMEX.c

[Histogram]=ComputeHistogramMEX(Samples,A,b);

Inputs:
	Samples		DxN matrix with N training samples of dimension D
	A			DxD matrix A of an affine transform
	b			Dx1 vector b of an affine transform
Output:
	Histogram	Resulting histogram

*/

/* Matrix sum. It supports that one of the operands is also the result*/
void MatrixSum(double *A,double *B,double *Result,int NumRows,int NumCols)
{
    register double *ptra;
    register double *ptrb;
    register double *ptrres;
    register int ndx;
    register int NumElements;
    
    ptra=A;
    ptrb=B;
    ptrres=Result;
    NumElements=NumRows*NumCols;
    for(ndx=0;ndx<NumElements;ndx++)
    {
        (*ptrres)=(*ptra)+(*ptrb);
        ptrres++;
        ptra++;
        ptrb++;
    }    
}

/* Matrix product */
void MatrixProduct(double *A,double *B,double *Result,int NumRowsA,
    int NumColsA,int NumColsB)
{
    register double *ptra;
    register double *ptrb;
    register double *ptrres;
    register int i;
    register int j;
    register int k;
    register double Sum;
    
    ptrres=Result;
    for(j=0;j<NumColsB;j++)
    {
        for(i=0;i<NumRowsA;i++)
        {
            Sum=0.0;
            ptrb=B+NumColsA*j;
            ptra=A+i;
            for(k=0;k<NumColsA;k++)
            {
                Sum+=(*ptra)*(*ptrb);
                ptra+=NumRowsA;
                ptrb++;
            }    
            (*ptrres)=Sum;
            ptrres++;
        }
    }            
}   

void mexFunction(int nlhs, mxArray* plhs[],
                 int nrhs, const mxArray* prhs[])
{

	mxArray *Samples,*MyCell,*Indices,*Counters;
	int NumSamples,Dimension;
	int NdxSample,NdxDim,MyNumElems;
	double *ptrSamples,*ptrMyHashBin;
	double *ptrA,*ptrb,*AuxVector,*AuxVectorRounded;
	double **ptrIndices,**ptrCounts;
	int *NumElemsHashBin;
	int HashValue,MyElement,NdxBin;
	const char *FieldNames[]={"Indices","Counts"};


	/* Get input mxArrays */
	Samples=prhs[0];
	ptrA=mxGetPr(prhs[1]);
	ptrb=mxGetPr(prhs[2]);

	/* Get working data */
	Dimension=mxGetM(Samples);
	NumSamples=mxGetN(Samples);
	ptrSamples=mxGetPr(Samples);

	
	/* Create auxiliary arrays */
	ptrIndices=mxCalloc(NUM_HASH_BINS,sizeof(double *));
	ptrCounts=mxCalloc(NUM_HASH_BINS,sizeof(double *));
	NumElemsHashBin=mxCalloc(NUM_HASH_BINS,sizeof(int));
	AuxVector=mxMalloc(Dimension*sizeof(double));
	AuxVectorRounded=mxMalloc(Dimension*sizeof(double));


	/* Process all input samples */
	for(NdxSample=0;NdxSample<NumSamples;NdxSample++)
	{
		/* Transform the sample */
		MatrixProduct(ptrA,ptrSamples+NdxSample*Dimension,
			AuxVector,Dimension,Dimension,1);
		MatrixSum(AuxVector,ptrb,AuxVector,Dimension,1);

		/* Round the result and compute hash value */
		HashValue=0;
		for(NdxDim=0;NdxDim<Dimension;NdxDim++)
		{
			AuxVectorRounded[NdxDim]=floor(AuxVector[NdxDim]+0.5);
			HashValue+=(int)AuxVectorRounded[NdxDim];
		}
		HashValue&=HASH_MASK;

		/* Look for the histogram bin corresponding to this test sample */
		MyNumElems=NumElemsHashBin[HashValue];
		ptrMyHashBin=ptrIndices[HashValue];
		MyElement=-1;
		for(NdxBin=0;NdxBin<MyNumElems;NdxBin++)
		{
			if (memcmp(AuxVectorRounded,ptrMyHashBin+NdxBin*Dimension,
				Dimension*sizeof(double))==0)
			{
				MyElement=NdxBin;
				break;
			}
		}

		/* If the histogram bin has been found, add one to the counter.
		   Otherwise, insert a new histogram bin into the hash table. */
		if (MyElement>=0)
		{
			ptrCounts[HashValue][MyElement]++;
		}
		else
		{
			MyNumElems++;
			NumElemsHashBin[HashValue]=MyNumElems;
			ptrCounts[HashValue]=mxRealloc(ptrCounts[HashValue],MyNumElems*sizeof(double));
			ptrCounts[HashValue][MyNumElems-1]=1;
			ptrIndices[HashValue]=mxRealloc(ptrIndices[HashValue],MyNumElems*Dimension*sizeof(double));
			memcpy(ptrIndices[HashValue]+(MyNumElems-1)*Dimension,
				AuxVectorRounded,Dimension*sizeof(double));
		}
	}


	/* Convert the hash table to mxArray format */
	Indices=mxCreateCellMatrix(NUM_HASH_BINS, 1);
	Counters=mxCreateCellMatrix(NUM_HASH_BINS, 1);
	for(NdxBin=0;NdxBin<NUM_HASH_BINS;NdxBin++)
	{
		if (NumElemsHashBin[NdxBin]>0)
		{
			/* Indices array */
			MyCell=mxCreateDoubleMatrix(Dimension,NumElemsHashBin[NdxBin],mxREAL);
			memcpy(mxGetPr(MyCell),ptrIndices[NdxBin],
				NumElemsHashBin[NdxBin]*Dimension*sizeof(double));
			mxSetCell(Indices,NdxBin,MyCell);
			/* Counters array */
			MyCell=mxCreateDoubleMatrix(1,NumElemsHashBin[NdxBin],mxREAL);
			memcpy(mxGetPr(MyCell),ptrCounts[NdxBin],
				NumElemsHashBin[NdxBin]*sizeof(double));
			mxSetCell(Counters,NdxBin,MyCell);
		}
	}


	/* Create output mxArray */
	plhs[0]=mxCreateStructMatrix(1, 1, 2, FieldNames);
	mxSetField(plhs[0], 0, "Indices", Indices);
	mxSetField(plhs[0], 0, "Counts", Counters);

	/* Release dynamic memory */
	mxFree(AuxVector);
	mxFree(AuxVectorRounded);
	mxFree(NumElemsHashBin);
	for(NdxBin=0;NdxBin<NUM_HASH_BINS;NdxBin++)
	{
		mxFree(ptrIndices[NdxBin]);
		mxFree(ptrCounts[NdxBin]);
	}
	mxFree(ptrIndices);
	mxFree(ptrCounts);
}