const char *help = "\
Mixture of MLP (c) Trebolloc & Co 2001\n";

#include "ConnectedMachine.h"
#include "Linear.h"
#include "FileDataSet.h"
#include "MseCriterion.h"
#include "Sigmoid.h"
#include "Tanh.h"
#include "MseMeasurer.h"
#include "TwoClassFormat.h"
#include "OneHotClassFormat.h"
#include "ClassMeasurer.h"
#include "SaturationMeasurer.h"
#include "StochasticGradient.h"
#include "GMTrainer.h"
#include "Mixer.h"
#include "Softmax.h"
#include "CmdLine.h"

using namespace Torch;

//======= The MLP-expert ================
class MLP : public ConnectedMachine
{
  public:
    Linear *cachees;
    Tanh *cachees_tanh;
    Linear *sorties;
    Tanh *sorties_tanh;
    
    MLP(int n_entrees, int n_cachees, int n_sorties);
    virtual ~MLP();
};

MLP::MLP(int n_entrees, int n_cachees, int n_sorties)
{
  cachees = new Linear(n_entrees, n_cachees);
  cachees_tanh = new Tanh(n_cachees);
  sorties = new Linear(n_cachees, n_sorties);
  sorties_tanh = new Tanh(n_sorties);

  cachees->init();
  cachees_tanh->init();
  sorties->init();
  sorties_tanh->init();

  addFCL(cachees);
  addFCL(cachees_tanh);
  addFCL(sorties);
  addFCL(sorties_tanh);
}

MLP::~MLP()
{
  delete cachees;
  delete cachees_tanh;
  delete sorties;
  delete sorties_tanh;
}
//=======================================

