#include "filter_mask.h"
#include "utils.h" // for TinyVector comparison

void FilterGenMask::init(){
  min.set_description("lower threshold");
  append_arg(min,"min");
  max.set_description("upper threshold");
  append_arg(max,"max");
}

bool FilterGenMask::process(Data<float,4>& data, Protocol& prot) const {

  TinyVector<int,4> datashape=data.shape();
  TinyVector<int,4> maskshape=data.shape();
  maskshape(timeDim)=1;

  Data<float,4> mask(maskshape); mask=1.0;
  TinyVector<int,4> maskindex;
  TinyVector<int,4> dataindex;
  for(int i=0; i<mask.size(); i++) {
    maskindex=mask.create_index(i);
    dataindex=maskindex;
    for(int irep=0; irep<datashape(timeDim); irep++) { // Condition must be met for all time steps
      dataindex(timeDim)=irep;
      float val=data(dataindex);
      if(val<min || val>max) mask(maskindex)=0.0;
    }
  }

  data.reference(mask);

//  data=where(Array<float,4>(data)>=min && Array<float,4>(data)<=max, float(1.0), float(0.0));
  return true;
}


///////////////////////////////////////////////////////////////////////////

void FilterUseMask::init() {

  fname.set_description("filename");
  append_arg(fname,"fname");
}


bool FilterUseMask::process(Data<float,4>& data, Protocol& prot) const {
  Log<Filter> odinlog(c_label(),"process");

  // Load external file
  Data<float,4> maskdata;
  if(maskdata.autoread(fname)<0) return false;
  TinyVector<int,4> maskshape=maskdata.shape();
  TinyVector<int,4> datashape=data.shape();

  maskshape(timeDim)=datashape(timeDim)=1;
  if(maskshape!=datashape) {
    ODINLOG(odinlog,errorLog) << "shape mismatch: " << maskshape << "!=" << datashape << STD_endl;
    return false;
  }

  fvector vals;
  for(int i=0; i<data.size(); i++) {
    TinyVector<int,4> index=data.create_index(i);
    float val=data(index);
    index(timeDim)=0;
    if(maskdata(index)) vals.push_back(val);
  }

  data.resize(1,vals.size(),1,1);
  data(0,Range::all(),0,0)=Data<float,1>(vals);

  return true;
}
