import java.util.*;

public class DataProcessor{

    public ArrayList<double[][]> splitTrainTest(double[][] dataset,
						int testStartIDx, int testEndIDx){
	
	ArrayList<double[][]> output = new ArrayList<double[][]>();
	int testRecordCount = testEndIDx  - testStartIDx;
	int trainRecordCount = dataset.length - testRecordCount;
	double[][] testset =  new double[testRecordCount][dataset[0].length];
	double[][] trainset = new double[trainRecordCount][dataset[0].length];
	setFoldData(testset, dataset, testStartIDx);
	int testRowCounter = 0; int trainRowCounter = 0;
	for(int rowCounter = 0 ; rowCounter < dataset.length; rowCounter++){
	    for(int colCounter = 0; colCounter < dataset[0].length; colCounter++){
		if((rowCounter >= testStartIDx) && (rowCounter < testEndIDx)){
		    testset[testRowCounter][colCounter] = 
			dataset[rowCounter][colCounter];
		} else{
		    trainset[trainRowCounter][colCounter] = 
			dataset[rowCounter][colCounter];
		}// if condition ends here
	    }// column loop
	    if((rowCounter >= testStartIDx) && (rowCounter <= testEndIDx)){
		testRowCounter++;
	    } else {
		trainRowCounter++;
	    }
	}// rows loop
	output.add(trainset);
	output.add(testset);
	return output;
    }

    public ArrayList<double[][]> getKfolds(double[][] dataset, int k){
	ArrayList<double[][]> output = new ArrayList<double[][]>();
	if(k <= 0){
	    output.add(dataset);
	    return output;
	}
	if(dataset.length == 0){
	    output.add(dataset);
	    return output;
	}
	int allFoldsRecordCount = 0; 
	int foldRecordCount = dataset.length / k;
	int colCount = dataset[0].length;
	Log.write("Each fold will have: "+ foldRecordCount);
	//iterate over first k-1 folds to get exact size
	for(int foldCounter = 0; foldCounter < k-1; foldCounter++){
	    double[][] fold = new double[foldRecordCount][colCount];
	    setFoldData(fold, dataset, foldCounter * foldRecordCount);
	    allFoldsRecordCount += foldRecordCount;
	    Log.write("Records read so far : "+ allFoldsRecordCount);
	    output.add(fold);
	}
	// put remaining lines in the last fold
	if(allFoldsRecordCount != dataset.length){
	    int diff = dataset.length - allFoldsRecordCount;
	    Log.write("Last fold will have : "+ diff);
	    double[][] fold = new double[diff][colCount];
	    setFoldData(fold, dataset, allFoldsRecordCount);
	    output.add(fold);
	}
	return output;
    }

    private void setFoldData(double[][] fold, double[][] dataset, int startIDx){
	if(dataset.length == 0) return ;
	int targetCount = startIDx + fold.length; int foldRowCounter = 0;
	for(int rowCounter = startIDx ; rowCounter < targetCount; rowCounter++){
	    for(int colCounter = 0; colCounter < dataset[0].length;colCounter++){
		fold[foldRowCounter][colCounter] = 
		    dataset[rowCounter][colCounter];
	    }
	    foldRowCounter++;
	}
    }

    public ArrayList<double[]> normalizeNumericCols(double[][] dataset, 
						    int[] numericIDXs,
						    int normalizationType) {
	
	if(normalizationType == 0){
	    return shiftScaleNormalization(dataset, numericIDXs);
	} else if(normalizationType == 1){
		return zeroMeanUnitVarNormalization(dataset, numericIDXs);
	}
	return null;
    }

    private ArrayList<double[]> zeroMeanUnitVarNormalization(double[][] data,
    							int[] numericIDXs){
    	double[] mean; double[] stdDev;
    	ArrayList<double[]> output = new ArrayList<double[]>();
    	
    	mean = getMeanVector(data, numericIDXs);
    	stdDev = getStdDevVector(data, numericIDXs);
    	return output;
    }
    
