/*************************************************************************/
/*                                                                       */
/*                Centre for Speech Technology Research                  */
/*                     University of Edinburgh, UK                       */
/*                      Copyright (c) 1996,1997                          */
/*                        All Rights Reserved.                           */
/*                                                                       */
/*  Permission to use, copy, modify and distribute this software and its */
/*  documentation for research, educational and individual use only, is  */
/*  hereby granted without fee, subject to the following conditions:     */
/*   1. The code must retain the above copyright notice, this list of    */
/*      conditions and the following disclaimer.                         */
/*   2. Any modifications must be clearly marked as such.                */
/*   3. Original authors' names are not deleted.                         */
/*  This software may not be used for commercial purposes without        */
/*  specific prior written permission from the authors.                  */
/*                                                                       */
/*  THE UNIVERSITY OF EDINBURGH AND THE CONTRIBUTORS TO THIS WORK        */
/*  DISCLAIM ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING      */
/*  ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN NO EVENT   */
/*  SHALL THE UNIVERSITY OF EDINBURGH NOR THE CONTRIBUTORS BE LIABLE     */
/*  FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES    */
/*  WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN   */
/*  AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION,          */
/*  ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF       */
/*  THIS SOFTWARE.                                                       */
/*                                                                       */
/*************************************************************************/
/*                     Author :  Alan W Black                            */
/*                     Date   :  May 1996                                */
/*-----------------------------------------------------------------------*/
/*  A Classification and Regression Tree (CART) Program                  */
/*  A basic implementation of many of the techniques in                  */
/*  Briemen et al. 1984                                                  */
/*                                                                       */
/*  Added decision list support, Feb 1997                                */
/*                                                                       */
/*=======================================================================*/

#include <stdlib.h>
#include <iostream.h>
#include <fstream.h>
#include <string.h>
#include "EST.h"
#include "EST_Wagon.h"

static void do_summary(WNode &tree,WDataSet &ds);
static void test_tree_float(WNode &tree,WDataSet &ds);
static void test_tree_class(WNode &tree,WDataSet &ds);
static void test_tree_cluster(WNode &tree,WDataSet &dataset);
static int wagon_split(WNode &node);
static WQuestion find_best_question(WVectorVector &dset);
static void construct_binary_ques(int feat,WQuestion &test_ques);
static float construct_float_ques(int feat,WQuestion &ques,WVectorVector &ds);
static float construct_class_ques(int feat,WQuestion &ques,WVectorVector &ds);
static void wgn_set_up_data(WVectorVector &data,const WVectorList &ds,int held_out,int in);

static int margin;

#if defined(INSTANTIATE_TEMPLATES)
// Instantiate class
#include "../base_class/EST_TList.cc"
template class EST_TList<WVector *>;
template class EST_TItem<WVector *>;
#include "../base_class/EST_TVector.cc"
template class EST_TVector<WVector *>;
#endif

void wgn_load_datadescription(EST_String fname)
{
    // Load field description for a file
    wgn_dataset.init(fname);
}

void wgn_load_dataset(EST_String fname)
{
    // Read the data set from a filename.  One vector per line
    // Assume all numbers are numbers and non-nums are categorical
    EST_TokenStream ts;
    WVector *v;
    int nvec=0,i;

    if (ts.open(fname) == -1)
	wagon_error(EST_String("unable to open data file \"")+
		    fname+"\"");
    ts.set_PunctuationSymbols("");
    ts.set_PrePunctuationSymbols("");
    ts.set_SingleCharSymbols("");

    for ( ;!ts.eof(); )
    {
	v = new WVector(wgn_dataset.width());
	i = 0;
	do 
	{
	    int type = wgn_dataset.dtype(i);
	    if (type == wndt_float)
		v->set_flt_val(i,atof(ts.get().string()));
	    else if (type == wndt_binary)
		v->set_int_val(i,atoi(ts.get().string()));
	    else if (type == wndt_cluster)
		v->set_int_val(i,atoi(ts.get().string()));
	    else if (type == wndt_ignore)
	    {
		ts.get();  // skip it
		v->set_int_val(i,0);
	    }
	    else // should check the different classes 
	    {
		EST_String s = ts.get().string();
		int n = wgn_discretes.discrete(type).name(s); 
		if (n == -1)
		{
		    cout << "Bad value " << s << " in field " <<
			wgn_dataset.feat_name(i) << " vector " << 
			    wgn_dataset.samples() << endl;
		    n = 0;
		}
		v->set_int_val(i,n);
	    }
	    i++;
	}
	while (!ts.eoln() && i<wgn_dataset.width());
	nvec ++;
	if (i != wgn_dataset.width())
	{
	    wagon_error(EST_String("Data vector contains ")+itoString(i)+
			" parameters instead of "+
			itoString(wgn_dataset.width()));
	}
	if (!ts.eoln())
	{
	    cerr << "Data vector " << nvec << 
		" contains too many parameters instead of " 
		<< wgn_dataset.width() << endl;
	    wagon_error(EST_String("extra parameter(s) from ")+
			ts.peek().string());
	}
	wgn_dataset.append(v);
    }

    cout << "Dataset of " << wgn_dataset.samples() << " vectors of " <<
	wgn_dataset.width() << " parameters " << endl;
    ts.close();
}

