#include "mex.h"
#include <stdio.h>
#include <math.h>
#include <time.h>
#include <float.h>
#include <string.h>
#include <stdlib.h>

/* 

In order to compile this MEX function, type the following at the MATLAB prompt:
mex TrainQDSOMMEX.c Debugging.c

Usage:
[Model] = TrainQDSOMMEX(Model,NumSteps);

Notes:
Model is a QDSOM model
NumSteps is the number of epochs of the training procedure.

*/

#ifndef isnan
bool isnan(double x) {
    return x != x;
}
#endif

#ifndef isinf
bool isinf(double x) {
    return ((x - x) != 0);
}
#endif

#ifndef isfinite
bool isfinite(double x) {
    return ((x - x) == 0);
}
#endif

#ifndef log2
double log2(double x)
{
       return 1.442695040888963*log(x);
}
#endif



/* Find winning distribution */
void WinningQDSOMProbMEX(mxArray* Model,int *MyRanges,int *MyTrans,
    double *Pattern,double RadiusNeighbourhood,
    int *NdxWinRow,int *NdxWinCol,double *LogLikelihoodProb,double *ptrResponsibilities);    

/* Update Chow-Liu tree by Prim's spanning tree algorithm */
void PrimSpanningTree(double *ptrMyMutual,double *ptrMyGraph,int Dimension,int RootNode);