int main(int argc, char **argv)
{
  char *file;
  char *valid_file;
  int n_inputs;
  int n_targets;
  int n_hu;
  int max_load;
  real accuracy;
  real learning_rate;
  real decay;
  int max_iter;
  bool regression;
  char *file_model, *test_model;
  int n_experts;
  int n_hug;
  int the_seed;

  CmdLine cmd;

  cmd.info(help);

  cmd.addText("\nArguments:");
  cmd.addSCmdArg("file", &file, "the train or test file");
  cmd.addICmdArg("n_inputs", &n_inputs, "input dimension of the data");
  cmd.addICmdArg("n_targets", &n_targets, "output dimension of the data");

  cmd.addText("\nModel Options:");
  cmd.addICmdOption("-nhu", &n_hu, 25, "number of hidden units for experts");
  cmd.addICmdOption("-nhug", &n_hug, 25, "number of hidden units for gater");
  cmd.addBCmdOption("-rm", &regression, false, "regression mode");
  cmd.addICmdOption("-ne", &n_experts, 10, "number of experts");

  cmd.addText("\nLearning Options:");
  cmd.addICmdOption("-seed", &the_seed, -1, "c'est *the seed* mec");
  cmd.addICmdOption("-iter", &max_iter, 25, "max number of iterations");
  cmd.addRCmdOption("-lr", &learning_rate, 0.01, "learning rate");
  cmd.addRCmdOption("-e", &accuracy, 0.00001, "end accuracy");
  cmd.addRCmdOption("-lrd", &decay, 0, "learning rate decay");

  cmd.addText("\nMisc Options:");
  cmd.addICmdOption("-load", &max_load, -1, "max number of examples to load");
  cmd.addSCmdOption("-valid", &valid_file, "", "validation file");
  cmd.addSCmdOption("-sm", &file_model, "", "file to save the model");
  cmd.addSCmdOption("-test", &test_model, "", "model file to test");

  cmd.read(argc, argv);

  if(the_seed == -1)
    seed();
  else
    manual_seed((long)the_seed);

  ConnectedMachine mixture;

  // ============ The experts ===============
  ConnectedMachine experts;

  MLP **mlp = new MLP *[n_experts];
  for(int i = 0; i < n_experts; i++)
  {
    mlp[i] = new MLP(n_inputs, n_hu, n_targets);
    mlp[i]->init();
    experts.addMachine(mlp[i]);
  }

  experts.init();

  // ============ The gater =================
  Linear mixinl(n_inputs, n_hug);
  Tanh mixint(n_hug);
  Linear mixoutl(n_hug, n_experts);
  Softmax mixouts(n_experts);
  mixouts.setBOption("compute shift", true);

  mixinl.init();
  mixint.init();
  mixoutl.init();
  mixouts.init();

  //============== The rest =================
  mixture.addMachine(&mixinl);
  mixture.addMachine(&experts);
  mixture.addLayer();
  mixture.addMachine(&mixint);
  mixture.connectOn(&mixinl);
  mixture.addLayer();
  mixture.addMachine(&mixoutl);
  mixture.connectOn(&mixint);
  mixture.addLayer();
  mixture.addMachine(&mixouts);
  mixture.connectOn(&mixoutl);
  mixture.addLayer();

  Mixer mix(n_experts, n_targets);
  mix.init();
  mixture.addMachine(&mix);
  mixture.connectOn(&mixouts);
  mixture.connectOn(&experts);
  mixture.init();

  //=============== The datas ===============
  FileDataSet data(file, n_inputs, n_targets, false, max_load);
  data.setBOption("normalize inputs", true);
  data.init();

  FileDataSet *valid_data = NULL;
  if( strcmp(valid_file, "") )
  {
    valid_data = new FileDataSet(valid_file, n_inputs, n_targets);
    valid_data->init();
    valid_data->normalizeUsingDataSet(&data);
  }

  //=============== The measurers ===========

  // The class format
  ClassFormat *class_format = NULL;
  if(!regression)
  {
    if(n_targets == 1)
      class_format = new TwoClassFormat(&data);
    else
      class_format = new OneHotClassFormat(&data);
  }

  // On the train data...
  List *measurers = NULL;
  MseMeasurer *mse_meas = new MseMeasurer(mixture.outputs, &data, "the_mse_soft");
  mse_meas->init();
  addToList(&measurers, 1, mse_meas);

  ClassMeasurer *class_meas = NULL;
  if(!regression)
  {
    class_meas = new ClassMeasurer(mixture.outputs, &data, class_format, "the_class_err_soft");
    class_meas->init();
    addToList(&measurers, 1, class_meas);
  }

  // On the validation data...
  MseMeasurer *valid_mse_meas = NULL;
  ClassMeasurer *valid_class_meas = NULL;
  if( strcmp(valid_file, "") )
  {
    valid_mse_meas = new MseMeasurer(mixture.outputs, valid_data, "the_valid_mse_soft");
    valid_mse_meas->init();
    addToList(&measurers, 1, valid_mse_meas);

    if(!regression)
    {
      valid_class_meas = new ClassMeasurer(mixture.outputs, valid_data, class_format, "the_valid_class_err_soft");
      valid_class_meas->init();
      addToList(&measurers, 1, valid_class_meas);
    }
  }

  //=============== The trainer =============
  MseCriterion mse(n_targets);
  mse.init();

  StochasticGradient opt;
  opt.setIOption("max iter", max_iter);
  opt.setROption("end accuracy", accuracy);
  opt.setROption("learning rate", learning_rate);
  opt.setROption("learning rate decay", decay);
  GMTrainer trainer(&mixture, &data, &mse, &opt);

  message("Number of parameters: %d", mixture.n_params);

  if( strcmp(test_model, "") )
  {
    trainer.load(test_model);
    trainer.test(measurers);
  }
  else
  {
    trainer.train(measurers);
    
    if( strcmp(file_model, "") )
      trainer.save(file_model);
  }

  // Destroy all

  for(int i = 0; i < n_experts; i++)
    delete mlp[i];
  delete[] mlp;

  if(strcmp(valid_file, ""))
  {
    delete valid_data;
    delete valid_mse_meas;
    if(!regression)
      delete valid_class_meas;
  }

  delete mse_meas;
  if(!regression)
  {
    delete class_meas;
    delete class_format;
  }

  freeList(&measurers);

  return(0);
}
