#include "mex.h"
#include "Mates.h"
#include "Debugging.h"

#include "gmp.h"
#include "mpfr.h"
#include <stdio.h>
#include <math.h>
#include <time.h>
#include <float.h>
#include <string.h>

/* In order to compile this MEX function, type the following at the MATLAB prompt:
32-bit Windows:
mex TestPPCASOMMEX.c MatesLap.c lapack.a blas.a libf2c.a libmpfr.a libgmp.a Debugging.c
64-bit Windows:
mex LINKFLAGS="$LINKFLAGS /NODEFAULTLIB:LIBCMT" TestPPCASOMMEX.c MatesLap.c Debugging.c clapack_nowrap.lib BLAS_nowrap.lib libf2c.lib mpir.lib mpfr.lib

Usage:
[ANLL,LogDensitiesProb,ProjVectors,RepresVectors,Respon] = TestPPCASOMMEX(Samples,Model);

Samples=Test samples
Model=PPCASOM model
ANLL=Average negative log-likelihood of the test samples
LogDensitiesProb(n)=Log-density log(p(t sub n)) for all test samples t sub n
ProjVectors(:,i,n)=Projection vector t hat sub n on the subspace of unit i 
RepresVectors(:,i,n)=Representation vector x sub n on the subspace of unit i, padded with zeros
Respon(i,n)=Posterior responsibility P( i | t sub n) that unit i has generated test sample t sub n

*/

void FindLogDensPPCASOMMEX(const mxArray* Model,double *ptrSample,
    double *LogDensityProb,double *ptrTn,double *ptrXn,double *ptrRespon);



void mexFunction(int nlhs, mxArray* plhs[],
                 int nrhs, const mxArray* prhs[])
{  
    int SpaceDimension,NumSamples,IndexPat,NumRowsMap,NumColsMap;
    const int *DimSamples;
	const int *DimMap;
    double *ptrSamples,*Pattern,*ptrANLL,*ptrLogDensitiesProb,*ptrTn,*ptrXn,*ptrRespon;
    double SumLogDensitiesProb;
	int Dims[3];
     

    
    
    /* Get input data */
    DimSamples=mxGetDimensions(prhs[0]);
    SpaceDimension=DimSamples[0];
    NumSamples=DimSamples[1];
    ptrSamples=mxGetPr(prhs[0]);
    DimMap=mxGetDimensions(mxGetField(prhs[1],0,"Sigma2"));
    NumRowsMap=DimMap[0];
    NumColsMap=DimMap[1];
    
    /* Create output matrices */
    plhs[0]=mxCreateNumericMatrix(1, 1, mxDOUBLE_CLASS, mxREAL);
    ptrANLL=mxGetPr(plhs[0]);
    plhs[1]=mxCreateNumericMatrix(1, NumSamples, mxDOUBLE_CLASS, mxREAL);
    ptrLogDensitiesProb=mxGetPr(plhs[1]);
	Dims[0]=SpaceDimension;
	Dims[1]=NumRowsMap*NumColsMap;
	Dims[2]=NumSamples;
	plhs[2]=mxCreateNumericArray(3,Dims,mxDOUBLE_CLASS,mxREAL);
	ptrTn=mxGetPr(plhs[2]);
	plhs[3]=mxCreateNumericArray(3,Dims,mxDOUBLE_CLASS,mxREAL);
	ptrXn=mxGetPr(plhs[3]);
    plhs[4]=mxCreateNumericMatrix(NumRowsMap*NumColsMap, NumSamples, mxDOUBLE_CLASS, mxREAL);
    ptrRespon=mxGetPr(plhs[4]);
    
    
    /* Main loop */
    SumLogDensitiesProb=0.0;
    for(IndexPat=0;IndexPat<NumSamples;IndexPat++)
    {
        
        /* Pattern = Samples(:,IndexPat); */
        Pattern=ptrSamples+IndexPat*SpaceDimension;
        
        
        FindLogDensPPCASOMMEX(prhs[1],Pattern,
                ptrLogDensitiesProb+IndexPat,
				ptrTn+IndexPat*NumRowsMap*NumColsMap*SpaceDimension,
				ptrXn+IndexPat*NumRowsMap*NumColsMap*SpaceDimension,
				ptrRespon+IndexPat*NumRowsMap*NumColsMap);
        SumLogDensitiesProb+=ptrLogDensitiesProb[IndexPat];
        
    }    
    (*ptrANLL)=-SumLogDensitiesProb/(double)NumSamples;   
    
}    





/*--------------------------------------------------------------------*/

