// Copyright (C) 2002 Samy Bengio (samy.bengio@idiap.ch)
//                and Bison Ravi (francois.belisle@idiap.ch)
//                
//
// This file is part of Torch. Release II.
// [The Ultimate Machine Learning Library]
//
// Torch is free software; you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation; either version 2 of the License, or
// (at your option) any later version.
//
// Torch is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Torch; if not, write to the Free Software
// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA

#include "MatSeqDataSet.h"
#include "string_utils.h"

namespace Torch {

#ifdef USEDOUBLE
#define REAL_FORMAT "%lf"
#else
#define REAL_FORMAT "%f"
#endif

MatSeqDataSet::MatSeqDataSet():SeqDataSet(){
  examples_from_stddataset = false;
}

MatSeqDataSet::MatSeqDataSet(char* file, int n_inputs_ , int n_observations_,  
                              int n_targets_, bool bin, int max_load)
{

  examples_from_stddataset = false;

  n_real_examples = 0;

  if((n_inputs_ == -1 && n_observations_ == -1) ||
      (n_inputs_ == -1 && n_targets_ == -1) ||
      (n_observations_ == -1 && n_targets_ == -1))
    error(" only one field at a time can have value -1 ");

  
  if(n_inputs_ == -1)
    n_inputs = columns(file, bin) - (n_observations_ + n_targets_);
  else
    n_inputs = n_inputs_;

  if(n_observations_ == -1)
    n_observations = columns(file, bin) - (n_inputs + n_targets_);
  else
    n_observations = n_observations_;

  if(n_targets_ == -1)
    n_targets = columns(file, bin) - (n_inputs + n_observations);
  else
    n_targets = n_targets_;

  examples = (SeqExample*) xalloc(sizeof(SeqExample));
  file_names = (char**) xalloc(sizeof(char*));
  n_file_names = 0;
  readMat(file, bin, max_load);
  n_real_examples++;
  n_file_names++;
}

MatSeqDataSet::MatSeqDataSet(char** files, int n_files, int n_inputs_ , int n_observations_,  int n_targets_, bool bin, int max_load)
{
  examples_from_stddataset = false;

  n_real_examples = 0;

  //!! WARNING! This is not very robust in the sens that 
  //!! it only looks at the dimension in the first file
  //!! and assumes that it's the same in the others...
  //!! 

  if((n_inputs_ == -1 && n_observations_ == -1) ||
      (n_inputs_ == -1 && n_targets_ == -1) ||
      (n_observations_ == -1 && n_targets_ == -1) ||
      (n_inputs_ == -1 && n_observations_ == -1 && n_targets_ == -1))
    error(" only one field at a time can have value -1 ");
  
  if(n_inputs_ == -1)
    n_inputs = columns(files[0], bin) - (n_observations_ + n_targets_);
  else
    n_inputs = n_inputs_;

  if(n_observations_ == -1)
    n_observations = columns(files[0], bin) - (n_inputs + n_targets_);
  else
    n_observations = n_observations_;

  if(n_targets_ == -1)
    n_targets = columns(files[0], bin) - (n_inputs + n_observations);
  else
    n_targets = n_targets_;

  examples = (SeqExample*)xalloc(sizeof(SeqExample) * n_files);
  file_names = (char**) xalloc(sizeof(char*)*n_files);
  n_file_names = 0;
  for(int i = 0;i < n_files;i++)
  {
    readMat(files[i], bin, max_load);
    n_real_examples++;
    n_file_names++;
  }
}


MatSeqDataSet::MatSeqDataSet(StdDataSet* data, int n_inputs_, int n_observations_)
{
  n_file_names = 0;
  examples_from_stddataset = true;
  n_real_examples = data->n_examples;
  examples = (SeqExample*)xalloc(data->n_examples * sizeof(SeqExample));

  if (n_inputs_ == -1 && n_observations == -1) 
    error("n_inputs and n_observations cannot be equal to -1 at the same time");
  if (n_inputs_ < 0) {
    n_observations = n_observations_;
    n_inputs = data->n_inputs - n_observations;
  } else if (n_observations_ < 0) {
    n_inputs = n_inputs_;
    n_observations = data->n_inputs - n_inputs;
  } else if (n_inputs_ + n_observations_ != data->n_inputs) {
    error("n_inputs (%d) + n_obs (%d) != data->n_inputs (%d)",n_inputs_,n_observations_,data->n_inputs);
  } else {
    n_inputs = n_inputs_;
    n_observations = n_observations_;
  }
  n_targets = data->n_targets;

  for(int example = 0; example < n_real_examples; example++)
  {
    int index = data->selected_examples[example];
    examples[example].n_frames = 1;
    examples[example].n_real_frames = 1;

    if(n_targets > 0)
    {
      examples[example].n_seqtargets = 1;
      examples[example].seqtargets = &data->all_targets[index];
    }
    else
    {
      examples[example].n_seqtargets = 0;
      examples[example].seqtargets = NULL;
    }

    examples[example].selected_frames = NULL;
    examples[example].name = NULL;
    examples[example].n_alignments = 0;
    examples[example].alignment = NULL;
    examples[example].alignment_phoneme = NULL;
   
    if(n_inputs > 0) {
      examples[example].inputs = (real**) xalloc(sizeof(real*));
      examples[example].inputs[0] = data->all_inputs[index];
    } else
      examples[example].inputs = NULL;

    if(n_observations > 0) {
      examples[example].observations = (real**) xalloc(sizeof(real*));
      examples[example].observations[0] = data->all_inputs[index] + n_inputs;
    } else
      examples[example].observations = NULL;

  }
  file_names = NULL;
}


MatSeqDataSet::~MatSeqDataSet()
{
  freeMemory();
}

int MatSeqDataSet::columns(char* file, bool bin)
{
 FILE* f = fopen(file, "r");
  if(!f)
    error("MatSeqDataSet: file %s cannot be read", file);
  
  int n_ex;
  int n;
  
  if(!bin)
    fscanf(f, "%d %d\n", &n_ex, &n);
  else
  {
    xfread(&n_ex, sizeof(int), 1, f);
    xfread(&n, sizeof(int), 1, f);
  }
  
  fclose(f);
  
  return n;
  
}

int MatSeqDataSet::lines(char* file, bool bin)
{
  
  FILE* f = fopen(file, "r");
  if(!f)
    error("MatSeqDataSet: file %s cannot be read", file);
  
  int n_ex;
  int n;
  
  if(!bin)
    fscanf(f, "%d %d\n", &n_ex, &n);
  else
  {
    xfread(&n_ex, sizeof(int), 1, f);
    xfread(&n, sizeof(int), 1, f);
  }
  
  fclose(f);
  
  return n_ex;
  
}

int MatSeqDataSet::lines(char** files, int n_files, bool bin)
{
  
  int sum = 0;
  
  for(int i = 0;i < n_files;i++)
    sum += lines(files [i], bin);
  
  return sum;
  
}


void MatSeqDataSet::readMat(char* file, bool bin, int max_load)
{

  FILE* f = fopen(file, "r");
  if(!f)
    error("MatSeqDataSet: file %s cannot be read", file);
  
  int n_ex;
  int n;
  
  if(!bin)
    fscanf(f, "%d %d\n", &n_ex, &n);
  else
  {
    xfread(&n_ex, sizeof(int), 1, f);
    xfread(&n, sizeof(int), 1, f);
  }
  
  int example_n_frames = max_load > 0 && max_load < n_ex ? max_load : n_ex;
  
  examples[n_real_examples].n_frames =  example_n_frames;
    
  //allocation of memory
    
  if(n_inputs > 0) {
    examples[n_real_examples].inputs = (real**)xalloc(example_n_frames * sizeof(real*));
    for(int j = 0;j < example_n_frames;j++)
      examples[n_real_examples].inputs[j] = (real*)xalloc(n_inputs * sizeof(real));
  } else
    examples[n_real_examples].inputs = NULL;
    
  if(n_observations > 0) {
    examples[n_real_examples].observations = (real**)xalloc(example_n_frames * sizeof(real*));
    for(int j = 0;j < example_n_frames;j++)
      examples[n_real_examples].observations[j] = (real*)xalloc(n_observations * sizeof(real));
  } else
    examples[n_real_examples].observations = NULL;
    
  if(n_targets > 0) {
    examples[n_real_examples].seqtargets = (real**)xalloc(sizeof(real*) * example_n_frames);
    for(int j = 0;j < example_n_frames;j++)
      examples[n_real_examples].seqtargets[j] = (real*)xalloc(n_targets * sizeof(real));
  } else
    examples[n_real_examples].seqtargets = NULL;
    
  //reading of examples

  for(int j = 0;j < example_n_frames;j++) {
    if(!bin) {
      for(int k = 0;k < n_inputs;k++)
        fscanf(f, REAL_FORMAT, &examples[n_real_examples].inputs[j][k]);

      for(int k = 0;k < n_observations;k++)
        fscanf(f, REAL_FORMAT, &examples[n_real_examples].observations[j][k]);

      for(int k = 0;k < n_targets;k++)
        fscanf(f, REAL_FORMAT, &examples[n_real_examples].seqtargets[j][k]);

    } else {
      xfread(&examples[n_real_examples].inputs[j], sizeof(real), n_inputs, f);
      xfread(&examples[n_real_examples].observations[j], sizeof(real), n_observations, f);
      xfread(&examples[n_real_examples].seqtargets[j], sizeof(real), n_targets, f);
    }
  }
    
  examples[n_real_examples].n_real_frames = example_n_frames;
  examples[n_real_examples].selected_frames = NULL;
  examples[n_real_examples].name = NULL;
  examples[n_real_examples].n_seqtargets = n_targets > 0 ? example_n_frames : 0;
  examples[n_real_examples].n_alignments = 0;
  examples[n_real_examples].alignment = NULL;
  examples[n_real_examples].alignment_phoneme = NULL;
  
  char* start = strBaseName(file);
  file_names[n_file_names] = (char*)xalloc(sizeof(char)*(strlen(start)+1));
  strcpy(file_names[n_file_names],start);
  examples[n_real_examples].name = file_names[n_file_names];
  
  fclose(f);
}

void MatSeqDataSet::freeMemory()
{
  
  for(int example = 0;example < n_real_examples;example++)
  {
    setExample(example);
    if(!examples_from_stddataset) {
      for (int i=0;i<examples[example].n_seqtargets;i++)
        free(examples[example].seqtargets[i]);
      if (examples[example].n_seqtargets > 0) {
        free(examples[example].seqtargets);
        examples[example].seqtargets = NULL;
      }
      for(int frame = 0;frame < examples[example].n_frames;frame++) {
        if(n_inputs > 0) 
          free(examples[example].inputs[frame]);
        
        if(n_observations > 0) 
          free(examples[example].observations[frame]);
      }
    }
  
    if (examples[example].inputs) {
      free(examples[example].inputs);
      examples[example].inputs = NULL;
    }
    if (examples[example].observations) {
      free(examples[example].observations);
      examples[example].observations = NULL;
    }
    if (examples[example].n_alignments>0) {
      free(examples[example].alignment);
      free(examples[example].alignment_phoneme);
      examples[example].alignment = NULL;
      examples[example].alignment_phoneme = NULL;
    }
    
    if(examples[example].selected_frames) {
      free(examples[example].selected_frames);
      examples[example].selected_frames = NULL;
    }
  }
  
  free(examples);
  for (int i = 0;i<n_file_names;i++){
     if(file_names[i] != NULL){
	free(file_names[i]);
	file_names[i] = NULL;
     }
  }
  free(file_names);
}

}