static void summary_results(WNode &tree)
{
    if (wgn_test_file != "")  // load in test data
    {
//	WDataSet ts;
	wgn_load_datadescription(wgn_desc_file);
	wgn_load_dataset(wgn_test_file);
    }

    do_summary(tree,wgn_dataset);
}

static void do_summary(WNode &tree,WDataSet &ds)
{
    if (wgn_dataset.dtype(0) == wndt_cluster)
	test_tree_cluster(tree,ds);
    else if (wgn_dataset.dtype(0) >= wndt_class)
	test_tree_class(tree,ds);
    else
	test_tree_float(tree,ds);

}

void wgn_build_tree()
{
    // Build init node and split it while reducing the impurity
    WNode *top = new WNode();

    wgn_set_up_data(top->get_data(),wgn_dataset.get_data(),wgn_held_out,TRUE);

    margin = 0;
    wagon_split(*top);  // recursively split data;

    if (wgn_held_out > 0)
    {
	wgn_set_up_data(top->get_data(),wgn_dataset.get_data(),
			wgn_held_out,FALSE);
	top->held_out_prune();
    }
	
    if (wgn_prune)
	top->prune();

    *wgn_coutput << *top;

    summary_results(*top);
    
}

static void wgn_set_up_data(WVectorVector &data,const WVectorList &ds,int held_out,int in)
{
    // Set data ommitting held_out percent if in is true
    // or only including 100-held_out percent if in is false
    int i,j;
    EST_TBI *d;

    if (!in)
	cout << "Doing held out data stage\n";
    // Make it definitely big enough
    data.resize(ds.length());
    
    for (j=i=0,d=ds.head(); d != 0; d=next(d),j++)
    {
	if ((in) && ((j%100) >= held_out))
	    data(i++) = ds(d);
//	else if ((!in) && ((j%100 < held_out)))
//	    data(i++) = ds(d);
	else if (!in)
	    data(i++) = ds(d);
//	if ((in) && (j < held_out))
//	    data(i++) = ds(d);	    
//	else if ((!in) && (j >=held_out))
//	    data(i++) = ds(d);	    
    }
    // make it the actual size
    data.resize(i);
}

static void test_tree_class(WNode &tree,WDataSet &dataset)
{
    // Test tree against data to get summary of results
    EST_StrStr_KVL pairs;
    EST_StrList lex;
    EST_TBI *p;
    WVectorList ds;
    EST_String predict,real;
    int i,type;

    ds = dataset.get_data();
    for (p=ds.head(); p != 0; p=next(p))
    {
	predict = tree.predict((*ds(p)));
	type = dataset.dtype(0);
	real = wgn_discretes[type].name(ds(p)->get_int_val(0));
	pairs.add_item(real,predict,1);
    }
    for (i=0; i<wgn_discretes[dataset.dtype(0)].size(); i++)
	lex.append(wgn_discretes[dataset.dtype(0)].name(i));

    const EST_FMatrix &m = confusion(pairs,lex);
    print_confusion(m,pairs,lex);
    
}