    private double[] getMeanVector(double[][] data, int[] numericIDXs){
    	double[] output = new double[numericIDXs.length];
    	if(data.length == 0) return output;//no data ; zero means
    	for(int rowCounter = 0; rowCounter < data.length; rowCounter++){
    		for(int colCounter = 0; colCounter < output.length; colCounter++){
    			int colIDx = numericIDXs[colCounter];
    			//add sum of each feature value
    			output[colCounter] += data[rowCounter][colIDx];
    		}//column iteration ends here
    	}//row iterations end here
    	//convert sum to mean
    	for(int index = 0; index < output.length; index++){
    		output[index] = output[index]/data.length;
    	}
    	return output;
    }
    
    private double[] getStdDevVector(double[][] data, int[] numericIDXs){
    	double[] output = new double[numericIDXs.length];
    	return output;
    }
    private ArrayList<double[]> shiftScaleNormalization(double[][] dataset, 
							int[] numericIDXs){
	double[] minValues = new double[numericIDXs.length];
	double[] maxValues = new double[numericIDXs.length];
	ArrayList<double[]> output = new ArrayList<double[]>();
	//initialization values
	for(int colIDx =0; colIDx < numericIDXs.length; colIDx++){
	    minValues[colIDx] = dataset[0][numericIDXs[colIDx]];
	    maxValues[colIDx] = dataset[0][numericIDXs[colIDx]];
	}
	// iterate over rest of the dataset to get min and max values
	for(int rowIDx = 1; rowIDx < dataset.length; rowIDx++){
	    for(int colIDx = 0 ; colIDx < numericIDXs.length; colIDx++){
		double value = dataset[rowIDx][numericIDXs[colIDx]];
		if(value < minValues[colIDx])
		    minValues[colIDx] = value;
		if(value > maxValues[colIDx])
		    maxValues[colIDx] = value;
	    }// column loop	    
	}// row loop
	
	// update values
	for(int rowIDx = 0; rowIDx < dataset.length; rowIDx++){
	    for(int colIDx = 0 ; colIDx < numericIDXs.length; colIDx++){
		double oldVal = dataset[rowIDx][numericIDXs[colIDx]];
		double newVal = oldVal;
		if(maxValues[colIDx] > 0)
		    newVal = (oldVal - minValues[colIDx]) / maxValues[colIDx];
		dataset[rowIDx][numericIDXs[colIDx]] = newVal;
	    }// column loop	    
	}// row loop
	Log.write("Dataset normalized..");
	/*try{
	Log.writeToFile(dataset, "train_norm.csv", ",");
	} catch(Exception e){}*/
	output.add(maxValues);
	output.add(minValues);
	return output;
    }

    public void normalizeRecord(double[] record, int[] numericColIDXs, 
				double[] maxValues, double[] minValues){
	for(int colIDx = 0 ; colIDx < numericColIDXs.length; colIDx++){
	    double oldVal = record[numericColIDXs[colIDx]];
	    double newVal = oldVal;
	    if(maxValues[colIDx] > 0)
		newVal = (oldVal - minValues[colIDx]) / maxValues[colIDx];
	    record[numericColIDXs[colIDx]] = newVal;
	}
    }

    /* Converts the numeric columns to nominal by threshold mechanism
     updates to dataset are applied in place
    */
    public void numericToNominal(double[][] dataset, int[] numericIDXs) {
	
	
    }

    /*
      Generates threshold values for numeric columns in dataset
     */
    public HashMap<Integer,double[]> getThresholds(double[][] dataset, 
					      int[] numericIDXs){

	if(dataset.length == 0) return null;
	HashMap<Integer, double[]> output = new HashMap<Integer, double[]>();
	for(int colIdx : numericIDXs){
	    TreeSet<Double> uniqueValues = findUniqueValues(dataset, colIdx);
	    double[] thresholds = generateThresholds(uniqueValues);
	    output.put(colIdx, thresholds);
	    /*test prints
	    Log.write("No. of Unique values for Column: " + (colIdx +1) 
		      + " : "+uniqueValues.size() );
	    Log.writeSet(uniqueValues);
	    Log.writeDoubleArray(thresholds); */
	}
	return output;
    }