void mexFunction(int nlhs, mxArray* plhs[],
                 int nrhs, const mxArray* prhs[])
{  
    int Dimension,NumSamples,NumRowsMap,NumColsMap,MaxRange,NumValues,NdxDim,NdxDim2,NdxVal,NdxVal2;
    mxArray *Samples,*Marginal,*Intersect,*Mutual,*Graph;
    unsigned int Seed;
    int RandomNumber;
    double MaxRadiusNeighbourhood,RadiusNeighbourhood,LogLikelihoodProb;
    double LearningRate,OneMinusRate,NeighbourhoodFunction,NewAverage;
    double *ptrNumVictories,*Pattern,*ptrTranslation,*ptrRanges;
    double *ptrMyMarginal,*ptrMyIntersect,*ptrMyMutual,*ptrMyGraph;
    int NdxCol,NdxRow,NdxWinRow,NdxWinCol,NumSteps,NdxStep,IndexPat,NdxDatum;
    double *ptrSamples,*ptrResponsibilities,*ptrPi,*ptrUpdateMarginal,*ptrUpdateIntersect;
    int *MyTrans,*MyRanges;
    int IndexNeuron,Index,Index2,Value,Value2,RootNode;
    double NewCoef,OldCoef,AntPi,SumPis,MyProbIntersect,Logarithm2;

 
    /* Get input data */
    NumSteps=mxGetScalar(prhs[1]);    
    
    /* Setup pseudorandom number generator */
    Seed=time(NULL);
    srand(Seed);
    
    /* Create output array */
    plhs[0]=mxDuplicateArray(prhs[0]);
    
    
    /* Get working variables */
    Marginal=mxGetField(plhs[0],0,"Marginal");
    Intersect=mxGetField(plhs[0],0,"Intersect");
    Mutual=mxGetField(plhs[0],0,"Mutual");
    Graph=mxGetField(plhs[0],0,"Graph");
    ptrPi=mxGetPr(mxGetField(plhs[0],0,"Pi"));
    ptrNumVictories=mxGetPr(mxGetField(plhs[0],0,"NumVictories"));    
    ptrTranslation=mxGetPr(mxGetField(plhs[0],0,"Translate"));  
    ptrRanges=mxGetPr(mxGetField(plhs[0],0,"Ranges"));  


    NumRowsMap=mxGetScalar(mxGetField(plhs[0],0,"NumRowsMap"));
    NumColsMap=mxGetScalar(mxGetField(plhs[0],0,"NumColsMap"));
    Dimension=mxGetScalar(mxGetField(plhs[0],0,"Dimension"));
    MaxRange=mxGetScalar(mxGetField(plhs[0],0,"MaxRange"));
    NumValues=mxGetScalar(mxGetField(plhs[0],0,"NumValues"));
    RootNode=mxGetScalar(mxGetField(plhs[0],0,"RootNode"));

    Samples=mxGetField(plhs[0],0,"Samples");    
    NumSamples=mxGetN(Samples);
    ptrSamples=mxGetPr(Samples);
    
         
    /* Prepare auxiliary arrays */
    ptrResponsibilities=mxMalloc(NumRowsMap*NumColsMap*sizeof(double));
    ptrUpdateMarginal=mxMalloc(NumValues*sizeof(double));
    ptrUpdateIntersect=mxMalloc(NumValues*NumValues*sizeof(double));
    
    /* Prepare integer versions of the ranges table and the translation table */
    MyRanges=mxMalloc(Dimension*sizeof(int));
    MyTrans=mxMalloc(MaxRange*Dimension*sizeof(int));    
    for(NdxDim=0;NdxDim<Dimension;NdxDim++)
    {
         MyRanges[NdxDim]=(int)ptrRanges[NdxDim];         
         for(NdxVal=0;NdxVal<MyRanges[NdxDim];NdxVal++)
         {
              MyTrans[NdxVal+MaxRange*NdxDim]=(int)(ptrTranslation[NdxVal+MaxRange*NdxDim]-1.0);
         }
    }
    
    /* MaxRadiusNeighbourhood=mean([NumRowsMap NumColsMap])/4; */
    MaxRadiusNeighbourhood=((double)NumRowsMap+(double)NumColsMap)/4.0;
    
    /* Main loop */
    for(NdxStep=0;NdxStep<NumSteps;NdxStep++)
    {

        /* Choose a pattern */
        /*IndexPat = round(rand(1)*(NumSamples-1))+1; */
        RandomNumber=rand();
        
        IndexPat=RandomNumber%NumSamples;
        
        /* Pattern = Samples(:,IndexPat); */
        Pattern=ptrSamples+IndexPat*Dimension;
        
        /*  RadiusNeighbourhood=MaxRadiusNeighbourhood*(1.0-(double)NdxStep/NumSteps); */
        RadiusNeighbourhood=0.75*MaxRadiusNeighbourhood+0.25*MaxRadiusNeighbourhood*(1.0-(double)NdxStep/NumSteps);


        /* Find the winning distribution and the responsibilities and the
           Mahalanobis distances with respect to the covariance matrix */
        WinningQDSOMProbMEX(plhs[0],MyRanges,MyTrans,Pattern,RadiusNeighbourhood,
                &NdxWinRow,&NdxWinCol,&LogLikelihoodProb,ptrResponsibilities);

        /* Increment Model.NumVictories(NdxWinRow,NdxWinCol) */
        ptrNumVictories[NdxWinRow+NumRowsMap*NdxWinCol]++;

        /* Prepare the update vectors (common for all neurons) */
        memset(ptrUpdateMarginal,0,NumValues*sizeof(double));
        memset(ptrUpdateIntersect,0,NumValues*NumValues*sizeof(double));
 
        for(NdxDim=0;NdxDim<Dimension;NdxDim++)
        {
             Value=(int)Pattern[NdxDim];
             Index=MyTrans[Value+MaxRange*NdxDim];
             ptrUpdateMarginal[Index]=1.0;

             for(NdxDim2=NdxDim+1;NdxDim2<Dimension;NdxDim2++)
             {
                  Value2=(int)Pattern[NdxDim2];
                  Index2=MyTrans[Value2+MaxRange*NdxDim2];                  
                  ptrUpdateIntersect[Index+NumValues*Index2]=
                       ptrUpdateIntersect[Index2+NumValues*Index]=1.0;

             }             
        }

        /* Update all the neurons */        
        IndexNeuron=0;
        for(NdxCol=0;NdxCol<NumColsMap;NdxCol++)
        {
            for(NdxRow=0;NdxRow<NumRowsMap;NdxRow++)
            {
                /* Get pointers to this neuron */
                ptrMyMarginal=mxGetPr(mxGetCell(Marginal,IndexNeuron));
                ptrMyIntersect=mxGetPr(mxGetCell(Intersect,IndexNeuron));
                ptrMyMutual=mxGetPr(mxGetCell(Mutual,IndexNeuron));
                ptrMyGraph=mxGetPr(mxGetCell(Graph,IndexNeuron));
                
                
                /* Find the update rate of this neuron from its responsibility */
                
                /* Values of LearningRate larger than 0.8 make the algorithm converge very slowly.
                On the other hand, we always have OldCoef+NewCoef==1.0 */
                /* Linear decay also works, but it is less justifiable from a mathematical
                point of view:*/
                /*LearningRate=0.01*(1.0-(double)(NdxStep+1.0)/(NumSteps+1.0));*/
                /*LearningRate=1.0/(0.01*(double)NdxStep+100.0);*/
                LearningRate=1.0/(0.001*(double)NdxStep+100.0);
                OneMinusRate=1.0-LearningRate;
                AntPi=ptrPi[IndexNeuron];
                ptrPi[IndexNeuron]=OneMinusRate*ptrPi[IndexNeuron]+
                                LearningRate*ptrResponsibilities[IndexNeuron];
                OldCoef=(OneMinusRate*AntPi)/ptrPi[IndexNeuron];
                NewCoef=(LearningRate*ptrResponsibilities[IndexNeuron])/ptrPi[IndexNeuron];
                /* If there are problems, leave things as they were */
                if (isnan(OldCoef) || isinf(OldCoef) || isnan(NewCoef) || isinf(NewCoef) )
                {
                     OldCoef=1.0;
                     NewCoef=0.0;
                }
                

                
                /* Model.Marginal{NdxRow,NdxCol}=...
                    OldCoef*Model.Marginal{NdxRow,NdxCol}+...
                     NewCoef*UpdateMarginal; */
                for(NdxDatum=0;NdxDatum<NumValues;NdxDatum++)
                {
                    ptrMyMarginal[NdxDatum]=OldCoef*ptrMyMarginal[NdxDatum]+NewCoef*ptrUpdateMarginal[NdxDatum];
                }
                
                /* Model.Intersect{NdxRow,NdxCol}=...
                    OldCoef*Model.Intersect{NdxRow,NdxCol}+...
                     NewCoef*UpdateIntersect; */
                for(NdxDatum=0;NdxDatum<NumValues*NumValues;NdxDatum++)
                {
                    ptrMyIntersect[NdxDatum]=OldCoef*ptrMyIntersect[NdxDatum]+NewCoef*ptrUpdateIntersect[NdxDatum];
                }
                
                /* Update mutual information */
                memset(ptrMyMutual,0,Dimension*Dimension*sizeof(double));
                for(NdxDim=0;NdxDim<Dimension;NdxDim++)
                {
                     for(NdxVal=0;NdxVal<MyRanges[NdxDim];NdxVal++)
                     {
                          Index=MyTrans[NdxVal+MaxRange*NdxDim];
                          for(NdxDim2=NdxDim+1;NdxDim2<Dimension;NdxDim2++)
                          {
                              for(NdxVal2=0;NdxVal2<MyRanges[NdxDim2];NdxVal2++)
                              {
                                   Index2=MyTrans[NdxVal2+MaxRange*NdxDim2];                  
                                   MyProbIntersect=ptrMyIntersect[Index2+NumValues*Index];
                                   Logarithm2=log2(MyProbIntersect/(ptrMyMarginal[Index]*ptrMyMarginal[Index2]));
                                   if (isfinite(Logarithm2))
                                   {
                                        ptrMyMutual[NdxDim2+Dimension*NdxDim]=
                                             (ptrMyMutual[NdxDim+Dimension*NdxDim2]+=MyProbIntersect*
                                             Logarithm2);
                                   }
                                   
                              }                                                                        
                          }             
                     }
                }
                
                /* Update Chow-Liu tree by Prim's spanning tree algorithm */

                PrimSpanningTree(ptrMyMutual,ptrMyGraph,Dimension,RootNode);
                                                  
                /* Index to iterate through the cells */
                IndexNeuron++;
            }
        }
    }        
    
    
    /* ptrPi[Index] is proportional to the a priori probability of i, P(i). 
    Now we must normalize so that the sum is one */
    SumPis=0.0;
    for(Index=0;Index<NumRowsMap*NumColsMap;Index++)
    {
        SumPis+=ptrPi[Index];
    }    
    for(Index=0;Index<NumRowsMap*NumColsMap;Index++)
    {
        ptrPi[Index]/=SumPis;
    }    

    /* Free auxiliary arrays */
    mxFree(ptrResponsibilities);
    mxFree(MyRanges);
    mxFree(MyTrans);
    mxFree(ptrUpdateMarginal);
    mxFree(ptrUpdateIntersect);
    
}    