static void test_tree_cluster(WNode &tree,WDataSet &dataset)
{
    // Test tree against data to get summary of results for cluster trees
    WVectorList ds;
    WNode *leaf;
    int real;
    int right_cluster=0;
    EST_SuffStats ranking, meandist;
    EST_TBI *p;

    ds = dataset.get_data();
    for (p=ds.head(); p != 0; p=next(p))
    {
	leaf = tree.predict_node((*ds(p)));
	real = ds(p)->get_int_val(0);
	meandist += leaf->get_impurity().cluster_distance(real);
	right_cluster += leaf->get_impurity().in_cluster(real);
	ranking += leaf->get_impurity().cluster_ranking(real);
    }

    // Want number in right class, mean distance in sds, mean ranking
    if (wgn_coutput != &cout)   // save in output file
	*wgn_coutput << ";; Right cluster " << right_cluster << " (" <<
	    (int)(100.0*(float)right_cluster/(float)ds.length()) << 
		"%) mean ranking " << ranking.mean() << " mean distance "
		    << meandist.mean() << endl;
    cout << "Right cluster " << right_cluster << " (" <<
	(int)(100.0*(float)right_cluster/(float)ds.length()) << 
	    "%) mean ranking " << ranking.mean() << " mean distance "
		<< meandist.mean() << endl;


}

static void test_tree_float(WNode &tree,WDataSet &dataset)
{
    // Test tree against data to get summary of results FLOAT
    EST_TBI *p;
    WVectorList ds;
    float predict,real;
    EST_SuffStats x,y,xx,yy,xy,se,e;
    double cor,error;

    ds = dataset.get_data();
    for (p=ds.head(); p != 0; p=next(p))
    {
	predict = tree.predict((*ds(p)));
	real = ds(p)->get_flt_val(0);
	x += predict;
	y += real;
	error = predict-real;
	se += error*error;
	e += fabs(error);
	xx += predict*predict;
	yy += real*real;
	xy += predict*real;
    }

    cor = (xy.mean() - (x.mean()*y.mean()))/
	(sqrt(xx.mean()-(x.mean()*x.mean())) *
	 sqrt(yy.mean()-(y.mean()*y.mean())));

    if (wgn_coutput != &cout)   // save in output file
	*wgn_coutput 
	    << ";; RMSE " << ftoString(sqrt(se.mean()),4,1)
	    << " Correlation is " << ftoString(cor,4,1)
	    << " Mean (abs) Error " << ftoString(e.mean(),4,1)
	    << " (" << ftoString(e.stddev(),4,1) << ")" << endl;
	
    cout << "RMSE " << ftoString(sqrt(se.mean()),4,1)
	<< " Correlation is " << ftoString(cor,4,1)
	<< " Mean (abs) Error " << ftoString(e.mean(),4,1)
	<< " (" << ftoString(e.stddev(),4,1) << ")" << endl;
    
}

static int wagon_split(WNode &node)
{
    // Split given node (if possible)
    WQuestion q;
    WNode *l,*r;
    int i;

    node.set_impurity(WImpurity(node.get_data()));
    q = find_best_question(node.get_data());

    if (q.get_score() < node.get_impurity().measure())
    {
	// Ok its worth a split
	l = new WNode();
	r = new WNode();
	wgn_find_split(q,node.get_data(),l->get_data(),r->get_data());
	node.set_subnodes(l,r);
	node.set_question(q);
	if (!wgn_quiet)
	{
	    for (i=0; i < margin; i++)
		cout << " ";
	    cout << q << endl;
	}
	margin++;
	wagon_split(*l);
	margin++;
	wagon_split(*r);
	margin--;
	return TRUE;
    }
    else
    {
	if (!wgn_quiet)
	{
	    for (i=0; i < margin; i++)
		cout << " ";
	    cout << "stopped samples: " << node.samples() << " impurity: " 
		<< node.get_impurity() << endl;
	}
	margin--;
	return FALSE;
    }
}

