% Demo of MMED (probability density estimation and missing value estimation)
% Please note that the full execution of this script requires a long time to complete
clear all


% UCI benchmarks
DatasetNames={'UCI/DataGlass','UCI/DataLetter','UCI/DataLiver','UCI/DataPima','UCI/DataYeast',...
    'UCI/DataVowel','UCI/DataCloud','UCI/DataPendigits',...
	'UCI/DataBalanceScale','UCI/DataBreastCancerWisconsin','UCI/DataContraceptive',...
    'UCI/DataSegmentation','UCI/DataTAE','UCI/DataWine',...
    'UCI/DataPostOperative','UCI/DataServo','UCI/DataHayesRoth','UCI/DataHaberman','UCI/DataIris',...
    'UCI/DataSpaceShuttle1','UCI/DataSpaceShuttle2'};



% Probability of NaN in test data for missing values experiments
FractionNaNTest=0.05;

NumUnitsPerGroup=18;
NumSteps=100000;
MaxNumGroups=5;
PruningThreshold=0.00001;



for NdxDataBase=1:numel(DatasetNames)
    % Read the UCI dataset
    load(sprintf('%s.mat',DatasetNames{NdxDataBase}));
    
    % We read from either 'Data' or 'X'
    if ~isempty(who('Data'))
	    % Unir las muestras de todas las clases
	    AllSamples=[];
	    for NdxClass=1:Data.NumClasses
	        AllSamples=[AllSamples Data.Samples{NdxClass}];
        end
    else
        if ~isempty(who('X'))
            AllSamples=X;
        else
            AllSamples=Samples;
            clear Samples;
        end
    end
    
    % Precompute constants
    fprintf('Precomputing constants...\n')
    Dimension=size(AllSamples,1);
    Limits=[0 0.4 0.8 1.1 1.4 1.8 2.25 2.5 3 4.5 6 10 100];
    Limits=Limits.^(1/Dimension);    
    Constants=PrecomputeConstants(Dimension,Limits);
    Constants(find(isnan(Constants)))=0;
    fprintf('Constants are precomputed\n')
    
    % Create sample groups at random
    Groups=ceil(10*rand(size(AllSamples,2),1));
    

    % Run experiments for different numbers of GMCs
    for NumGroups=1:MaxNumGroups
        
        if exist('MMEDbatch.mat','file')
            load('MMEDbatch.mat','DatasetNames','MSEMeanMMED','MSEStdMMED',...
                'MSEMeanGlobal','MSEStdGlobal','ANLLMeanMMED','ANLLStdMMED','TimeMeanMMED','TimeStdMMED',...
                'BICMeanMMED','BICStdMMED','AICMeanMMED','AICStdMMED','FractionNaNTest',...
                'NdxBestNumberGMCs');
            if (size(ANLLMeanMMED,1)>=NdxDataBase) && (size(ANLLMeanMMED,2)>=NumGroups)
                if ANLLMeanMMED(NdxDataBase,NumGroups)~=0
                    disp('Skipping already executed experiments')
                    continue;
                end
            end
        end
        
        % Run the 10-fold cross-validation
        ANLL=[];
        BIC=[];
        AIC=[];
        MSETestMMED=[];
        MSETestGlobal=[];
        ElapsedTime=[];
        for NdxRepetition=1:10    
            % Obtain training and test samples
            SamplesTrain=AllSamples(:,find(Groups~=NdxRepetition));
            NumSamplesTrain=size(SamplesTrain,2);
            SamplesTest=AllSamples(:,find(Groups==NdxRepetition));
            if size(SamplesTest,2)==0
                SamplesTest=AllSamples(:,1);
            end

            Samples=SamplesTrain;
            
            if NumSamplesTrain<100
                % Introduce noise to avoid singular covariance matrices
                Samples=Samples.*(1000+rand(size(Samples)))/1000;
            end

            % Global mean
            GlobalMean=zeros(size(Samples,1),1);
            NumValidSamples=zeros(size(Samples,1),1);
            for NdxPat = 1 : size(Samples,2)
                MyPattern=Samples(:,NdxPat);
                NumValidSamples(find(isfinite(MyPattern)))=NumValidSamples(find(isfinite(MyPattern)))+1;
                MyPattern(find(isnan(MyPattern)))=0;
                GlobalMean=GlobalMean+MyPattern;
            end
            GlobalMean=GlobalMean./NumValidSamples;

            try
                                
                [Model,MyTime]=TrainMMED(Samples,NumGroups,NumUnitsPerGroup,NumSteps,PruningThreshold,...
                    Constants,Limits);
                ElapsedTime(end+1)=MyTime;

                % Measure ANLL, AIC and BIC on test data without missing values
                [MyANLL,LogProbDensities,ResponsibilitiesUnits,ResponsibilitiesGroups] = ...
                    MMEDANLLMEX(SamplesTest,Model);

                ANLL(end+1)=MyANLL;
                LogLikelihood=-size(SamplesTest,2)*MyANLL;
                BIC(end+1)=-2*LogLikelihood+...
                    Model.NumParameters*log(size(SamplesTest,2));
                AIC(end+1)=2*Model.NumParameters-2*LogLikelihood;

                % Test the efectiveness of missing value imputation
                BadTestSamples=SamplesTest;
                SamplesTestMeanEstimation=SamplesTest;

                for NdxNaN=1:FractionNaNTest*numel(BadTestSamples)
                    MyIndex=ceil(rand(1)*numel(BadTestSamples));
                    BadTestSamples(MyIndex)=NaN;
                    SamplesTestMeanEstimation(MyIndex)=GlobalMean(mod(MyIndex,length(GlobalMean))+1);
                end

                [EstimatedSamples,EstimatedSamplesModel]=EstimateMissingDataMMED(Model,BadTestSamples);

                Errors=SamplesTest-EstimatedSamples;
                MSETestMMED(end+1)=(sum(sum(Errors.^2)))/numel(SamplesTest);


                % Errors using the global mean as an estimator (test set)
                ErrorsGlobalMean=SamplesTest-SamplesTestMeanEstimation;
                MSETestGlobal(end+1)=(sum(sum(ErrorsGlobalMean.^2)))/numel(SamplesTest);
              catch
                  disp('Error while running experiments')
              end




        end

        
        % Find the mean and the standard deviation of the performance measures obtained
        % by 10-fold cross-validation
        if exist('MMEDbatch.mat','file')
            load('MMEDbatch.mat','DatasetNames','MSEMeanMMED','MSEStdMMED',...
                'MSEMeanGlobal','MSEStdGlobal','ANLLMeanMMED','ANLLStdMMED','TimeMeanMMED','TimeStdMMED',...
                'BICMeanMMED','BICStdMMED','AICMeanMMED','AICStdMMED','FractionNaNTest',...
                'NdxBestNumberGMCs');
        end
        TimeMeanMMED(NdxDataBase,NumGroups)=mean(ElapsedTime);
        TimeStdMMED(NdxDataBase,NumGroups)=std(ElapsedTime);
        MSEMeanMMED(NdxDataBase,NumGroups)=mean(MSETestMMED);
        MSEStdMMED(NdxDataBase,NumGroups)=std(MSETestMMED);
        MSEMeanGlobal(NdxDataBase,NumGroups)=mean(MSETestGlobal);
        MSEStdGlobal(NdxDataBase,NumGroups)=std(MSETestGlobal);
        ANLLMeanMMED(NdxDataBase,NumGroups)=mean(ANLL);
        % Find the best configuration so far
        [Minimum Index]=min(ANLLMeanMMED(NdxDataBase,1:NumGroups));
        NdxBestNumberGMCs(NdxDataBase)=Index;
        
        ANLLStdMMED(NdxDataBase,NumGroups)=std(ANLL); 
        BICMeanMMED(NdxDataBase,NumGroups)=mean(BIC);
        BICStdMMED(NdxDataBase,NumGroups)=std(BIC);  
        AICMeanMMED(NdxDataBase,NumGroups)=mean(AIC);
        AICStdMMED(NdxDataBase,NumGroups)=std(AIC);  
        save('MMEDbatch.mat','DatasetNames','MSEMeanMMED','MSEStdMMED',...
            'MSEMeanGlobal','MSEStdGlobal','ANLLMeanMMED','ANLLStdMMED','TimeMeanMMED','TimeStdMMED',...
            'BICMeanMMED','BICStdMMED','AICMeanMMED','AICStdMMED','FractionNaNTest',...
            'NdxBestNumberGMCs');
    end

    
    
    
    
end



        