import java.util.*;

public class Regressor{

    //global public constants 
    public Dataset trainingSet;
    public double[] currentWeights;
    public double[] diffPerDataPoint;
    public double totalError;// sum of error over all data points
    public double learnRate;// initial learning rate
    //global private constants
    private static final double INIT_LEARN_RATE = 0.25;
    private static final double ERROR_TARGET = 50.0;
    private static final double INIT_WEIGHT = 1.0;
    private static final double MSE_NON_ZERO_CONST = 1.0;
    private double[] newWeights;
    private double[] predPerDataPoint;
    private double errorTarget;//target value for Gradient descent convergence
    private ArrayList<Double> errorPerIteration;
    private ArrayList<double[]> normArrays;

    //Functions-----------------------------------------------------------------
    public Regressor(Dataset data){
	this.trainingSet = data;
    }

    //builds a simple linear regression model using gradient descent
    public void buildLinearRegressionModel(){
	Log.write("Linear regression model : Generation started.");
	initializeLinearRegressionConstants();
	Log.write("Linear regression model : Initialization complete.");
	normalizeDataset();
	Log.write("Linear regression model : Dataset normalized.");
	findLocalMinimaRecursively();
	Log.write("Linear regression model : Generation done.");
	Log.write("Model Weights are: ");
	Log.writeDoubleArray(currentWeights);
    }
    
    //sets initial values for liner regression predictor
    private void initializeLinearRegressionConstants(){
	errorPerIteration = new ArrayList<Double>();
	int dataRowCount = trainingSet.data.length;
	predPerDataPoint = new double[dataRowCount];
	diffPerDataPoint = new double[dataRowCount];
	currentWeights = initializeWeights();
	learnRate = INIT_LEARN_RATE;
	errorTarget = ERROR_TARGET;
    }
    
    //normalizes all columns of dataset
    private void normalizeDataset(){
	DataProcessor processor = new DataProcessor();
	normArrays = processor.normalizeNumericCols(trainingSet.data, 
				       trainingSet.numericColIDXs, 1);
    }
    //returns initial weights array for first iteration for current dataset
    private double[] initializeWeights(){
	int colCount = trainingSet.numericColIDXs.length + 
	    trainingSet.nominalColIDXs.length;
	double[] weights = new double[colCount];
	for(int i=0; i< colCount; i++)
	    weights[i] = INIT_WEIGHT;//set weight for each column to default
	return weights;
    }
    
    //Finds local minima by running gradient descent on dataset
    private void findLocalMinimaRecursively(){
    	do{
    		updateLearningRate();
    		updateWeights();//first iteration doesn't need predictions
    		calculatePredictions();
    		calculateDiffperDataPoint();
    		totalError = getTotalMSE();
    		errorPerIteration.add(totalError);
    		Log.write("Iteration: " + errorPerIteration.size() + 
    				" MSE: "+ totalError);
    	} while(totalError > errorTarget);
	//if(errorPerIteration.size() == 500) return;
	//if(totalError > errorTarget) findLocalMinimaRecursively();
    }
    
    private void updateLearningRate(){
	learnRate = INIT_LEARN_RATE;//keeping it constant as of now
//	int iterations = errorPerIteration.size();
//	if (iterations < 3) return;
//	double diff = errorPerIteration.get(iterations - 2) -
//			errorPerIteration.get(iterations - 1);
//	learnRate = learnRate + (0.001 * diff);
//	learnRate = Math.abs(learnRate);
//    Log.write("New Learn Rate: " + learnRate);
    }

    //Invariant: All numeric cols' weights are first and then nominal columns
    private void updateWeights(){
	if(newWeights == null){//first iteration
	    newWeights = initializeWeights();
	}else{
	    updateNewWeights();
	    updateCurrentWeights();
	}
	 
    }

    private void updateNewWeights(){
	int rowCount = trainingSet.data.length;
	int numericColCount = trainingSet.numericColIDXs.length;
	int nominalColCount = trainingSet.nominalColIDXs.length;
	for(int rowCounter = 0; rowCounter < rowCount; rowCounter++){
	    //iterate over numeric and nominal cols as label col index might
	    //be somewhere in between
	    //weight array has all numeric cols weight on top, then nominal
	    int weightCounter = 0;
	    for(int index : trainingSet.numericColIDXs){
		newWeights[weightCounter] += 
		    (diffPerDataPoint[rowCounter] * 
		     trainingSet.data[rowCounter][index]);
		weightCounter++;
	    }//numeric columns iteration ends here
	    for(int index : trainingSet.nominalColIDXs){
		newWeights[weightCounter] += 
		    (diffPerDataPoint[rowCounter] * 
		     trainingSet.data[rowCounter][index]);
		weightCounter++;
	    }//nominal columns iteration ends here
	}//rows iteration ends here
    }
    
    private void updateCurrentWeights(){
    int recordCount = trainingSet.data.length;
	for(int counter = 0 ; counter< currentWeights.length; counter++){
	    double oldVal = currentWeights[counter];
	    //skipping division as unnecessary
	    double newVal = oldVal - 
	    		( (learnRate * newWeights[counter]) / recordCount ) ;
	    currentWeights[counter] = newVal;//update the value
	}//counter iteration ends here
    }

    private void calculatePredictions(){
	int rowCount = trainingSet.data.length;
	int numericColCount = trainingSet.numericColIDXs.length;
	int nominalColCount = trainingSet.nominalColIDXs.length;
	for(int rowCounter = 0; rowCounter < rowCount; rowCounter++){
	    //iterate over numeric and nominal cols as label col index might
	    //be somewhere in between
	    //weight array has all numeric cols weight on top and then nominal
	    double rowSum = 0.0; int weightCounter = 0;
	    for(int index : trainingSet.numericColIDXs){
		rowSum += (currentWeights[weightCounter] * 
			   trainingSet.data[rowCounter][index]);
		weightCounter++;
	    }//numeric columns iteration ends here
	    for(int index : trainingSet.nominalColIDXs){
		rowSum += (currentWeights[weightCounter] * 
			   trainingSet.data[rowCounter][index]);
		weightCounter++;
	    }//nominal columns iteration ends here
	    predPerDataPoint[rowCounter] = rowSum;
	}//rows iteration ends here
    }

    private void calculateDiffperDataPoint(){
	int rowCount = predPerDataPoint.length;
	int labelIDx = trainingSet.labelColIDx;
	for(int rowCounter = 0; rowCounter< rowCount; rowCounter++){
	    diffPerDataPoint[rowCounter] = 
		predPerDataPoint[rowCounter] - 
		trainingSet.data[rowCounter][labelIDx];
	    
	}
    }

    private double getTotalMSE(){
	double sum = 0.0;
	int rowCount = diffPerDataPoint.length;
	for(int rowIDx = 0 ; rowIDx < rowCount; rowIDx++)
	    sum += (diffPerDataPoint[rowIDx] * diffPerDataPoint[rowIDx]);
	return (MSE_NON_ZERO_CONST * ( sum / rowCount) );//skipping the unnecessary division
    }
}
