// Copyright (C) 2002 Ronan Collobert (collober@iro.umontreal.ca)
//                
//
// This file is part of Torch. Release II.
// [The Ultimate Machine Learning Library]
//
// Torch is free software; you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation; either version 2 of the License, or
// (at your option) any later version.
//
// Torch is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Torch; if not, write to the Free Software
// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA

#include "Boosting.h"
#include "BoostingMeasurer.h"
#include "random.h"

namespace Torch {

static void randw(int *selected_examples, real *ex_weights, int n_examples)
{
  real *repartition = (real *)xalloc(sizeof(real)*(n_examples+1));
  repartition[0] = 0;
  for(int i = 0; i < n_examples; i++)
    repartition[i+1] = repartition[i]+ex_weights[i];

  for(int i = 0; i < n_examples; i++)
  {
    real z = uniform();
    int gauche = 0;
    int droite = n_examples;
    while(gauche+1 != droite)
    {
      int centre = (gauche+droite)/2;
      if(repartition[centre] < z)
        gauche = centre;
      else
        droite = centre;
    }
    selected_examples[i] = gauche;
//    printf("%g < %g < %g\n", repartition[gauche], z, repartition[gauche+1]);
  }
  free(repartition);
}

Boosting::Boosting(WeightedSumMachine* w_machine_, DataSet* data_, ClassFormat *class_format_) : Trainer(w_machine_, data_)
{
  w_machine = w_machine_;
  class_format = class_format_;

  n_trainers = w_machine->n_trainers;
  weights = w_machine->weights;
}

void Boosting::train(List* measurers)
{
  int n_examples = data->n_examples;
  int *selected_examples = (int *)xalloc(n_examples*sizeof(int));
  real *ex_weights = (real *)xalloc(n_examples*sizeof(real));
  for(int t = 0; t < n_examples; t++)
    ex_weights[t] = 1./((real)n_examples);

  BoostingMeasurer *measurer = new BoostingMeasurer(class_format, "/dev/null");
  measurer->init();

  measurer->setData(data);
  measurer->setWeights(ex_weights);

  message("Boosting: training...");
  w_machine->n_trainers_trained = 0;

  List *the_boost_mes = NULL;
  addToList(&the_boost_mes, 1, measurer);
  for(int i = 0; i < n_trainers; i++)
  {
    randw(selected_examples, ex_weights, n_examples);
    data->pushSubset(selected_examples, n_examples);
    w_machine->trainers[i]->machine->reset();
    w_machine->trainers[i]->train(w_machine->trainers_measurers ? w_machine->trainers_measurers[i] : NULL);
    data->popSubset();

    // Find beta and all missclass
    measurer->setInputs(w_machine->trainers[i]->machine->outputs);
    w_machine->trainers[i]->test(the_boost_mes);

    // Compute new weights
    int *ptr_status = measurer->status;
    real *ptr_ex_weights = ex_weights;
    real beta = measurer->beta;
    for(int t = 0; t < n_examples; t++)
    {
      if(*ptr_status++ > 0)
        *ptr_ex_weights *= beta;
      ptr_ex_weights++;
    }

    // Normalize
    ptr_ex_weights = ex_weights;
    real z = 0;
    for(int t = 0; t < n_examples; t++)
      z += *ptr_ex_weights++;

    ptr_ex_weights = ex_weights;
    for(int t = 0; t < n_examples; t++)
      *ptr_ex_weights++ /= z;

    // Warning: no precautions, no normalization
    weights[i] = -log(beta);

    w_machine->n_trainers_trained = i+1;

    if(measurers)
      test(measurers);
  }

  freeList(&the_boost_mes);
  free(selected_examples);
  free(ex_weights);

  delete measurer;
}

Boosting::~Boosting()
{
}

}