void FindLogDensPPCASOMMEX(const mxArray* Model,double *ptrSample,
    double *LogDensityProb,double *ptrTn,double *ptrXn,double *ptrRespon)
{   
    double *ptrW,*ptrWT,*ptrWTW,*ptrInvWTW,*ptrWInvWTW,*ptrMatrixDiagonal,*ptrUqT,*ptrSigma2,*ptrMInv;
    double *ptrLambdaq,*ptrUq,*ptrPi;
    const int *DimMap;
    const int *DimSamples;
    const int *DimW;
    int SpaceDimension,NumSamples,NumRowsMap,NumColsMap,NumVecBasis,NdxVecBasis,Index;
    int NdxCol,NdxRow;
    mxArray *Means,*W,*Uq,*Sigma2,*Lambdaq,*MyMean,*MyW,*Samples,*Pi,*MInv;
    double *ptrVectorDif,*ptrzin,*ptrWTVectorDif,*ptrVectorErrRec,*ptrMatrixDiagzin;
    double Erec2,Ein2,En2,LogDetC,MyLogDensity;
    mpfr_t SumDensities,LogDens,NLL;
	mpfr_t *Dens;
    
    
    
    
    
    /* Get input data */
    Means=mxGetField(Model,0,"Means");
    W=mxGetField(Model,0,"W");
	MInv=mxGetField(Model,0,"MInv");
    Uq=mxGetField(Model,0,"Uq");
    Sigma2=mxGetField(Model,0,"Sigma2");
    DimMap=mxGetDimensions(Sigma2);
    NumRowsMap=DimMap[0];
    NumColsMap=DimMap[1];
    Lambdaq=mxGetField(Model,0,"Lambdaq");
    Samples=mxGetField(Model,0,"Samples");
    DimSamples=mxGetDimensions(Samples);
    SpaceDimension=DimSamples[0];
    Pi=mxGetField(Model,0,"Pi");
    ptrPi=mxGetPr(Pi);
    
    
    /* Create auxiliary matrices */
    ptrVectorDif=mxMalloc(SpaceDimension*1*sizeof(double));
    ptrzin=mxMalloc(SpaceDimension*1*sizeof(double));
    ptrWTVectorDif=mxMalloc(SpaceDimension*1*sizeof(double));
    ptrVectorErrRec=mxMalloc(SpaceDimension*1*sizeof(double));
    ptrMatrixDiagzin=mxMalloc(SpaceDimension*1*sizeof(double));
    
    mpfr_set_default_prec(100);
    
    /*  Variable initialization */
    
    mpfr_init(SumDensities); 
    mpfr_set_si(SumDensities,0,GMP_RNDN);
    mpfr_init(NLL);
    mpfr_init(LogDens); 
	Dens=mxMalloc(NumRowsMap*NumColsMap*sizeof(mpfr_t));
	for(Index=0;Index<NumRowsMap*NumColsMap;Index++)
	{
		mpfr_init(Dens[Index]); 
	}
    
    

 
    Index=0;
    /* LogDensityProb = -inf; */
    (*LogDensityProb)=-DBL_MAX;
    
    for(NdxCol=0;NdxCol<NumColsMap;NdxCol++)
    {
        for(NdxRow=0;NdxRow<NumRowsMap;NdxRow++)
        {
            /* Prepare working variables */
            MyMean=mxGetCell(Means,Index);
            MyW=mxGetCell(W,Index);
            DimW=mxGetDimensions(MyW);
            NumVecBasis=DimW[1];
            ptrW=mxGetPr(MyW);
            ptrUq=mxGetPr(mxGetCell(Uq,Index));
            ptrLambdaq=mxGetPr(mxGetCell(Lambdaq,Index));
            ptrSigma2=mxGetPr(mxGetCell(Sigma2,Index));
			ptrMInv=mxGetPr(mxGetCell(MInv,Index));
            
            
            ptrWT=mxMalloc(NumVecBasis*SpaceDimension*sizeof(double));   
            Traspose(ptrW,ptrWT,SpaceDimension,NumVecBasis);
            
            ptrWTW=mxMalloc(NumVecBasis*NumVecBasis*sizeof(double));
            MatrixProduct(ptrWT,ptrW,ptrWTW,NumVecBasis,SpaceDimension,NumVecBasis);
            
            ptrInvWTW=mxMalloc(NumVecBasis*NumVecBasis*sizeof(double));
            Inverse(ptrWTW,ptrInvWTW,NumVecBasis);
            
            ptrWInvWTW=mxMalloc(SpaceDimension*NumVecBasis*sizeof(double));
            MatrixProduct(ptrW,ptrInvWTW,ptrWInvWTW,SpaceDimension,NumVecBasis,NumVecBasis);
            
            ptrUqT=mxMalloc(NumVecBasis*SpaceDimension*sizeof(double));
            Traspose(ptrUq,ptrUqT,SpaceDimension,NumVecBasis);  
            
            ptrMatrixDiagonal=mxCalloc(NumVecBasis*NumVecBasis,sizeof(double));
            
            LogDetC=(double)(SpaceDimension-NumVecBasis)*log(*ptrSigma2);
            
            for(NdxVecBasis=0;NdxVecBasis<NumVecBasis;NdxVecBasis++)
            {
                ptrMatrixDiagonal[NdxVecBasis+NdxVecBasis*NumVecBasis]=1.0/ptrLambdaq[NdxVecBasis];
                
                LogDetC+=log(ptrLambdaq[NdxVecBasis]);
                
            }    
            
            /* VectorDif=Samples(:,NdxSample) - Model.Mu{NdxRow,NdxCol}; */
            MatrixDifference(ptrSample,mxGetPr(MyMean),
                    ptrVectorDif,SpaceDimension,1);
            
            
            /* Tn =  Mu + WInvWTW * (W' * VectorDif); 
			   Xn = MInv * (W' * VectorDif); */            
            MatrixProduct(ptrWT,ptrVectorDif,ptrWTVectorDif,NumVecBasis,SpaceDimension,1);
            MatrixProduct(ptrWInvWTW,ptrWTVectorDif,ptrTn,SpaceDimension,NumVecBasis,1);
			MatrixSum(ptrTn,mxGetPr(MyMean),ptrTn,SpaceDimension,1);
			MatrixProduct(ptrMInv,ptrWTVectorDif,ptrXn,NumVecBasis,NumVecBasis,1);
            
            /* Erec2 = sum((Samples(:,NdxSample) - Tn).^2); */
            MatrixDifference(ptrSample,ptrTn,ptrVectorErrRec,SpaceDimension,1);
            SquaredNorm(ptrVectorErrRec,&Erec2,SpaceDimension);
            
                       
            /* zin=Model.Uq{NdxNeuro}'*VectorDif; */
            MatrixProduct(ptrUqT,ptrVectorDif,ptrzin,NumVecBasis,SpaceDimension,1);
            
            /* Ein2=zin'*MatrixDiagonal*zin; */
            MatrixProduct(ptrMatrixDiagonal,ptrzin,ptrMatrixDiagzin,NumVecBasis,NumVecBasis,1);
            MatrixProduct(ptrzin,ptrMatrixDiagzin,&Ein2,1,NumVecBasis,1);
            
            /* En2=Ein2+Erec2/Model.Sigma2{NdxRow,NdxCol}; */
            En2=Ein2+Erec2/(*ptrSigma2);
                                   
            /* Find log(  (pi sub i) * p(t sub n | i) ) in the variable MyLogDensity */
            MyLogDensity=log(ptrPi[Index])-0.91893853320467*SpaceDimension-0.5*LogDetC-0.5*En2;
            
            /* Find (pi sub i) * p(t sub n | i) in the variable Dens[Index] and add it in order to 
            compute p(t sub n) in the variable SumDensities */
            mpfr_set_d(LogDens,MyLogDensity,GMP_RNDN);
            mpfr_exp (Dens[Index], LogDens, GMP_RNDN); /* Exponential e ^ datum */
            mpfr_add(SumDensities,SumDensities,Dens[Index], GMP_RNDN);
                
            
            /* Update ptrTn and ptrXn */
			ptrTn+=SpaceDimension;
			ptrXn+=SpaceDimension;

            /* Index to iterate through the cells */
            Index++;
            
            /* Release memory */
            mxFree(ptrWT);
            mxFree(ptrWTW);
            mxFree(ptrInvWTW);
            mxFree(ptrWInvWTW);
            mxFree(ptrUqT);
            mxFree(ptrMatrixDiagonal);
            
        }    
 
    }    
    
    /* Find log( p( t sub n ) ) in the variable NLL */
    mpfr_log(NLL, SumDensities, GMP_RNDN);
    (*LogDensityProb)=mpfr_get_d(NLL, GMP_RNDN);

	/* Find responsibilities P( i | t sub n) */
	for(Index=0;Index<NumRowsMap*NumColsMap;Index++)
	{
		mpfr_div(Dens[Index],Dens[Index],SumDensities, GMP_RNDN);
		ptrRespon[Index]=mpfr_get_d(Dens[Index], GMP_RNDN);
	}
    
    
    mpfr_clear(SumDensities); /* Release memory */     
    mpfr_clear(LogDens); /* Release memory */        
    mpfr_clear(NLL); /* Release memory */		
	for(Index=0;Index<NumRowsMap*NumColsMap;Index++)
	{
		mpfr_clear(Dens[Index]); 
	}
	mxFree(Dens);
    
    mxFree(ptrVectorDif);
    mxFree(ptrzin);
    mxFree(ptrWTVectorDif);
    mxFree(ptrVectorErrRec);
    mxFree(ptrMatrixDiagzin);
    
    
    
}    