void wgn_find_split(WQuestion &q,WVectorVector &ds,
		    WVectorVector &y,WVectorVector &n)
{
    int i, iy, in;

    y.resize(q.get_yes());
    n.resize(q.get_no());
    
    for (iy=in=i=0; i < ds.size(); i++)
	if (q.ask(*ds(i)) == TRUE)
	    y(iy++) = ds(i);
	else
	    n(in++) = ds(i);

}

static WQuestion find_best_question(WVectorVector &dset)
{
    //  Ask all possible questions and find the best one
    int i;
    float bscore,tscore;
    WQuestion test_ques, best_ques;

    bscore = tscore = WGN_HUGE_VAL;
    best_ques.set_score(bscore);
    // test each feature with each possible question
    for (i=1;i < wgn_dataset.width(); i++)
    {
	if (wgn_dataset.dtype(i) == wndt_binary)
	{
	    construct_binary_ques(i,test_ques);
	    tscore = wgn_score_question(test_ques,dset);
	}
	else if (wgn_dataset.dtype(i) == wndt_float)
	{
	    tscore = construct_float_ques(i,test_ques,dset);
	}
	else if (wgn_dataset.dtype(i) == wndt_ignore)
	    tscore = WGN_HUGE_VAL;
	else if (wgn_csubset && (wgn_dataset.dtype(i) >= wndt_class))
	{
	    wagon_error("subset selection temporarily deleted");
//	    tscore = construct_class_ques_subset(i,test_ques,dset);
	}
	else if (wgn_dataset.dtype(i) >= wndt_class)
	    tscore = construct_class_ques(i,test_ques,dset);
	if (tscore < bscore)
	{
	    best_ques = test_ques;
	    best_ques.set_score(tscore);
	    bscore = tscore;
	}
    }

    return best_ques;
}

static float construct_class_ques(int feat,WQuestion &ques,WVectorVector &ds)
{
    // Find out which member of a class gives the best split
    float tscore,bscore = WGN_HUGE_VAL;
    int cl;
    WQuestion test_q;

    test_q.set_fp(feat);
    test_q.set_oper(wnop_is);
    ques = test_q;
    
    for (cl=0; cl < wgn_discretes[wgn_dataset.dtype(feat)].size(); cl++)
    {
	test_q.set_operand1(EST_Val(cl));
	tscore = wgn_score_question(test_q,ds);
	if (tscore < bscore)
	{
	    ques = test_q;
	    bscore = tscore;
	}
    }

    return bscore;
}

#if 0
static float construct_class_ques_subset(int feat,WQuestion &ques,
					 WVectorVector &ds)
{
    // Find out which subset of a class gives the best split.
    // We first measure the subset of the data for each member of 
    // of the class.  Then order those splits.  Then go through finding
    // where the best split of that ordered list is.  This is described
    // on page 247 of Breiman et al.
    float tscore,bscore = WGN_HUGE_VAL;
    LISP l;
    int cl;

    ques.set_fp(feat);
    ques.set_oper(wnop_is);
    float *scores = new float[wgn_discretes[wgn_dataset.dtype(feat)].size()];
    
    // Only do it for exists values
    for (cl=0; cl < wgn_discretes[wgn_dataset.dtype(feat)].size(); cl++)
    {
	ques.set_operand(flocons(cl));
	scores[cl] = wgn_score_question(ques,ds);
    }

    LISP order = sort_class_scores(feat,scores);
    if (order == NIL)
	return WGN_HUGE_VAL;
    if (siod_llength(order) == 1)
    {   // Only one so we know the best "split"
	ques.set_oper(wnop_is);
	ques.set_operand(car(order));
	return scores[(int)FLONM(car(order))];
    }

    ques.set_oper(wnop_in);
    LISP best_l = NIL;
    for (l=cdr(order); CDR(l) != NIL; l = cdr(l))
    {
	ques.set_operand(l);
	tscore = wgn_score_question(ques,ds);
	if (tscore < bscore)
	{
	    best_l = l;
	    bscore = tscore;
	}

    }

    if (best_l != NIL)
    {
	if (siod_llength(best_l) == 1)
	{
	    ques.set_oper(wnop_is);
	    ques.set_operand(car(best_l));
	}
	else if (equal(cdr(order),best_l) != NIL)
	{
	    ques.set_oper(wnop_is);
	    ques.set_operand(car(order));
	}
	else
	{
	    cout << "Found a good subset" << endl;
	    ques.set_operand(best_l);
	}
    }
    return bscore;
}