/*--------------------------------------------------------------------*/

/* Find winning distribution */        
void WinningQDSOMProbMEX(mxArray* Model,int *MyRanges,int *MyTrans,
    double *Pattern,double RadiusNeighbourhood,
    int *NdxWinRow,int *NdxWinCol,double *LogLikelihoodProb,double *ptrResponsibilities)
{   
    mxArray *Marginal,*Intersect,*Graph;    
    double *ptrPi,*ptrMyMarginal,*ptrMyIntersect,*ptrMyGraph;
    int Index,Index2,IndexNeuron,NdxCol,NumColsMap,NdxRow,NumRowsMap,NdxDim,NdxDim2,Dimension;
    int Value,Value2,MaxRange,NumValues,NdxColDist,NdxRowDist,RootNode;
    double MyLogLikelihood,DistributionProb,SumNeighbourhoodFunctions;
    double TopologicDistance,aux;
    double *NeighbourhoodFunction;
    double *ptrLikelihoods;
    
    
    /* Get input data */
    Marginal=mxGetField(Model,0,"Marginal");
    Intersect=mxGetField(Model,0,"Intersect");
    Graph=mxGetField(Model,0,"Graph");
    ptrPi=mxGetPr(mxGetField(Model,0,"Pi"));

    NumRowsMap=mxGetScalar(mxGetField(Model,0,"NumRowsMap"));
    NumColsMap=mxGetScalar(mxGetField(Model,0,"NumColsMap"));
    Dimension=mxGetScalar(mxGetField(Model,0,"Dimension"));
    MaxRange=mxGetScalar(mxGetField(Model,0,"MaxRange"));
    NumValues=mxGetScalar(mxGetField(Model,0,"NumValues"));   
    RootNode=mxGetScalar(mxGetField(Model,0,"RootNode")); 
    
    /* Prepare auxiliary arrays */
    ptrLikelihoods=mxMalloc(NumColsMap*NumRowsMap*sizeof(double));
    NeighbourhoodFunction=mxMalloc(NumColsMap*NumRowsMap*sizeof(double));
    
    
    IndexNeuron=0;
    /* LogLikelihoodProb = -inf; */
    (*LogLikelihoodProb)=-DBL_MAX;
    
    for(NdxCol=0;NdxCol<NumColsMap;NdxCol++)
    {
        for(NdxRow=0;NdxRow<NumRowsMap;NdxRow++)
        {
            /* Get pointers to this neuron */
            ptrMyMarginal=mxGetPr(mxGetCell(Marginal,IndexNeuron));
            ptrMyIntersect=mxGetPr(mxGetCell(Intersect,IndexNeuron));
            ptrMyGraph=mxGetPr(mxGetCell(Graph,IndexNeuron));            
            
            MyLogLikelihood=0.0;
            if (RootNode)
            {
                /* Directed graph */
                for(NdxDim=0;NdxDim<Dimension;NdxDim++)
                {
                     Value=(int)Pattern[NdxDim];
                     Index=MyTrans[Value+MaxRange*NdxDim];
                     
                     /* See whether this is the root node */                         
                     if (NdxDim==RootNode-1)
                     {
                          /* This is the root node */
                          MyLogLikelihood+=log(ptrMyMarginal[Index]);
                     }
                     /* Explore connections with other nodes */
                     for(NdxDim2=NdxDim+1;NdxDim2<Dimension;NdxDim2++)
                     {
                          Value2=(int)Pattern[NdxDim2];
                          Index2=MyTrans[Value2+MaxRange*NdxDim2];
                          /* Only if it belongs to the Chow-Liu tree of this neuron */
                          if (ptrMyGraph[NdxDim+Dimension*NdxDim2]) 
                          {
                               /* NdxDim is the parent node of NdxDim2 */
                               MyLogLikelihood+=(log(ptrMyIntersect[Index+NumValues*Index2])
                                    -log(ptrMyMarginal[Index]));
                                    
                          }
                          if (ptrMyGraph[NdxDim2+Dimension*NdxDim])
                          {
                               /* NdxDim2 is the parent node of NdxDim */
                               MyLogLikelihood+=(log(ptrMyIntersect[Index+NumValues*Index2])
                                    -log(ptrMyMarginal[Index2]));                                            
                          }                                                                    
                     }            
                }
            }
            else
            {
                /* Undirected graph */
                for(NdxDim=0;NdxDim<Dimension;NdxDim++)
                {
                     Value=(int)Pattern[NdxDim];
                     Index=MyTrans[Value+MaxRange*NdxDim];
                     MyLogLikelihood+=log(ptrMyMarginal[Index]);
                     
                     for(NdxDim2=NdxDim+1;NdxDim2<Dimension;NdxDim2++)
                     {
                          Value2=(int)Pattern[NdxDim2];
                          Index2=MyTrans[Value2+MaxRange*NdxDim2];
                          /* Only if it belongs to the Chow-Liu tree of this neuron */
                          if (ptrMyGraph[NdxDim+Dimension*NdxDim2] || ptrMyGraph[NdxDim2+Dimension*NdxDim])
                          {
                               MyLogLikelihood+=(log(ptrMyIntersect[Index+NumValues*Index2])
                                    -log(ptrMyMarginal[Index])
                                    -log(ptrMyMarginal[Index2]));
                                    
                          }
                     }             
                }
            }
            
            /* Now we have MyLogLikelihood==log P(x sub n | i) */
            /* We may have MyLogLikelihood==NaN for previously unseen combinations of values,
              since ptrMyIntersect[Index+NumValues*Index2]==0.0 or  ptrMyMarginal[Index]==0.0 */
            if (isfinite(MyLogLikelihood))
            {
                ptrLikelihoods[IndexNeuron]=exp(MyLogLikelihood);  
            }
            else
            {
                ptrLikelihoods[IndexNeuron]=0.0;
            }             
            
            /* Index to iterate through the cells */
            IndexNeuron++;
     
        }    
 
    }    
    
    
    /* LogLikelihoodProb = -inf; */
    (*LogLikelihoodProb)=-DBL_MAX;
    
    /* In case we have zero likelihood at all neurons */
    (*NdxWinRow)=(*NdxWinCol)=0;
    
    /* Compute the probability (likelihood) of each distribution, P(x sub n | q sub nj) */
    
    for(NdxColDist=0;NdxColDist<NumColsMap;NdxColDist++)
    {
        for(NdxRowDist=0;NdxRowDist<NumRowsMap;NdxRowDist++)
        {
            /* For each neuron in the map */
            Index=0;
            DistributionProb=0.0;
            SumNeighbourhoodFunctions=0.0;
            for(NdxCol=0;NdxCol<NumColsMap;NdxCol++)
            {
                for(NdxRow=0;NdxRow<NumRowsMap;NdxRow++)
                {
                
                /* Manhattan distance:
                TopologicDistance=abs(NdxRow-NdxWinRow)+abs(NdxCol-NdxWinCol); */
                TopologicDistance=(double)(abs(NdxRow-NdxRowDist)+abs(NdxCol-NdxColDist));
                
                /* NeighbourhoodFunction=exp(-0.5*((TopologicDistance/RadiusNeighbourhood)^2)); */
                /* This is q sub nji (unnormalized) */
                NeighbourhoodFunction[Index]=exp(-0.5*((TopologicDistance*TopologicDistance)
                                /(RadiusNeighbourhood*RadiusNeighbourhood)));
                
                /* Computing (q sub nji) * P(x sub n | i)
                 with unnormalized q sub nji */
                aux=NeighbourhoodFunction[Index]*ptrLikelihoods[Index];
             
                
                /* Computing P(x sub n | q sub nj)
                 with unnormalized q sub nji*/
                DistributionProb+=aux;
                
                          
                SumNeighbourhoodFunctions+=NeighbourhoodFunction[Index];
                
                Index++;                
                }    
            }
            
            /* Finding P(x sub n | q sub nj)
             by normalizing the previous value */
            DistributionProb/=SumNeighbourhoodFunctions;
            
            
            /* See whether this distribution q sub nj yields the highest probability (likelihood),
               that is, see whether this is the winning distribution */
            MyLogLikelihood=log(DistributionProb);

            if (MyLogLikelihood>(*LogLikelihoodProb))
            {
                (*LogLikelihoodProb)=MyLogLikelihood;
                (*NdxWinRow)=NdxRowDist;
                (*NdxWinCol)=NdxColDist;
                /* Update the neuron responsibilities for the winning distribution */
                Index=0;
                for(NdxCol=0;NdxCol<NumColsMap;NdxCol++)
                {
                    for(NdxRow=0;NdxRow<NumRowsMap;NdxRow++)
                    {
                        ptrResponsibilities[Index]=NeighbourhoodFunction[Index]/SumNeighbourhoodFunctions;
                        Index++;
                    }
                }  
                
            }   
            /* End of the processing of this distribution */    
        }
    }  
    
    
    /* Release memory */    
    mxFree(ptrLikelihoods);
    mxFree(NeighbourhoodFunction);
}    


       
/* Update Chow-Liu tree by Prim's spanning tree algorithm */
struct Edge {
       double Weight;
       int Vertex1,Vertex2;
       };

