#include "mex.h"
#include "Mates.h"
#include "Debugging.h"

#include "gmp.h"
#include "mpfr.h"
#include <stdio.h>
#include <math.h>
#include <time.h>
#include <float.h>
#include <string.h>

/* In order to compile this MEX function, type the following at the MATLAB prompt:
32-bit Windows:
mex PPCASOMANLLMEX.c MatesLap.c lapack.a blas.a libf2c.a libmpfr.a libgmp.a Debugging.c
64-bit Windows:
mex LINKFLAGS="$LINKFLAGS /NODEFAULTLIB:LIBCMT" PPCASOMANLLMEX.c MatesLap.c Debugging.c clapack_nowrap.lib BLAS_nowrap.lib libf2c.lib mpir.lib mpfr.lib

Usage:
[ANLL,LogDensitiesProb] = PPCASOMANLLMEX(Samples,Model);


*/

void FindLogDensPPCASOMMEX(const mxArray* Model,double *ptrSample,
    double *LogDensityProb);



void mexFunction(int nlhs, mxArray* plhs[],
                 int nrhs, const mxArray* prhs[])
{  
    int SpaceDimension,NumSamples,IndexPat,NdxWinningRow,NdxWinningCol;
    const int *DimSamples;
    double *ptrSamples,*Pattern,*ptrANLL,*ptrLogDensitiesProb;
    double SumLogDensitiesProb;
     

    
    
    /* Get input data */
    DimSamples=mxGetDimensions(prhs[0]);
    SpaceDimension=DimSamples[0];
    NumSamples=DimSamples[1];
    ptrSamples=mxGetPr(prhs[0]);
    
    /* Create output matrices */
    plhs[0]=mxCreateNumericMatrix(1, 1, mxDOUBLE_CLASS, mxREAL);
    ptrANLL=mxGetPr(plhs[0]);
    plhs[1]=mxCreateNumericMatrix(1, NumSamples, mxDOUBLE_CLASS, mxREAL);
    ptrLogDensitiesProb=mxGetPr(plhs[1]);
    
    
    
    /* Main loop */
    SumLogDensitiesProb=0.0;
    for(IndexPat=0;IndexPat<NumSamples;IndexPat++)
    {
        
        /* Pattern = Samples(:,IndexPat); */
        Pattern=ptrSamples+IndexPat*SpaceDimension;
        
        
        FindLogDensPPCASOMMEX(prhs[1],Pattern,
                ptrLogDensitiesProb+IndexPat);
        SumLogDensitiesProb+=ptrLogDensitiesProb[IndexPat];
        
    }    
    (*ptrANLL)=-SumLogDensitiesProb/(double)NumSamples;   
    
}    





/*--------------------------------------------------------------------*/

