// Copyright (C) 2002 Johnny Mariethoz  (Johnny.Mariethoz@idiap.ch)
//                and Bison Ravi  (francois.belisle@idiap.ch)
//                and Samy Bengio (bengio@idiap.ch)
//                
//
// This file is part of Torch. Release II.
// [The Ultimate Machine Learning Library]
//
// Torch is free software; you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation; either version 2 of the License, or
// (at your option) any later version.
//
// Torch is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Torch; if not, write to the Free Software
// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA


#include "SeqDataSet.h"
#include "Distribution.h"

namespace Torch {

#ifdef USEDOUBLE
#define REAL_FORMAT "%lf"
#else
#define REAL_FORMAT "%f"
#endif

SeqDataSet::SeqDataSet() 
{

  examples = NULL;
  kind = NULL;
  seqtargets = NULL;
  targets = NULL;
  n_seqtargets = 0;
  n_targets = 0;
  inputs = NULL;
  n_frames = 0;
  current_example = 0;
  current_frame = 0;
  file_names = NULL;
  n_file_names = 0;

  addBOption("normalize inputs", &norm_inputs, false,"normalize the inputs by mean/stdv");
  addBOption("normalize observations", &norm_observations, false,"normalize the observations by mean/stdv");

  mean_in = NULL;
  mean_ob = NULL;
  stdv_in = NULL;
  stdv_ob = NULL;
}

SeqDataSet::~SeqDataSet()
{
  freeList(&inputs);
  free(mean_in);
  free(mean_ob);
  free(stdv_in);
  free(stdv_ob);
}

void SeqDataSet::init()
{
  DataSet::init();
  addToList(&inputs, n_inputs, NULL);
  totNFrames();
  normalize();
}

int SeqDataSet::removeUnlikelyFrames(Distribution* likely_distr, Distribution* unlikely_distr,int obs_offset,int range){
  likely_distr->eMIterInitialize();
  unlikely_distr->eMIterInitialize();
  totNFrames();
  int tot = tot_n_frames;
  for(int i=0;i<n_real_examples;i++){
    setExample(i);
    SeqExample* ex = (SeqExample*)inputs->ptr;
    unlikely_distr->eMSequenceInitialize(inputs);
    likely_distr->eMSequenceInitialize(inputs);
    for(int j=0;j<ex->n_frames;j++){
			real* obs = &ex->observations[j][obs_offset];
      unlikely_distr->frameLogProbability(obs,ex->inputs ? &ex->inputs[j][obs_offset] : NULL,j);
      likely_distr->frameLogProbability(obs,ex->inputs ? &ex->inputs[j][obs_offset] : NULL,j);
    } 
    // range-1 = number of frame before and after the current frame to means
    // the decision
    int num_frames = 0;

    for(int j=range;j<ex->n_frames-range;j++){
      real p_unlikely = unlikely_distr->log_probabilities[j];
      real p_likely = likely_distr->log_probabilities[j];
      for(int k=1;k<range+1;k++){
        p_unlikely +=  unlikely_distr->log_probabilities[j+k] + unlikely_distr->log_probabilities[j-k];
        p_likely += likely_distr->log_probabilities[j+k] + likely_distr->log_probabilities[j-k];
      }
      if(p_likely >= p_unlikely){
        ex->observations[num_frames] = ex->observations[j];
        if (ex->inputs)
          ex->inputs[num_frames] = ex->inputs[j];
        num_frames++;
      }else{
        //free(ex->observations[j]);
        ex->observations[j] = NULL;
        if (ex->inputs) {
          free(ex->inputs[j]);
          ex->inputs[j] = NULL;
        }
      }
    }
    for(int j=num_frames;j<ex->n_frames;j++){
      ex->observations[j] = NULL;
      if (ex->inputs)
        ex->inputs[j] = NULL;
    }
    ex->n_frames = num_frames;
    ex->n_real_frames = ex->n_frames;
    totNFrames();
  }
  return (tot-tot_n_frames);
}

int SeqDataSet::xwavesRemoveUnlikelyFrames(Distribution* likely_distr, Distribution* unlikely_distr,bool* frame_to_keep,int obs_offset,int range){
	likely_distr->eMIterInitialize();
	unlikely_distr->eMIterInitialize();
	totNFrames();
	int tot = tot_n_frames;
	int num_frames = 0;
	for(int i=0;i<n_real_examples;i++){
		setExample(i);
		SeqExample* ex = (SeqExample*)inputs->ptr;
		unlikely_distr->eMSequenceInitialize(inputs);
		likely_distr->eMSequenceInitialize(inputs);
		for(int j=0;j<ex->n_frames;j++){
			real* obs = &ex->observations[j][obs_offset];
      unlikely_distr->frameLogProbability(obs,ex->inputs ? &ex->inputs[j][obs_offset] : NULL,j);
      likely_distr->frameLogProbability(obs,ex->inputs ? &ex->inputs[j][obs_offset] : NULL,j);
		} 
		// range-1 = number of frame before and after the current frame to means
		// the decision

		for(int j=range;j<ex->n_frames-range;j++){
			real p_unlikely = unlikely_distr->log_probabilities[j];
			real p_likely = likely_distr->log_probabilities[j];
			for(int k=1;k<range+1;k++){
				p_unlikely +=  unlikely_distr->log_probabilities[j+k] + unlikely_distr->log_probabilities[j-k];
				p_likely += likely_distr->log_probabilities[j+k] + likely_distr->log_probabilities[j-k];
			}
			if(p_likely >= p_unlikely){
				num_frames++;
				//printf("%d %.3f %.3f speech\n",j,p_likely,p_unlikely);
				frame_to_keep[j] = true;
			}else{
				//free(ex->observations[j]);
				frame_to_keep[j] = false;
				//printf("%d %.3f %.3f sil\n",j,p_likely,p_unlikely);
			}
		}
	}
return (tot-num_frames);
}


void SeqDataSet::selectBootstrap(){
  int n = n_examples;
  for (int j=0;j<n;j++) {
    setExample(j);
    int n_selected_frames = n_frames;
    int* selected_frames = (int*)xalloc(sizeof(int)*n_selected_frames);
    for (int l=0;l<n_frames;l++)
      selected_frames[l] = (int)floor(bounded_uniform(0,n_frames));
    setSelectedFrames(selected_frames,n_selected_frames);
    free(selected_frames);
  }
}

void SeqDataSet::linearSegmentation(int ith_segment,int n_segment){
  int n = n_examples;
  // linear segmentation
  for (int j=0;j<n;j++) {
    setExample(j);
    int n_frames_per_state = n_frames/(n_segment-1);
    int from = (ith_segment-1)*n_frames_per_state;
    int to = (ith_segment == n_segment-2 ? n_frames : ith_segment*n_frames_per_state);
    int n_selected_frames = to - from;
    int* selected_frames = (int*)xalloc(sizeof(int)*n_selected_frames);
    int k = 0;
    for (int l=from;l<to;l++,k++)
      selected_frames[k] = l;
    setSelectedFrames(selected_frames,n_selected_frames);
    free(selected_frames);
  }
}

void SeqDataSet::totNFrames()
{
  tot_n_frames = 0;

  for(int i = 0;i < n_examples;i++) 
  {
    setExample(i);
    tot_n_frames += n_frames;
  }
}

void SeqDataSet::setRealExample(int t)
{
  current_example = t;

  SeqExample *ex = &examples[current_example];
  inputs->ptr = ex;
  seqtargets = ex->seqtargets;
  // to keep compatibility with DataSet, suppose first value is target
  if (seqtargets)
    targets = ex->seqtargets[0];
  n_seqtargets = ex->n_seqtargets;
  n_frames = ex->n_frames;
  setFrame(0);
}

void SeqDataSet::setFrame(int t)
{
  if(examples[current_example].selected_frames)
    current_frame = examples[current_example].selected_frames[t];
  else
    current_frame = t;

  examples[current_example].current_frame = current_frame;
}

void setFrameExample(SeqExample* ex, int frame)
{
  if(ex->selected_frames)
    ex->current_frame = ex->selected_frames[frame];
  else
    ex->current_frame = frame;
}


void SeqDataSet::setSelectedFrames(int *selected_frames, int n_selected_frames)
{
  int t = current_example;
  if (examples[t].n_frames != examples[t].n_real_frames)
    error("cannot select frames, of an example with already selected frames");
  examples[t].n_frames = n_selected_frames;

  examples[t].selected_frames = (int*)xalloc(sizeof(int) * n_selected_frames);
  int *f_from = selected_frames;
  int *f_to = examples[t].selected_frames;

  for(int i = 0;i < n_selected_frames;i++) 
    *f_to++ = *f_from++;
}


void SeqDataSet::unsetSelectedFrames()
{
  int t = current_example;
  if(examples[t].selected_frames) {
    free(examples[t].selected_frames);
    examples[t].selected_frames = NULL;
  }
  examples[t].n_frames = examples[t].n_real_frames;
}

void SeqDataSet::unsetAllSelectedFrames(){
  for (int j=0;j<n_examples;j++) {
    setExample(j);
    unsetSelectedFrames();
  }
}

void SeqDataSet::removeUnusedData(bool* mask){
  int new_vect_size = 0;

  /* find new vector size */
  for(int i=0;i<n_observations;i++){
    if(!mask[i])
      new_vect_size++;
  }
  for(int i=0;i<n_real_examples;i++)
    for(int j=0;j<examples[i].n_frames;j++){
      int k=0;
      real* vect = examples[i].observations[j];
      for(int l=0;l<n_observations;l++)
        if(!mask[l]){
          vect[k] = vect[l];
          k++;
        }
    }
  n_observations = new_vect_size;
}


void SeqDataSet::toOneFramePerExample()
{ 
  int old_n_examples = n_real_examples;
  SeqExample* old_examples = examples;

  //find the n frames total
  totNFrames();
  file_names = (char**) xrealloc(file_names,tot_n_frames*sizeof(char*));
  examples = (SeqExample*)xalloc(sizeof(SeqExample) * tot_n_frames);

  n_real_examples = 0;

  for(int i = 0;i < old_n_examples;i++)
  {
    for(int j = 0;j < old_examples[i].n_frames;j++, n_real_examples++)
    {
      examples[n_real_examples].n_frames = 1;
      examples[n_real_examples].n_real_frames = 1;
      examples[n_real_examples].selected_frames = NULL;
      examples[n_real_examples].name = old_examples[i].name;
			file_names[n_real_examples] = examples[n_real_examples].name;
      examples[n_real_examples].n_alignments = 0;
      examples[n_real_examples].alignment = NULL;
      examples[n_real_examples].alignment_phoneme = NULL;

      if(old_examples[i].n_seqtargets > 0) {
        examples[n_real_examples].n_seqtargets = 1;
        examples[n_real_examples].seqtargets = (real**) xalloc(sizeof(real*));
        examples[n_real_examples].seqtargets[0] = old_examples[i].seqtargets[j];
      } else {
        examples[n_real_examples].seqtargets = NULL;
        examples[n_real_examples].n_seqtargets = 0;
      }

      if(n_observations > 0) {
        examples[n_real_examples].observations = (real**) xalloc(sizeof(real*));
        examples[n_real_examples].observations[0] = old_examples[i].observations[j];
      } else
        examples[n_real_examples].observations = NULL;

      if(n_inputs > 0) {
        examples[n_real_examples].inputs = (real**) xalloc(sizeof(real*));
        examples[n_real_examples].inputs[0] = old_examples[i].inputs[j];
      } else
        examples[n_real_examples].inputs = NULL;
    }

    if(old_examples[i].seqtargets) {
      free(old_examples[i].seqtargets);
      old_examples[i].seqtargets = NULL;
    }

    if(old_examples[i].n_alignments>0) {
      free(old_examples[i].alignment);
      free(old_examples[i].alignment_phoneme);
      old_examples[i].alignment = NULL;
      old_examples[i].alignment_phoneme = NULL;
    }

    if(old_examples[i].observations) {
      free(old_examples[i].observations);
      old_examples[i].observations = NULL;
    }

    if(old_examples[i].inputs) {
      free(old_examples[i].inputs);
      old_examples[i].inputs = NULL;
    }

    if(old_examples[i].selected_frames) {
      free(old_examples[i].selected_frames);
      old_examples[i].selected_frames = NULL;
    }
    old_examples[i].name = NULL;
  }
  free(old_examples);
  n_examples = n_real_examples;

  selected_examples = (int *)xrealloc(selected_examples,sizeof(int)*n_examples);
  for(int i = 0; i < n_examples; i++)
    selected_examples[i] = i;

}

void SeqDataSet::display()
{
  if (kind)
    message("kind: %s", kind);
  print("n examples: %d\n", n_real_examples);
  print("n observations: %d\n", n_observations);
  print("n inputs: %d\n", n_inputs);

  for(int i = 0;i < n_real_examples;i++)
  {
    print("n frames: %d in example: %d\n", examples[i].n_frames, i);
    for(int j = 0;j < examples[i].n_frames;j++)
    {
      if (n_inputs>0) {
        print("inputs ");
        for(int k = 0;k < n_inputs;k++)
          print("%g", examples[i].inputs[j][k]);
      }
      if (n_observations>0) {
        print("obs ");
        for(int k = 0;k < n_observations;k++)
          print("%g ", examples[i].observations[j][k]);
      }
      print("\n");
    }
  }  
}

void SeqDataSet::normalize()
{
  // supposes that totNFrames() is correct
  if(norm_inputs) {
    if(!mean_in) {
      mean_in = (real*)xalloc(n_inputs * sizeof(real));
      stdv_in = (real*)xalloc(n_inputs * sizeof(real));
    }

    real** all_inputs = (real**)xalloc(tot_n_frames * sizeof(real*));

    int k=0;
    for(int ex = 0;ex < n_real_examples;ex++)
      for(int i = 0;i < examples[ex].n_frames;i++,k++)
        all_inputs[k] = examples[ex].inputs[i];

    MSTDVNormalize(all_inputs, mean_in, stdv_in, tot_n_frames, n_inputs);

    for(int i = 0;i < n_real_examples;i++)
      for (int j=0;j<examples[i].n_frames;j++)
        for(int d = 0;d < n_inputs;d++)
          examples[i].inputs[j][d] = (examples[i].inputs[j][d]-mean_in[d])/stdv_in[d];
    free(all_inputs);
  }

  if(norm_observations) {
    if(!mean_ob) {
      mean_ob = (real*)xalloc(n_observations * sizeof(real));
      stdv_ob = (real*)xalloc(n_observations * sizeof(real));
    }

    real** all_observations = (real**)xalloc(tot_n_frames*sizeof(real*));

    int k=0;
    for(int ex = 0;ex < n_real_examples;ex++)
      for(int i = 0;i < examples[ex].n_frames;i++,k++)
        all_observations[k] = examples[ex].observations[i];

    MSTDVNormalize(all_observations, mean_ob, stdv_ob, tot_n_frames, n_observations);

    for(int i = 0;i < n_real_examples;i++)
      for (int j=0;j<examples[i].n_frames;j++)
        for(int d = 0;d < n_observations;d++)
          examples[i].observations[j][d] = (examples[i].observations[j][d]-mean_ob[d])/stdv_ob[d];
    free(all_observations);
  }
}

void SeqDataSet::normalizeUsingDataSet(SeqDataSet *data_norm)
{
  if(data_norm->norm_inputs) {
    if(n_inputs == data_norm->n_inputs) {
      for(int i = 0;i < n_real_examples;i++)
        for (int j=0;j<examples[i].n_frames;j++)
          for(int d = 0;d < n_inputs;d++)
            examples[i].inputs[j][d] = (examples[i].inputs[j][d]-data_norm->mean_in[d])/data_norm->stdv_in[d];
    } else
      warning("SeqDataSet: the normalization machine has not the good input size");
  }

  if(data_norm->norm_observations) {
    if(n_observations == data_norm->n_observations) {
      for(int i = 0;i < n_real_examples;i++)
        for (int j=0;j<examples[i].n_frames;j++)
          for(int d = 0;d < n_observations;d++)
            examples[i].observations[j][d] = (examples[i].observations[j][d]-data_norm->mean_ob[d])/data_norm->stdv_ob[d];
    } else
      warning("SeqDataSet: the normalization machine has not the good observation size");
  }
}

void SeqDataSet::readTargets(char* file)
{
  // nothing generic...
}

void SeqDataSet::readAlignments(char* file, bool needs_all_examples)
{
  // nothing generic...
}

void SeqDataSet::saveFILE(FILE *file)
{
  if(norm_inputs)
  {
    xfwrite(mean_in, sizeof(real), n_inputs, file);
    xfwrite(stdv_in, sizeof(real), n_inputs, file);
  }

  if(norm_observations)
  {
    xfwrite(mean_ob, sizeof(real), n_observations, file);
    xfwrite(stdv_ob, sizeof(real), n_observations, file);
  }
}

void SeqDataSet::loadFILE(FILE *file)
{
  if(norm_inputs) {
    for(int i = 0;i < n_real_examples;i++)
      for (int j=0;j<examples[i].n_frames;j++)
        for(int d = 0;d < n_inputs;d++)
          examples[i].inputs[j][d] = examples[i].inputs[j][d]*stdv_in[d]+mean_in[d];

    xfread(mean_in, sizeof(real), n_inputs, file);
    xfread(stdv_in, sizeof(real), n_inputs, file);

    for(int i = 0;i < n_real_examples;i++)
      for (int j=0;j<examples[i].n_frames;j++)
        for(int d = 0;d < n_inputs;d++)
          examples[i].inputs[j][d] = (examples[i].inputs[j][d]-mean_in[d])/stdv_in[d];
  }

  if(norm_observations) {
    for(int i = 0;i < n_real_examples;i++)
      for (int j=0;j<examples[i].n_frames;j++)
        for(int d = 0;d < n_observations;d++)
          examples[i].observations[j][d] = examples[i].observations[j][d]*stdv_ob[d]+mean_ob[d];

    xfread(mean_ob, sizeof(real), n_observations, file);
    xfread(stdv_ob, sizeof(real), n_observations, file);

    for(int i = 0;i < n_real_examples;i++)
      for (int j=0;j<examples[i].n_frames;j++)
        for(int d = 0;d < n_observations;d++)
          examples[i].observations[j][d] = (examples[i].observations[j][d]-mean_ob[d])/stdv_ob[d];
  }
}


}

