// Copyright (C) 2002 Samy Bengio (bengio@idiap.ch)
//                
//
// This file is part of Torch. Release II.
// [The Ultimate Machine Learning Library]
//
// Torch is free software; you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation; either version 2 of the License, or
// (at your option) any later version.
//
// Torch is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Torch; if not, write to the Free Software
// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA


#include "HMM.h"
#include "log_add.h"

namespace Torch {

HMM::HMM(int n_states_, Distribution **states_, real prior_transitions_,SeqDataSet* data_,real** transitions_, Distribution** unique_states_, int n_unique_states_) : Distribution()
{
  n_states = n_states_;
  states = states_;
  n_observations = states[1]->n_observations;
  n_inputs = states[1]->n_inputs;
  prior_transitions = prior_transitions_;
  transitions = transitions_;
  data = data_;
  if (n_unique_states_ > 0 && unique_states_) {
    n_unique_states = n_unique_states_;
    unique_states = unique_states_;
  } else {
    n_unique_states = n_states_;
    unique_states = states_;
  }
  params = NULL;
  der_params = NULL;
}


void HMM::loadFILE(FILE *file)
{
  // first the transitions
  xfread(params->ptr, sizeof(real), params->n, file);
  for (int i=0;i<n_states;i++)
    log_transitions[i] = ((real*)params->ptr) + i*n_states;
  // then the emissions
  for (int i=1;i<n_unique_states-1;i++) {
    unique_states[i]->loadFILE(file);
  }
}

void HMM::saveFILE(FILE *file)
{
  // first the transitions
  xfwrite(params->ptr, sizeof(real), params->n, file);
  // then the emissions
  for (int i=1;i<n_unique_states-1;i++) {
    unique_states[i]->saveFILE(file);
  }
}

void HMM::allocateMemory()
{
  n_params = numberOfParams();
  log_transitions = (real**)xalloc(sizeof(real*)*n_states);
  dlog_transitions = (real**)xalloc(sizeof(real*)*n_states);
  transitions_acc = (real**)xalloc(sizeof(real*)*n_states);
  addToList(&params,n_states*n_states,(real*)xalloc(sizeof(real)*n_states*n_states));
  addToList(&der_params,n_states*n_states,(real*)xalloc(sizeof(real)*n_states*n_states));
  addToList(&outputs,n_outputs,(real*)xalloc(sizeof(real)*n_outputs));
  for (int i=1;i<n_unique_states-1;i++) {
    addToList(&params,unique_states[i]->params);
    addToList(&der_params,unique_states[i]->der_params);
  }
  for (int i=0;i<n_states;i++) {
    transitions_acc[i] = (real*)xalloc(sizeof(real)*n_states);;
    log_transitions[i] = ((real*)params->ptr) + i*n_states;
    dlog_transitions[i] = ((real*)der_params->ptr) + i*n_states;
  }
  // prepare the maximum memory
  // first find max_n_frames;
  max_n_frames = 3;
  if (data) {
    for (int i=0;i<data->n_examples;i++) {
      data->setExample(i);
      if (data->examples[data->current_example].n_real_frames+2 > max_n_frames)
        max_n_frames = data->examples[data->current_example].n_real_frames+2;
    }
  }
  log_probabilities_s = (real**)xalloc(sizeof(real*)*max_n_frames);
  log_alpha = (real**)xalloc(sizeof(real*)*max_n_frames);
  log_beta = (real**)xalloc(sizeof(real*)*max_n_frames);
  arg_viterbi = (int**)xalloc(sizeof(int*)*max_n_frames);
  viterbi_sequence = (int*)xalloc(sizeof(int)*max_n_frames);
  for (int i=0;i<max_n_frames;i++) {
    log_probabilities_s[i] = (real*)xalloc(sizeof(real)*n_states);
    log_alpha[i] = (real*)xalloc(sizeof(real)*n_states);;
    log_beta[i] = (real*)xalloc(sizeof(real)*n_states);;
    arg_viterbi[i] = (int*)xalloc(sizeof(int)*n_states);;
  }
}

void HMM::freeMemory()
{
  if (is_free)
    return;
  is_free = true;
  freeList(&outputs,true);
  // supposition: the first params are the local transitions...
  free(params->ptr);
  free(der_params->ptr);
  // the rest comes from the other distributions...
  freeList(&params,false);
  freeList(&der_params,false);
  for (int i=0;i<n_states;i++) {
    free(transitions_acc[i]);
  }
  for (int i=0;i<max_n_frames;i++) {
    free(log_probabilities_s[i]);
    free(log_alpha[i]);
    free(log_beta[i]);
    free(arg_viterbi[i]);
  }
  free(log_transitions);
  free(dlog_transitions);
  free(transitions_acc);
  free(log_alpha);
  free(log_beta);
  free(arg_viterbi);
  free(viterbi_sequence);
  free(log_probabilities_s);
}

int HMM::numberOfParams()
{
  // first the transitions
  int n = n_states * n_states;
  // then the emissions (except first and last)
  for (int i=1;i<n_unique_states-1;i++)
    n += unique_states[i]->numberOfParams();
  return n;
}

void HMM::reset()
{
  // the emission distributions
  for (int i=1;i<n_unique_states-1;i++)
    unique_states[i]->reset();
  // for the transitions, re-initialize to initial values given in constructor
  for (int i=0;i<n_states;i++) {
    real *p = transitions[i];
    real *lp = log_transitions[i];
    for (int j=0;j<n_states;j++,lp++,p++) {
      if (*p > 0)
        *lp = log(*p);
      else
        *lp = LOG_ZERO;
    }
  }
}

void HMM::printTransitions(bool real_values, bool transitions_only)
{
  print("transitions: %d x %d\n",n_states,n_states);
  for (int i=0;i<n_states;i++) {
    for (int j=0;j<n_states;j++) {
      if (transitions_only) {
        if (log_transitions[j][i] != LOG_ZERO) {
          print("%d -> %d = %f\n",i,j,exp(log_transitions[j][i]));
        }
      } else if (real_values) {
        print("%f ",exp(log_transitions[j][i]));
      } else {
        print("%d ",(log_transitions[j][i] != LOG_ZERO));
      }
    }
    print("\n");
  }
}

void HMM::logAlpha(SeqExample* ex)
{
  log_alpha[0][0] = LOG_ONE;
  for (int i=1;i<n_states;i++)
     log_alpha[0][i] = LOG_ZERO; 
  for (int f=1;f<=ex->n_frames;f++) {
    log_alpha[f][0] = LOG_ZERO;
    log_alpha[f][n_states-1] = LOG_ZERO;
    for (int i=1;i<n_states-1;i++) {
      log_alpha[f][i] = LOG_ZERO;
      for (int j=0;j<n_states-1;j++) {
        if (log_transitions[i][j] == LOG_ZERO)
          continue;
        log_alpha[f][i] = log_add(log_alpha[f][i],
          log_transitions[i][j] + log_probabilities_s[f][i] + 
          log_alpha[f-1][j]);
      }
    }
  }
  // last frame
  int f=ex->n_frames+1;
  for (int i=0;i<n_states;i++)
    log_alpha[f][i] = LOG_ZERO;
  int i=n_states-1;
  for (int j=0;j<n_states-1;j++) {
    if (log_transitions[i][j] == LOG_ZERO)
      continue;
    log_alpha[f][i] = log_add(log_alpha[f][i],
      log_transitions[i][j] + log_alpha[f-1][j]);
  }
}

void HMM::logBeta(SeqExample* ex)
{
  int f_final = ex->n_frames+1;
  for (int i=0;i<n_states-1;i++) {
    log_beta[f_final][i] = LOG_ZERO;
    log_probabilities_s[f_final][i] = LOG_ZERO;
  }
  log_beta[f_final][n_states-1] = LOG_ONE;
  log_probabilities_s[f_final][n_states-1] = LOG_ONE;
  for (int f=ex->n_frames;f>=0;f--) {
    log_beta[f][n_states-1] = LOG_ZERO;
    for (int i=0;i<n_states-1;i++) {
      log_beta[f][i] = LOG_ZERO;
      for (int j=1;j<n_states;j++) {
        if (log_transitions[j][i] == LOG_ZERO)
          continue;
        log_beta[f][i] = log_add(log_beta[f][i],
          log_transitions[j][i] + log_probabilities_s[f+1][j] + log_beta[f+1][j]);
      }
    }
  }
}

void HMM::logViterbi(SeqExample* ex)
{
  log_alpha[0][0] = LOG_ONE;
  for (int i=1;i<n_states;i++)
     log_alpha[0][i] = LOG_ZERO; 
  for (int f=1;f<=ex->n_frames;f++) {
    log_alpha[f][0] = LOG_ZERO;
    log_alpha[f][n_states-1] = LOG_ZERO;
    for (int i=1;i<n_states-1;i++) {
      log_alpha[f][i] = LOG_ZERO;
      for (int j=0;j<n_states-1;j++) {
        if (log_transitions[i][j] == LOG_ZERO)
          continue;
        real v = 
          log_transitions[i][j] + log_probabilities_s[f][i] + log_alpha[f-1][j];
        if (log_alpha[f][i] < v) {
          log_alpha[f][i] = v;
          arg_viterbi[f][i] = j;
        }
      }
    }
  }
  // last frame
  int f=ex->n_frames+1;
  for (int j=0;j<n_states;j++)
    log_alpha[f][j] = LOG_ZERO;
  int i=n_states-1;
  for (int j=1;j<n_states-1;j++) {
    if (log_transitions[i][j] == LOG_ZERO)
      continue;
    real v = log_transitions[i][j] + log_alpha[f-1][j];
    if (log_alpha[f][i] < v) {
      log_alpha[f][i] = v;
      arg_viterbi[f][i] = j;
    }
  }
  // now recall the state sequence
  viterbi_sequence[ex->n_frames+1] = n_states-1;
  for (int f=ex->n_frames;f>=0;f--) {
    viterbi_sequence[f] = arg_viterbi[f+1][viterbi_sequence[f+1]];
  }
}

void HMM::logProbabilities(List *inputs)
{
  SeqExample* ex = (SeqExample*)inputs->ptr;
  real* in = NULL;
  real* obs = NULL;
  for (int f=0;f<ex->n_frames;f++) {
    if (ex->inputs)
      in = ex->inputs[f];
    if (ex->observations)
      obs = ex->observations[f];
    log_probabilities_s[f+1][0] = LOG_ZERO;
    for (int i=1;i<n_states-1;i++)
      log_probabilities_s[f+1][i] = states[i]->frameLogProbability(obs,in,f);
    log_probabilities_s[f+1][n_states-1] = LOG_ZERO;
  }
}

real HMM::logProbability(List *inputs)
{
  SeqExample* ex = (SeqExample*)inputs->ptr;
  logProbabilities(inputs);
  logAlpha(ex);
  log_probability = log_alpha[ex->n_frames+1][n_states-1];
  return log_probability;
}

real HMM::viterbiLogProbability(List *inputs)
{
  SeqExample* ex = (SeqExample*)inputs->ptr;
  logProbabilities(inputs);
  logViterbi(ex);
  log_probability = log_alpha[ex->n_frames+1][n_states-1];
  return log_probability;
}

void HMM::eMSequenceInitialize(List* inputs)
{
  SeqExample* ex = (SeqExample*)inputs->ptr;
  if (ex->n_real_frames+2 > max_n_frames) {
    int old_max = max_n_frames;
    max_n_frames = ex->n_real_frames+2;
    log_probabilities_s = (real**)xrealloc(log_probabilities_s,sizeof(real*)*max_n_frames);
    log_alpha = (real**)xrealloc(log_alpha,sizeof(real*)*max_n_frames);
    log_beta = (real**)xrealloc(log_beta,sizeof(real*)*max_n_frames);
    arg_viterbi = (int**)xrealloc(arg_viterbi,sizeof(int*)*max_n_frames);
    viterbi_sequence = (int*)xrealloc(viterbi_sequence,sizeof(int)*max_n_frames);
    for (int i=old_max;i<max_n_frames;i++) {
      log_probabilities_s[i] = (real*)xalloc(sizeof(real)*n_states);
      log_alpha[i] = (real*)xalloc(sizeof(real)*n_states);
      log_beta[i] = (real*)xalloc(sizeof(real)*n_states);
      arg_viterbi[i] = (int*)xalloc(sizeof(int)*n_states);
    }
  }
  for (int i=1;i<n_unique_states-1;i++)
    unique_states[i]->eMSequenceInitialize(inputs);
}

void HMM::sequenceInitialize(List* inputs)
{
  SeqExample* ex = (SeqExample*)inputs->ptr;
  if (ex->n_real_frames+2 > max_n_frames) {
    int old_max = max_n_frames;
    max_n_frames = ex->n_real_frames+2;
    log_probabilities_s = (real**)xrealloc(log_probabilities_s,sizeof(real*)*max_n_frames);
    log_alpha = (real**)xrealloc(log_alpha,sizeof(real*)*max_n_frames);
    log_beta = (real**)xrealloc(log_beta,sizeof(real*)*max_n_frames);
    arg_viterbi = (int**)xrealloc(arg_viterbi,sizeof(int*)*max_n_frames);
    viterbi_sequence = (int*)xrealloc(viterbi_sequence,sizeof(int)*max_n_frames);
    for (int i=old_max;i<max_n_frames;i++) {
      log_probabilities_s[i] = (real*)xalloc(sizeof(real)*n_states);
      log_alpha[i] = (real*)xalloc(sizeof(real)*n_states);
      log_beta[i] = (real*)xalloc(sizeof(real)*n_states);
      arg_viterbi[i] = (int*)xalloc(sizeof(int)*n_states);
    }
  }
  for (int i=1;i<n_unique_states-1;i++)
    unique_states[i]->sequenceInitialize(inputs);
  for (int i=0;i<n_states;i++) {
    for (int j=0;j<n_states;j++) {
      dlog_transitions[i][j] = 0;
    }
  }
}

void HMM::eMIterInitialize()
{
  for (int i=1;i<n_unique_states-1;i++)
    unique_states[i]->eMIterInitialize();
  for (int i=0;i<n_states;i++)
    for (int j=0;j<n_states;j++)
      transitions_acc[i][j] = prior_transitions;
}

void HMM::iterInitialize()
{
  for (int i=1;i<n_unique_states-1;i++)
    unique_states[i]->iterInitialize();
  for (int i=0;i<n_states;i++)
    for (int j=0;j<n_states;j++)
      transitions_acc[i][j] = prior_transitions;
}

void HMM::eMAccPosteriors(List *inputs, real log_posterior)
{
  SeqExample* ex = (SeqExample*)inputs->ptr;
  real* in = NULL;
  real* obs = NULL;

  // compute the beta by backward recursion
  logBeta(ex);

  // accumulate the emission and transition posteriors
  for (int f=0;f<ex->n_frames;f++) {
    if (ex->inputs)
      in = ex->inputs[f];
    if (ex->observations)
      obs = ex->observations[f];
    for (int i=1;i<n_states-1;i++) {
      real log_posterior_i_f = log_posterior + log_alpha[f+1][i] + 
        log_beta[f+1][i] - log_probability;
      real log_emit_i = states[i]->log_probabilities[f];
      states[i]->frameEMAccPosteriors(obs,log_posterior_i_f,in,f);
      for (int j=0;j<n_states;j++) {
        if (log_transitions[i][j] == LOG_ZERO)
          continue;
        transitions_acc[i][j] += exp(log_posterior + log_alpha[f][j] + 
          log_transitions[i][j] + log_emit_i + log_beta[f+1][i] - log_probability);
      }
    }
  }
  // particular case of transitions to last state
  int f = ex->n_frames;
  int i = n_states-1;
  for (int j=0;j<n_states;j++) {
    if (log_transitions[i][j] == LOG_ZERO)
      continue;
    transitions_acc[i][j] += exp(log_posterior + log_alpha[f][j] + 
      log_transitions[i][j] + log_beta[f+1][i] - log_probability);
  }
}

void HMM::viterbiAccPosteriors(List *inputs, real log_posterior)
{
  SeqExample* ex = (SeqExample*)inputs->ptr;
  real* in = NULL;
  real* obs = NULL;

  // accumulate the emission and transition posteriors
  for (int f=0;f<ex->n_frames;f++) {
    int i = viterbi_sequence[f+1];
    if (ex->inputs)
      in = ex->inputs[f];
    if (ex->observations)
      obs = ex->observations[f];
    states[i]->frameEMAccPosteriors(obs,log_posterior,in,f);
    int j = arg_viterbi[f+1][i];
    transitions_acc[i][j] += log_posterior;
  }
}

void HMM::eMUpdate()
{
  // first the states
  for (int i=1;i<n_unique_states-1;i++) {
    unique_states[i]->eMUpdate();
  }
  // then the transitions;
  for (int i=0;i<n_states-1;i++) {
    real sum_trans_acc = 0;
    for (int j=0;j<n_states;j++) {
      if (log_transitions[j][i] == LOG_ZERO)
        continue;
      sum_trans_acc += transitions_acc[j][i];
    }
    real log_sum = log(sum_trans_acc);
    for (int j=0;j<n_states;j++) {
      if (log_transitions[j][i] == LOG_ZERO)
        continue;
      log_transitions[j][i] = log(transitions_acc[j][i]) - log_sum;
    }
  }
}

void HMM::backward(List *inputs, real *alpha)
{
  SeqExample* ex = (SeqExample*)inputs->ptr;
  real* in = NULL;
  real* obs = NULL;

  // compute the beta by backward recursion
  logBeta(ex);

  // accumulate the emission and transition posteriors
  for (int f=0;f<ex->n_frames;f++) {
    if (ex->inputs)
      in = ex->inputs[f];
    if (ex->observations)
      obs = ex->observations[f];
    for (int i=1;i<n_states-1;i++) {
      real posterior_i_f[1];
      posterior_i_f[0] = - *alpha * exp(log_alpha[f+1][i] + 
        log_beta[f+1][i] - log_probability);
      real log_emit_i = states[i]->log_probabilities[f];
      states[i]->frameBackward(obs,posterior_i_f,in,f);
      for (int j=0;j<n_states;j++) {
        if (log_transitions[i][j] == LOG_ZERO)
          continue;
        real posterior_i_j_f = - *alpha * exp(log_alpha[f][j] +
          log_transitions[i][j] + log_emit_i + log_beta[f+1][i] - log_probability);
        dlog_transitions[i][j] += posterior_i_j_f;
        for (int k=0;k<n_states;k++) {
          if (log_transitions[k][j] == LOG_ZERO)
            continue;
          dlog_transitions[k][j] -= posterior_i_j_f * exp(log_transitions[k][j]);
        }
      }
    }
  }
  // particular case of transitions to last state
  int f = ex->n_frames;
  int i = n_states-1;
  for (int j=0;j<n_states;j++) {
    if (log_transitions[i][j] == LOG_ZERO)
      continue;
    real posterior_i_j_f = - *alpha * exp(log_alpha[f][j] +
      log_transitions[i][j] + log_beta[f+1][i] - log_probability);
    dlog_transitions[i][j] += posterior_i_j_f;
    for (int k=0;k<n_states;k++) {
      if (log_transitions[k][j] == LOG_ZERO)
        continue;
      dlog_transitions[k][j] -= posterior_i_j_f * exp(log_transitions[k][j]);
    }
  }
}

void HMM::viterbiBackward(List *inputs, real *alpha)
{
  SeqExample* ex = (SeqExample*)inputs->ptr;
  real* in = NULL;
  real* obs = NULL;

  // accumulate the emission and transition posteriors
  for (int f=0;f<=ex->n_frames;f++) {
    int i = viterbi_sequence[f+1];
    if (f<ex->n_frames) {
      if (ex->inputs)
        in = ex->inputs[f];
      if (ex->observations)
        obs = ex->observations[f];
      states[i]->frameBackward(obs,alpha,in,f);
    }
    int j = arg_viterbi[f+1][i];
    dlog_transitions[i][j] -= *alpha;
    for (int k=0;k<n_states;k++) {
      if (log_transitions[k][j] == LOG_ZERO)
        continue;
      dlog_transitions[k][j] += *alpha * exp(log_transitions[k][j]);
    }
  }
}

void HMM::decode(List* input)
{
  SeqExample* ex = (SeqExample*)input->ptr;
  eMSequenceInitialize(input);
  logProbabilities(input);
  logViterbi(ex);
}

HMM::~HMM()
{
  freeMemory();
}

}