static LISP sort_class_scores(int feat,float *scores)
{
    // returns sorted list of (non WGN_HUGE_VAL) items
    int i;
    LISP items = NIL;
    LISP l;

    for (i=0; i < wgn_discretes[wgn_dataset.dtype(feat)].size(); i++)
    {
	if (scores[i] != WGN_HUGE_VAL)
	{
	    if (items == NIL)
		items = cons(flocons(i),NIL);
	    else
	    {
		for (l=items; l != NIL; l=cdr(l))
		{
		    if (scores[i] < scores[(int)FLONM(car(l))])
		    {
			CDR(l) = cons(car(l),cdr(l));
			CAR(l) = flocons(i);
			break;
		    }
		}
		if (l == NIL)
		    items = l_append(items,cons(flocons(i),NIL));
	    }
	}
    }
    return items;
}
#endif

static float construct_float_ques(int feat,WQuestion &ques,WVectorVector &ds)
{
    // Find out a split of the range that gives the best score 
    // Naively does this by paritioning the range into float_range_split slots
    float tscore,bscore = WGN_HUGE_VAL;
    int d;
    float p;
    WQuestion test_q;
    float max,min,val,incr;

    test_q.set_fp(feat);
    test_q.set_oper(wnop_lessthan);
    ques = test_q;

    min = max = ds(0)->get_flt_val(feat);  /* set up some value */
    for (d=0; d < ds.size(); d++)
    {
	val = ds(d)->get_flt_val(feat);
	if (val < min)
	    min = val;
	else if (val > max)
	    max = val;
    }
    if (max == min)  // we're pure
	return WGN_HUGE_VAL;
    incr = (max-min)/wgn_float_range_split;  
    // so do float_range-1 splits
    for (p=min+incr; p <= max; p += incr )
    {
	test_q.set_operand1(EST_Val(p));
	tscore = wgn_score_question(test_q,ds);
	if (tscore < bscore)
	{
	    ques = test_q;
	    bscore = tscore;
	}
    }

    return bscore;
}

static void construct_binary_ques(int feat,WQuestion &test_ques)
{
    // construct a question.  Not sure about this in general
    // of course continuous/categorical features will require different
    // rule and non-binary ones will require some test point

    test_ques.set_fp(feat);
    test_ques.set_oper(wnop_binary);
    test_ques.set_operand1(EST_Val(""));
}

static float score_question_set(WQuestion &q, WVectorVector &ds, int ignorenth)
{
    // score this question as a possible split by finding
    // the sum of the impurities when ds is split with this question
    WImpurity y,n;
    int d;
    WVector *wv;

    for (d=0; d < ds.size(); d++)
    {
	if ((ignorenth < 2) ||
	    (d%ignorenth != ignorenth-1))
	{
	    wv = ds(d);
	    if (q.ask(*wv) == TRUE)
		y.cumulate((*wv)[0]);
	    else
		n.cumulate((*wv)[0]);
	}
    }

    q.set_yes(y.samples());
    q.set_no(n.samples());

    int min_cluster;

    if ((wgn_balance == 0.0) ||
	(ds.size()/wgn_balance < wgn_min_cluster_size))
	min_cluster = wgn_min_cluster_size;
    else 
	min_cluster = (int)(ds.size()/wgn_balance);

    if ((y.samples() < min_cluster) ||
	(n.samples() < min_cluster))
	return WGN_HUGE_VAL;

    if (y.measure() == 0.0)
	return n.measure();
    else if (n.measure() == 0.0)
       return y.measure();
    else    
	return (y.measure() + n.measure())/2.0;

}

float wgn_score_question(WQuestion &q, WVectorVector &ds)
{
    // This isn't v-fold cross validation !! Have to think more
    // about this
    EST_SuffStats a;
    int v;
    float best;

    for (v=0; v < wgn_vfold; v++)
    {
	best = score_question_set(q,ds,v);
	a+=best;
    }

    if (wgn_vfold > 2)
	best = score_question_set(q,ds,1);

    return a.mean();
}
