const char *help = "\
BayesClassifier (c) Bison Ravi et Samy Bengio 2001\n\
\n\
This program will train a BayesClassifier of GMMs\n";

#include "EMTrainer.h"
#include "DiagonalGMM.h"
#include "Kmeans.h"
#include "SeqDataSet.h"
#include "MatSeqDataSet.h"
#include "HtkSeqDataSet.h"
#include "CmdLine.h"
#include "NllMeasurer.h"
#include "BayesClassifier.h"
#include "BayesClassifierMachine.h"
#include "ClassMeasurer.h"
#include "MultiClassFormat.h"
#include "OneHotClassFormat.h"
#include "TwoClassFormat.h"

using namespace Torch;

int main(int argc, char **argv)
{
  char* train_file;
  char* test_file;
  int n_inputs;
  int n_observations;
  int n_targets;
  int max_load;
  int max_load_test;
  int seed_value;
  real accuracy;
  real threshold;
  int max_iter_kmeans;
  int max_iter_gmm;
  char *dir_name;
  int n_gaussians;
  real prior;
  char *load_model;
  char *save_model;
  bool dynamic;
  int kfold;
  bool one_hot;
  bool multi_class;
  bool equal_bayes_prior;

  //=================== The command-line ==========================

  CmdLine cmd;

  // Put the help line at the beginning
  cmd.info(help);

  // Ask for arguments
  cmd.addText("\nArguments:");
  cmd.addSCmdArg("file", &train_file, "the train files, in double-quote");

  // Propose some options
  cmd.addText("\nModel Options:");
  cmd.addICmdOption("-n_gaussians", &n_gaussians, 10, "number of Gaussians");
  cmd.addRCmdOption("-threshold", &threshold, 0.0001, "variance threshold");
  cmd.addRCmdOption("-prior", &prior, 0.001, "prior on the weights of the mixture");
  cmd.addBCmdOption("-equal_bayes_prior", &equal_bayes_prior, false, "give equal prior to each class");

  cmd.addText("\nLearning Options:");
  cmd.addICmdOption("-iterk", &max_iter_kmeans, 25, "max number of iterations of Kmeans initialization");
  cmd.addICmdOption("-iterg", &max_iter_gmm, 25, "max number of iterations of EM training for GMMs");
  cmd.addRCmdOption("-e", &accuracy, 0.0001, "end accuracy");

  cmd.addText("\nMisc Options:");
  cmd.addICmdOption("-Kfold", &kfold, 5, "k-fold crossValidation");
  cmd.addSCmdOption("-test_file", &test_file, "", "test data file");
  cmd.addBCmdOption("-dynamic", &dynamic, false, "dynamic problems");
  cmd.addBCmdOption("-multi_class", &multi_class, false, "data is in multi_class format");
  cmd.addBCmdOption("-one_hot", &one_hot, false, "data is in one_hot format");
  cmd.addICmdOption("-load", &max_load, -1, "max number of train examples to load");
  cmd.addICmdOption("-load_test", &max_load_test, -1, "max number of test examples to load");
  cmd.addICmdOption("-seed", &seed_value, -1, "initial seed for random generator");
  cmd.addSCmdOption("-dir", &dir_name, ".", "directory to save measures");
  cmd.addSCmdOption("-lm", &load_model, "", "start from given model file");
  cmd.addSCmdOption("-sm", &save_model, "", "save results into given model file");

  // Read the command line
  cmd.read(argc, argv);

  // If the user didn't give any random seed,
  // generate a random random seed...
  if (seed_value == -1)
    seed();
  else
    manual_seed((long)seed_value);

  //=================== DataSets ===================

  char* train_files[1000];
  int n;
  train_files[0] = strtok(train_file," ");
  for(n = 1;(train_files[n] = strtok(NULL," "));n++);

  SeqDataSet* data = NULL;
  SeqDataSet* tdata = NULL;

  data = new MatSeqDataSet(train_files, n, 0, -1, 1, false, max_load);
  data->init();
  if (!dynamic)
    data->toOneFramePerExample();

  if(strcmp(test_file, "") ) {
    tdata = new MatSeqDataSet(test_file, 0, -1, 1, false, max_load_test);
    tdata->init();
    if (!dynamic)
      tdata->toOneFramePerExample();
  }
  n_inputs = data->n_inputs;
  n_observations = data->n_observations;
  n_targets = data->n_targets;

  // How the dataset encodes the class format?
  ClassFormat *class_format = NULL;
  if (multi_class)
    class_format = new MultiClassFormat(data);
  else if (one_hot)
    class_format = new OneHotClassFormat(data);
  else
    class_format = new TwoClassFormat(data);

  int n_classes = class_format->getNumberOfClasses();

  // for the Gaussian mixtures, give the minimum variance value per dimension
  real* thresh = (real*)xalloc(n_observations*sizeof(real));
  for (int i=0;i<n_observations;i++)
    thresh[i] = threshold;
  
  // create the model. For each class, we need to create a DiagonalGMM. This
  // DiagonalGMM will be initialized by a Kmeans which will be trained by a
  // EMTrainer (we are talking about the Kmeans). We will also record the
  // Kmeans score during training as well as the DiagonalGMM during EM.
  Kmeans** kmeans = new Kmeans *[n_classes];
  EMTrainer** kmeans_trainer = new EMTrainer *[n_classes];
  DiagonalGMM** gmm = new DiagonalGMM *[n_classes];
  Trainer** trainer = new Trainer *[n_classes];
  NllMeasurer** nll_meas_kmeans = new NllMeasurer *[n_classes];
  NllMeasurer** nll_meas_gmm = new NllMeasurer *[n_classes];
  
  List** meas_kmeans = new List *[n_classes];
  List** meas_gmm = new List *[n_classes];
  
  for(int i = 0;i < n_classes;i++) {
    meas_kmeans[i]  = NULL;
    meas_gmm[i]  = NULL;
  }
  
  for(int classe = 0;classe < n_classes;classe++) {
    kmeans[classe] = new Kmeans(n_observations,n_gaussians,thresh,prior,data);
    kmeans[classe]->init();
    kmeans[classe]->reset();
  
    kmeans_trainer[classe] = new EMTrainer(kmeans[classe],data);
    kmeans_trainer[classe]->setROption("end accuracy", accuracy);
    kmeans_trainer[classe]->setIOption("max iter", max_iter_kmeans);
    
    char kmeans_name[100];
    sprintf(kmeans_name,"%s/kmeans_val_%d",dir_name, classe);
    nll_meas_kmeans[classe] = new NllMeasurer(kmeans[classe]->outputs, data, kmeans_name);
    nll_meas_kmeans[classe]->init();
    addToList(&meas_kmeans[classe], 1, nll_meas_kmeans[classe]);

    gmm[classe] = new DiagonalGMM(n_observations, n_gaussians, thresh, prior);
    gmm[classe]->setOption("initial kmeans trainer",&kmeans_trainer[classe]);
    gmm[classe]->setOption("initial kmeans trainer measurers",&meas_kmeans[classe]);
    gmm[classe]->init();
    
    trainer[classe] = new EMTrainer(gmm[classe],data);
    trainer[classe]->setROption("end accuracy", accuracy);
    trainer[classe]->setIOption("max iter", max_iter_gmm);

    char gmm_name[100];
    sprintf(gmm_name,"%s/gmm_val_%d",dir_name, classe);
    nll_meas_gmm[classe] = new NllMeasurer(gmm[classe]->outputs, data, gmm_name);
    nll_meas_gmm[classe]->init();
    addToList(&meas_gmm[classe], 1, nll_meas_gmm[classe]);
  }

  message(">>> GMMs initialized <<< ");

  // The BayesClassifier can be given a prior probability for each class
  // in order to weight the posterior
  real* bayes_prior = NULL;
  if (equal_bayes_prior) {
    bayes_prior = (real*)xalloc(n_classes*sizeof(real));
    for (int i=0;i<n_classes;i++)
      bayes_prior[i] = -log((real)n_classes);
  }

  BayesClassifierMachine machine(trainer, n_classes, meas_gmm,class_format,bayes_prior);
  machine.init();

  BayesClassifier bayes(&machine, data);

  
  List *measurers = NULL;
  List *test_measurers = NULL;
  char bayes_mes_name[100];
  sprintf(bayes_mes_name, "%s/bayes_train_err", dir_name);
  char bayes_tmes_name[100];
  sprintf(bayes_tmes_name, "%s/bayes_test_err", dir_name);

  Measurer* mes = NULL;
  Measurer* tmes = NULL;

  mes = new ClassMeasurer(machine.outputs, data, class_format, bayes_mes_name);
  mes->init();
  addToList(&measurers, 1, mes);
  if(strcmp(test_file, "")) {
    tmes = new ClassMeasurer(machine.outputs, tdata, class_format, bayes_tmes_name);
    tmes->init();
    addToList(&measurers, 1, tmes);
  } else {
    tmes = new ClassMeasurer(machine.outputs, data, class_format, bayes_tmes_name);
    tmes->init();
    addToList(&test_measurers, 1, tmes);
  }

    
  // =========== Training or Testing =================


  char load_model_name[100];
  if(strcmp(load_model, ""))
    sprintf(load_model_name, "%s/%s", dir_name, load_model);

  if(strcmp(load_model, "")) {
    bayes.load(load_model_name);
    bayes.test(measurers);
  } else {
    if(!strcmp(test_file, ""))
      bayes.crossValidate(kfold, measurers, test_measurers);
    else {
      bayes.train(NULL);
      bayes.test(measurers);
    }
    
    if(strcmp(save_model, "")) {
      char save_model_name[100];
      sprintf(save_model_name, "%s/%s", dir_name, save_model);
      bayes.save(save_model_name);
    }
  }
  
  //if you love someone, set them free
  
  for(int classe = 0;classe < n_classes;classe++) {
    delete(kmeans[classe]);
    delete(kmeans_trainer[classe]);
    delete(gmm[classe]);
    delete(trainer[classe]);
    delete(nll_meas_kmeans[classe]);
    freeList(&meas_kmeans[classe]);
    delete(nll_meas_gmm[classe]);
    freeList(&meas_gmm[classe]);
  }

  delete[] kmeans;
  delete[] kmeans_trainer;
  delete[] gmm;
  delete[] trainer;
  delete[] nll_meas_kmeans;
  delete[] meas_kmeans;
  delete[] nll_meas_gmm;
  delete[] meas_gmm;
  free(thresh);

  delete mes;
  delete tmes;
  freeList(&measurers);
  if (test_measurers)
    freeList(&test_measurers);
  
  if (equal_bayes_prior)
    free(bayes_prior);

  delete class_format;
  delete data;
  delete tdata;
  return(0);
}