    //given a column index, returns an array of all unique values of that column
    public TreeSet<Double> findUniqueValues(double[][] dataset, int colIdx){
	if(dataset.length == 0) return null;
	TreeSet<Double> output = new TreeSet<Double>();
	for(int rowIdx = 0; rowIdx < dataset.length; rowIdx++){
	    //relying on treeset's property to add only unique & ordered values
	    output.add(dataset[rowIdx][colIdx]);
	}
	return output;
    }

    //given an ordered set of unique values, 
    //returns a set of all unique threshold values - set is required here
    // as sum of unique values could be same e.g. 1+4 = 2+3 = 5
    private double[] generateThresholds(TreeSet<Double> uniqueValues){
	int uniqueValCount = uniqueValues.size();
	double firstItem = uniqueValues.first();
	if (uniqueValCount == 0) return new double[0];
	if (uniqueValCount == 1) return new double[]{firstItem};
	//as yi = (x<i>+x<i+1>)/2 - hence output array size = input list size-1
	TreeSet<Double> thSet = new TreeSet<Double>();//need only unique values
	Iterator<Double> valueCursor = uniqueValues.iterator();
	double lastIDxValue = firstItem;
	while(valueCursor.hasNext()){
	    double currentIDxValue = valueCursor.next();
	    //skip first as need two items to calculate avg
	    if(currentIDxValue == firstItem) continue;
	    //relying on treeset to add only unique values
	    thSet.add((lastIDxValue + currentIDxValue) / 2);
	    lastIDxValue = currentIDxValue;//update the value for next iteration
	}//while loop ends here
	double[] output = new double[thSet.size()];
	Iterator<Double> cursor = thSet.iterator();
	output = setToArray(cursor, output);
	return output;
    }
    
    //given dataset, returns various split options possible for this dataset
    public ArrayList<SplitOption> getSplitOptions(double[][] dataset, 
						  int[] numericIDXs,
						  int[] nominalIDXs, 
				 HashMap<Integer,double[]> thresholds){
	if(dataset.length == 0) return null;
	ArrayList<SplitOption> output = new ArrayList<SplitOption>();
	//generate options for numeric columns
	for(int numericColIdx : numericIDXs){
	    double[] thresholdArray = thresholds.get(numericColIdx);
	    // each threhold value translates to one split option
	    for(double val : thresholdArray){
		SplitOption sop = new SplitOption();
		sop.attributeIDx = numericColIdx;
		sop.type = 't';	sop.value = val;output.add(sop);
	    }// threshold array loop ends here
	}// numeric column iteration ends here
	//generate options for nominal columns
	for(int nominalColIdx : nominalIDXs){
	    //each nominal column generates one split option
	    SplitOption sop = new SplitOption();
	    sop.attributeIDx = nominalColIdx;   sop.type = 'c';
	    TreeSet<Double> classesSet = findUniqueValues(dataset, 
							  nominalColIdx);
	    double[] classes = new double[classesSet.size()];
	    classes = setToArray(classesSet.iterator(), classes);
	    sop.classes = classes;  output.add(sop);
	}//nominal columns iteration ends here
	return output;
    }
    
    // iterates over a set to convert it into an array
    private double[] setToArray(Iterator<Double> cursor, double[] target){
	    int counter = 0;
	    while(cursor.hasNext()){//.toArray(T[]) function did not work
		target[counter] = cursor.next();
		counter++;
	    }
	    return target;
    }
}

class SplitOption{
    public int attributeIDx; //0 based index
    public char type; // could be c or t
    public double value; // stores actual value if of type t
    public double[] classes; // stores classes if of type c //as of now double
}
