// Copyright (C) 2002 Johnny Mariethoz (Johnny.Mariethoz@idiap.ch)
//                and Samy Bengio (bengio@idiap.ch)
//                
//
// This file is part of Torch. Release II.
// [The Ultimate Machine Learning Library]
//
// Torch is free software; you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation; either version 2 of the License, or
// (at your option) any later version.
//
// Torch is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Torch; if not, write to the Free Software
// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA

#include "HtkSeqDataSet.h"
#include "string_utils.h"
#include "IOHtk.h"

namespace Torch {

HtkSeqDataSet::HtkSeqDataSet(char* file,int max_load)
{
  dict = NULL;
  n_per_frame = 125000;
  n_real_examples = 0;
  examples = NULL;
  n_targets = 0;
  n_inputs = 0;
  n_frames = 0;
  current_frame = 0;
  n_observations = 0;

  htk = (IOHtk**)xalloc(sizeof(IOHtk*));
  htk[0] = new IOHtk(file,max_load);
  kind = htk[0]->kind;
  n_observations = htk[0]->n_cols;
  examples = (SeqExample*) xalloc(sizeof(SeqExample));
  file_names = (char**) xalloc(sizeof(char*));
  n_real_examples = 1; 
  n_file_names = 1;
  prepareData();
}

HtkSeqDataSet::HtkSeqDataSet(char** files, int n_files, int max_load)
{
  dict = NULL;
  n_per_frame = 125000;
  n_real_examples = 0;
  examples = NULL;
  n_inputs = 0;
  n_frames = 0;
  current_frame = 0;
  n_observations = 0;

  examples = (SeqExample*)xalloc(sizeof(SeqExample) * n_files);
  file_names = (char**) xalloc(sizeof(char*)*n_files);
  htk = (IOHtk**)xalloc(sizeof(IOHtk*)*n_files);
  n_file_names = n_files;
  n_real_examples = n_files;
  for(int i = 0;i < n_files;i++)
    htk[i] = new IOHtk(files[i],max_load);
  kind = htk[0]->kind;
  n_observations = htk[0]->n_cols;
  prepareData();
}

HtkSeqDataSet::~HtkSeqDataSet()
{
  freeMemory();
}

void HtkSeqDataSet::write(char* dir_to_save){
	FILE* f = NULL;
	char* save_file_name;

	for(int i=0;i<n_file_names;i++){
		save_file_name = strConcat(2,dir_to_save,file_names[i]);
		if((f=fopen(save_file_name,"r")))
			error("file: %s exist!",save_file_name);
		free(save_file_name);
	}
	
	HTKhdr header;
	bool new_file = true;
  header.sampSize = n_observations * 4;
	header.sampPeriod = htk[0]->samp_period;
	header.sampKind = htk[0]->str2ParmKind(htk[0]->kind);
#ifdef USEDOUBLE
	float* temp = (float*)xalloc(n_observations*sizeof(float));
#endif
  for (int i=0;i<n_examples;i++) {
		setExample(i);
		SeqExample* ex = (SeqExample*)inputs->ptr;
		save_file_name = strConcat(2,dir_to_save,ex->name);
		if(!(f=fopen(save_file_name,"r"))){
			new_file = true;
			header.nSamples = ex->n_frames;
			f=fopen(save_file_name,"w");
			xfwrite(&header.nSamples,sizeof(long),1,f);
			xfwrite(&header.sampPeriod,sizeof(long),1,f);
			xfwrite(&header.sampSize,sizeof(short),1,f);
			xfwrite(&header.sampKind,sizeof(short),1,f);

		}else{
			fclose(f);
			f=fopen(save_file_name,"a+");
		}

		real** data = ex->observations;
		for(int j=0;j<ex->n_frames;j++){

#ifdef USEDOUBLE
			for(int k=0;k<n_observations;k++)
				temp[k] = (float)data[j][k];
			xfwrite(temp, sizeof(float), n_observations, f);
#else
			xfwrite(data[j], sizeof(real), n_observations, f);
#endif
		}
		if(!new_file){
			rewind(f);
			xfread(&header.nSamples,sizeof(long),1,f);
			header.nSamples += ex->n_frames; 
			rewind(f);
			xfwrite(&header.nSamples,sizeof(long),1,f);
		}
		new_file = false;
		free(save_file_name);
	}

#ifdef USEDOUBLE
			free(temp);
#endif
			fclose (f);
	//for(int i = 0;i < n_examples;i++)
	//	htk[i]->write(dir_to_save);
}


void HtkSeqDataSet::freeMemory()
{
  for(int i = 0;i < n_real_examples;i++)
  {
    if (examples[i].n_alignments > 0) {
      free(examples[i].alignment);
      examples[i].alignment = NULL;
    }
    if(examples[i].n_seqtargets > 0) {
      for (int j=0;j<examples[i].n_seqtargets;j++)
        free(examples[i].seqtargets[j]);
      free(examples[i].seqtargets);
      examples[i].seqtargets = NULL;
    }
    if(examples[i].n_real_frames > 0) {
      free(examples[i].observations);
    }

    if(examples[i].selected_frames) {
      free(examples[i].selected_frames);
      examples[i].selected_frames = NULL;
    }
  }
  free(examples);

  for (int i = 0;i<n_file_names;i++){
    delete htk[i];
  }
  free(htk);
  free(file_names);
}

int HtkSeqDataSet::findExample(char* name)
{
  int index = -1;
  // first make a local copy of name and strip it
  char* strip_name = (char*)xalloc(sizeof(char)*(strlen(name)+1));
  strcpy(strip_name,name);
  //then strip first characters to '*' 
  char* start = strip_name + strlen(strip_name)-5;
  while (*start != '*') start--;
  start++;
  //and last characters until '.' (extension)
  //strip_name[strlen(strip_name)-5]='\0';
  char* end = strip_name + strlen(strip_name) - 1;
  while (*end != '.') end--;
  *end = '\0';

  // now try to find equivalent name in dataset
  for (int i=0;i<n_examples;i++) {
    setExample(i);
    SeqExample* ex = (SeqExample*)inputs->ptr;
    // compare "start" to ex->name
    if (!strcmp(ex->name,start)) {
      index = i;
      break;
    }
  }
  free(strip_name);
  return index;
}

void HtkSeqDataSet::setDictionary(Dictionary* dict_)
{
  dict = dict_;
}

void HtkSeqDataSet::readTargets(char* file)
{ 
  if (!dict)
    error("cannot read targets without a dictionary");

  FILE *f=fopen(file,"r");
  if (!f)
    error("readTargets: file %s cannot be read",file);
  int line_len=300;
  char line[300];
  char word[300];
  // first line is the header: not intersting
  fgets(line,line_len,f);
  // then process each sentence at a time
  while (!feof(f)) {
    // first line of each sentence contains the name of the file
    fgets(line,line_len,f);
    if (feof(f))
      break;
    // find the corresponding example
    char name[300];
    sscanf(line,"%s",name);
    int index = findExample(name);
    if (index<0) {
      fgets(line,line_len,f);
      while (line[0] != '.' && !feof(f)) 
        fgets(line,line_len,f);
      continue;
    }
    SeqExample* ex = &examples[index];
    int *indices = (int*)xalloc(sizeof(int)*n_frames);

    // then, read the rest of the sentence until "."
    int n_targ = 0;
    fgets(line,line_len,f);
    while (line[0] != '.' && !feof(f)) {
      // get the word
      sscanf(line,"%s",word);
      // find it in the dictionary
      index = dict->findWord(word);
      if (index < 0)
        error("readTargets: cannot find word %s in dictionary",word);
      indices[n_targ++] = index;
      fgets(line,line_len,f);
    }
    // now adjust the seqtargets
    ex->n_seqtargets = n_targ;
    ex->seqtargets = (real**)xalloc(sizeof(real*)*n_targ);
    for (int i=0;i<n_targ;i++)
      ex->seqtargets[i] = (real*)xalloc(sizeof(real));
    for (int i=0;i<n_targ;i++) {
      ex->seqtargets[i][0] = (float)indices[i];
    }
    free(indices);
  }
  // finally, check that all examples have a seqtarget
  for (int i=0;i<n_examples;i++) {
    setExample(i);
    if (!examples[current_example].seqtargets)
      error("readTargets: cannot find targets for example named %s",
        examples[current_example].name);
  }
  fclose(f);
}

void HtkSeqDataSet::setNPerFrame(int n_per_frame_)
{
  n_per_frame = n_per_frame_;
}

void HtkSeqDataSet::readAlignments(char* file, bool needs_all_examples)
{
  if (!dict)
    error("cannot read alignments without a dictionary");

  FILE *f=fopen(file,"r");
  if (!f)
    error("readAlignments: file %s cannot be read",file);
  int line_len=300;
  char line[300];
  char phoneme[300];
  // first line is the header: not intersting
  fgets(line,line_len,f);
  // then process each sentence at a time
  while (!feof(f)) {
    // first line of each sentence contains the name of the file
    fgets(line,line_len,f);
    if (feof(f))
      break;
    // find the corresponding example
    char name[300];
    sscanf(line,"%s",name);
    int index = findExample(name);
    if (index<0) {
      fgets(line,line_len,f);
      while (line[0] != '.' && !feof(f)) 
        fgets(line,line_len,f);
      continue;
    }
    SeqExample* ex = &examples[index];
    int *ali = (int*)xalloc(ex->n_frames*sizeof(int));
    int *tf = (int*)xalloc(ex->n_frames*sizeof(int));

    // then, read the rest of the sentence until "."
    int n_align = 0;
    fgets(line,line_len,f);
    while (line[0] != '.' && !feof(f)) {
      // get the alignment information
      int time_from;
      int time_to;
      sscanf(line,"%d %d %s",&time_from,&time_to,phoneme);
      // add the time_frame to the alignments
      tf[n_align] = time_to / n_per_frame;
      int p_index = dict->findPhoneme(phoneme);
      if (p_index<0)
        error("cannot find phoneme %s of alignment in dictionary",phoneme);
      ali[n_align] = p_index;
      n_align++;
      fgets(line,line_len,f);
    }
    // now create the alignment
    if (ex->n_alignments > 0) {
      free(ex->alignment);
      free(ex->alignment_phoneme);
    }
    ex->n_alignments = n_align;
    ex->alignment = (int*)xalloc(sizeof(int)*n_align);
    ex->alignment_phoneme = (int*)xalloc(sizeof(int)*n_align);
    for (int i=0;i<n_align;i++) {
      ex->alignment[i] = tf[i];
      ex->alignment_phoneme[i] = ali[i];
    }
    free(ali);
    free(tf);
  }
  // finally, check that all examples have an alignment
  if (needs_all_examples) {
    for (int i=0;i<n_examples;i++) {
      setExample(i);
      if (!examples[current_example].alignment)
        error("readAlignment: cannot find alignment for example named %s",
          examples[current_example].name);
    }
  }
  fclose(f);
}

void HtkSeqDataSet::prepareData()
{
  for (int i=0;i<n_real_examples;i++) {
    int example_n_frames = htk[i]->n_lines;
    examples[i].n_frames = example_n_frames;
    examples[i].inputs = NULL;
    examples[i].observations = (real**)xalloc(example_n_frames * sizeof(real*));
    examples[i].seqtargets = NULL;
    for(int j=0;j<example_n_frames;j++){
      examples[i].observations[j] = &htk[i]->data[j*n_observations];
    }
    examples[i].n_real_frames = example_n_frames;
    examples[i].selected_frames = NULL;
    examples[i].n_seqtargets = 0;
    examples[i].seqtargets = NULL;
    examples[i].inputs = NULL;
    examples[i].n_alignments = 0;
    examples[i].alignment = NULL;
    file_names[i] = htk[i]->file_name;
    examples[i].name = htk[i]->file_name;
  }
}

void HtkSeqDataSet::createMaskFromParam(bool* mask)
{
  htk[0]->createMaskFromParam(mask);
}

}