int CompareFunction(const void * a, const void * b)
{
    struct Edge *MyEdge1;
    struct Edge *MyEdge2;
    double Difference;
    
    MyEdge1=(struct Edge *)a;
    MyEdge2=(struct Edge *)b;
    
    Difference= (MyEdge2->Weight) - (MyEdge1->Weight);
    if (Difference<0.0)
    {
         return -1;
    }
    else
    {
        if (Difference==0.0)
        {
             return 0;
        }
        else
        {
            return 1;
        }
    }
}

void PrimSpanningTree(double *ptrMyMutual,double *ptrMyGraph,int Dimension,int RootNode)
{
     int *InTree;     
     struct Edge *EdgeList;
     int Index,NdxDim,NdxDim2,ListSize,MyVert1,MyVert2;
     
     /* Initialize the set of vertices which belong to the spanning tree (initially, only the root) */
     InTree=mxMalloc(Dimension*sizeof(double));
     memset(InTree,0,Dimension*sizeof(double));
     if (RootNode)
     {
          InTree[RootNode-1]=1;
     }
     else
     {
          InTree[0]=1;
     }

     /* Initialize the output graph */
     memset(ptrMyGraph,0,Dimension*Dimension*sizeof(double));
     
     /* Initialize and sort edge list */
     ListSize=(Dimension*(Dimension-1))/2;
     EdgeList=mxMalloc(ListSize*sizeof(struct Edge));
     Index=0;
     for(NdxDim=0;NdxDim<Dimension;NdxDim++)
     {
          for(NdxDim2=(NdxDim+1);NdxDim2<Dimension;NdxDim2++)
          {
               EdgeList[Index].Weight=ptrMyMutual[NdxDim+Dimension*NdxDim2];
               EdgeList[Index].Vertex1=NdxDim;
               EdgeList[Index].Vertex2=NdxDim2;                                                              
               Index++;
          }
     }

     qsort(EdgeList,ListSize,sizeof(struct Edge),CompareFunction);
     
     /* Find the edges of the spanning tree */
     for(NdxDim=0;NdxDim<(Dimension-1);NdxDim++)
     {          
         /* Interate through the edge list, sorted in descending order of mutual information */
         for(Index=0;Index<ListSize;Index++)
         {
              MyVert1=EdgeList[Index].Vertex1;
              MyVert2=EdgeList[Index].Vertex2;
              /*  Add to the spanning tree only if one vertex belongs and the other does not */
              if (InTree[MyVert1] ^ InTree[MyVert2])
              {
                   if (InTree[MyVert1])
                   {
                        InTree[MyVert2]=1;
                        ptrMyGraph[MyVert1+Dimension*MyVert2]=1.0;
                   }
                   else
                   {
                        InTree[MyVert1]=1;
                        ptrMyGraph[MyVert2+Dimension*MyVert1]=1.0;                        
                   }                                      
                   break;
              }
         }         
     }

     /* Release memory */
     mxFree(InTree);
     mxFree(EdgeList);
     
}