void FindLogDensPPCASOMMEX(const mxArray* Model,double *ptrSample,
    double *LogDensityProb)
{   
    double *ptrW,*ptrWT,*ptrWTW,*ptrInvWTW,*ptrWInvWTW,*ptrMatrixDiagonal,*ptrUqT,*ptrSigma2;
    double *ptrLambdaq,*ptrUq,*ptrPi;
    const int *DimMap;
    const int *DimSamples;
    const int *DimW;
    int SpaceDimension,NumSamples,NumRowsMap,NumColsMap,NumVecBasis,NdxVecBasis,Index;
    int NdxCol,NdxRow;
    mxArray *Means,*W,*Uq,*Sigma2,*Lambdaq,*MyMean,*MyW,*Samples,*Pi;
    double *ptrVectorDif,*ptrTn,*ptrzin,*ptrWTVectorDif,*ptrVectorErrRec,*ptrMatrixDiagzin;
    double Erec2,Ein2,En2,LogDetC,MyLogDensity;
    mpfr_t SumDensities,LogDens,Dens,NLL;
    
    
    
    
    
    /* Get input data */
    Means=mxGetField(Model,0,"Means");
    W=mxGetField(Model,0,"W");
    Uq=mxGetField(Model,0,"Uq");
    Sigma2=mxGetField(Model,0,"Sigma2");
    DimMap=mxGetDimensions(Sigma2);
    NumRowsMap=DimMap[0];
    NumColsMap=DimMap[1];
    Lambdaq=mxGetField(Model,0,"Lambdaq");
    Samples=mxGetField(Model,0,"Samples");
    DimSamples=mxGetDimensions(Samples);
    SpaceDimension=DimSamples[0];
    Pi=mxGetField(Model,0,"Pi");
    ptrPi=mxGetPr(Pi);
    
    
    /* Create auxiliary matrices */
    ptrVectorDif=mxMalloc(SpaceDimension*1*sizeof(double));
    ptrTn=mxMalloc(SpaceDimension*1*sizeof(double));
    ptrzin=mxMalloc(SpaceDimension*1*sizeof(double));
    ptrWTVectorDif=mxMalloc(SpaceDimension*1*sizeof(double));
    ptrVectorErrRec=mxMalloc(SpaceDimension*1*sizeof(double));
    ptrMatrixDiagzin=mxMalloc(SpaceDimension*1*sizeof(double));
    
    mpfr_set_default_prec(100);
    
    /*  Variable initialization */
    
    mpfr_init(SumDensities); 
    mpfr_set_si(SumDensities,0,GMP_RNDN);
    mpfr_init(NLL);
    mpfr_init(LogDens);   
    mpfr_init(Dens); 
    
    

 
    Index=0;
    /* LogDensityProb = -inf; */
    (*LogDensityProb)=-DBL_MAX;
    
    for(NdxCol=0;NdxCol<NumColsMap;NdxCol++)
    {
        for(NdxRow=0;NdxRow<NumRowsMap;NdxRow++)
        {
            /* Prepare working variables */
            MyMean=mxGetCell(Means,Index);
            MyW=mxGetCell(W,Index);
            DimW=mxGetDimensions(MyW);
            NumVecBasis=DimW[1];
            ptrW=mxGetPr(MyW);
            ptrUq=mxGetPr(mxGetCell(Uq,Index));
            ptrLambdaq=mxGetPr(mxGetCell(Lambdaq,Index));
            ptrSigma2=mxGetPr(mxGetCell(Sigma2,Index));
            
            
            ptrWT=mxMalloc(NumVecBasis*SpaceDimension*sizeof(double));   
            Traspose(ptrW,ptrWT,SpaceDimension,NumVecBasis);
            
            ptrWTW=mxMalloc(NumVecBasis*NumVecBasis*sizeof(double));
            MatrixProduct(ptrWT,ptrW,ptrWTW,NumVecBasis,SpaceDimension,NumVecBasis);
            
            ptrInvWTW=mxMalloc(NumVecBasis*NumVecBasis*sizeof(double));
            Inverse(ptrWTW,ptrInvWTW,NumVecBasis);
            
            ptrWInvWTW=mxMalloc(SpaceDimension*NumVecBasis*sizeof(double));
            MatrixProduct(ptrW,ptrInvWTW,ptrWInvWTW,SpaceDimension,NumVecBasis,NumVecBasis);
            
            ptrUqT=mxMalloc(NumVecBasis*SpaceDimension*sizeof(double));
            Traspose(ptrUq,ptrUqT,SpaceDimension,NumVecBasis);  
            
            ptrMatrixDiagonal=mxCalloc(NumVecBasis*NumVecBasis,sizeof(double));
            
            LogDetC=(double)(SpaceDimension-NumVecBasis)*log(*ptrSigma2);
            
            for(NdxVecBasis=0;NdxVecBasis<NumVecBasis;NdxVecBasis++)
            {
                ptrMatrixDiagonal[NdxVecBasis+NdxVecBasis*NumVecBasis]=1.0/ptrLambdaq[NdxVecBasis];
                
                LogDetC+=log(ptrLambdaq[NdxVecBasis]);
                
            }    
            
            /* VectorDif=Samples(:,NdxSample) - Model.Mu{NdxRow,NdxCol}; */
            MatrixDifference(ptrSample,mxGetPr(MyMean),
                    ptrVectorDif,SpaceDimension,1);
            
            
            /* Tn =  WInvWTW * (W' * VectorDif); */
            
            MatrixProduct(ptrWT,ptrVectorDif,ptrWTVectorDif,NumVecBasis,SpaceDimension,1);
            MatrixProduct(ptrWInvWTW,ptrWTVectorDif,ptrTn,SpaceDimension,NumVecBasis,1);
            
            /* Erec2 = sum((VectorDif - Tn).^2); */
            MatrixDifference(ptrVectorDif,ptrTn,ptrVectorErrRec,SpaceDimension,1);
            SquaredNorm(ptrVectorErrRec,&Erec2,SpaceDimension);
            
                       
            /* zin=Model.Uq{NdxNeuro}'*VectorDif; */
            MatrixProduct(ptrUqT,ptrVectorDif,ptrzin,NumVecBasis,SpaceDimension,1);
            
            /* Ein2=zin'*MatrixDiagonal*zin; */
            MatrixProduct(ptrMatrixDiagonal,ptrzin,ptrMatrixDiagzin,NumVecBasis,NumVecBasis,1);
            MatrixProduct(ptrzin,ptrMatrixDiagzin,&Ein2,1,NumVecBasis,1);
            
            /* En2=Ein2+Erec2/Model.Sigma2{NdxRow,NdxCol}; */
            En2=Ein2+Erec2/(*ptrSigma2);
                        
            /* MyLogDensity=(-0.5*Dimension)*log(2*pi)-0.5*LogDetC-0.5*En2; */
            
            /* Find log(  (pi sub i) * p(t sub n | i) ) in the variable MyLogDensity */
            MyLogDensity=log(ptrPi[Index])-0.91893853320467*SpaceDimension-0.5*LogDetC-0.5*En2;
            /*MyLogDensity=-log((double)NumRowsMap*NumColsMap)-0.91893853320467*SpaceDimension-0.5*LogDetC-0.5*En2;*/
            
            /* Find (pi sub i) * p(t sub n | i) in the variable Dens and add it in order to 
            compute p(t sub n) in the variable SumDensities */
            mpfr_set_d(LogDens,MyLogDensity,GMP_RNDN);
            mpfr_exp (Dens, LogDens, GMP_RNDN); /* Exponential e ^ dato */
            mpfr_add(SumDensities,SumDensities,Dens, GMP_RNDN);
                
            
            
            /* Index to iterate through the cells */
            Index++;
            
            /* Release memory */
            mxFree(ptrWT);
            mxFree(ptrWTW);
            mxFree(ptrInvWTW);
            mxFree(ptrWInvWTW);
            mxFree(ptrUqT);
            mxFree(ptrMatrixDiagonal);
            
        }    
 
    }    
    
    /* Find log( p( t sub n ) ) in the variable NLL */
    mpfr_log(NLL, SumDensities, GMP_RNDN);
    (*LogDensityProb)=mpfr_get_d(NLL, GMP_RNDN);
    
    
    mpfr_clear(SumDensities); /* Release memory */     
    mpfr_clear(LogDens); /* Release memory */    
    mpfr_clear(Dens); /* Release memory */
    mpfr_clear(NLL); /* Release memory */
    
    mxFree(ptrVectorDif);
    mxFree(ptrTn);
    mxFree(ptrzin);
    mxFree(ptrWTVectorDif);
    mxFree(ptrVectorErrRec);
    mxFree(ptrMatrixDiagzin);
    
    
    
}    
