import java.util.ArrayList;
import java.util.Random;

/**
 * Simple perceptron Implementation
 *  1. Starts with a random set of weights
 *  2. Puts all points on negative side to positive by flipping sign of X and Y
 *  3. Finds all misclassified instances
 *  4. Updates weights based on these instances
 *  5. Iterates over 4 and 5 till no misclassified instances
 *  Followed algorithm described at:
 *  http://www.willamette.edu/~gorr/classes/cs449/Classification/perceptron.html
 */
public class Perceptron {
    //public global constants
    public Dataset trainingSet = null;
    public double[] learningVector = null;
    public double[] currentWeights = null;
    public double learnRate;
    public int mistakeTarget;
    //private global constants
    private double INIT_LEARN_RATE = 0.01;
    private double INIT_WEIGHT = 0.0;
    private int MISTAKES_TARGET = 0;
    private static final boolean RANDOM_LEARN_RATE = false;
    private static final boolean RANDOM_INIT_WEIGHT = false;
    private static final boolean USE_LEARNING_VECTOR = false;
    private static final int MAX_ITERATIONS = 500;
    private double[] initWeights = null;
    private ArrayList<double[]> normArrays = null;

    //Methods
    public Perceptron(Dataset data){
        this.trainingSet = data;
    }

    public void buildClassifier(){
        addBias();
        initializeConstants();
        normalizeDataset();
        currentWeights = initializeWeights();
        learnWeightsIteratively();
        Log.writeDoubleArray(currentWeights);
    }
    private void addBias(){
        DataProcessor processor = new DataProcessor();
        trainingSet.data = processor.addBiasColumn(trainingSet.data);
        int[] newNumericColIDXs = new int[trainingSet.numericColIDXs.length+1];
        newNumericColIDXs[0] = 0;//add bias column first
        int newColIDx = 1;
        for(int index : trainingSet.numericColIDXs){
            newNumericColIDXs[newColIDx] = index + 1;
            newColIDx++;
        }
        trainingSet.numericColIDXs = newNumericColIDXs;
        trainingSet.labelColIDx = trainingSet.labelColIDx + 1;
        int iDx = 0;
        for(int index : trainingSet.nominalColIDXs){
            trainingSet.nominalColIDXs[iDx] = index + 1;
            iDx++;
        }
    }
    private void initializeConstants(){
        learningVector = new double[trainingSet.numericColIDXs.length +
                trainingSet.nominalColIDXs.length];
        currentWeights = initializeWeights();
        mistakeTarget = MISTAKES_TARGET;
        if(!USE_LEARNING_VECTOR){
            learnRate = INIT_LEARN_RATE;
        } else {
            Random randNum = new Random();
            for(int idx = 0; idx < learningVector.length ; idx++){
                if (RANDOM_LEARN_RATE){
                    learningVector[idx] = 0.01 + ( randNum.nextDouble() *
                            (learnRate - 0.01) );//random rate in a range
                } else{
                    learningVector[idx] = learnRate;
                }
            }
        }
    }
    private void normalizeDataset(){
        DataProcessor processor = new DataProcessor();
        normArrays = processor.normalizeNumericCols(trainingSet.data,
                trainingSet.numericColIDXs, 1);
    }
    private double[] initializeWeights(){
        int colCount = trainingSet.numericColIDXs.length +
                trainingSet.nominalColIDXs.length;
        double[] weights = new double[colCount];
        Random randNum = new Random();
        for(int idx = 0; idx < colCount ; idx++){
            if (RANDOM_INIT_WEIGHT){
                weights[idx] = 0.0 + ( randNum.nextDouble() *
                        (INIT_WEIGHT - 0.0) );//random weight in a range
            } else{
                weights[idx] = INIT_WEIGHT;//set same weight for each column
            }
        }
        initWeights = weights;
        return weights;

    }

    private void learnWeightsIteratively(){
        ArrayList<Integer> mistakes = null;
        int iterationCounter = 1;
        do{
            mistakes = getWronglyClassifiedPoints();
            updateLearningRate(iterationCounter);
            updateWeights(mistakes);
            Log.write("Iterations: "+ iterationCounter++ + ", " +
                    "Mistakes: "+ mistakes.size() );
            if(iterationCounter == MAX_ITERATIONS) break;
        } while(mistakes.size() > mistakeTarget);// converge till no mistakes
    }
    private ArrayList<Integer> getWronglyClassifiedPoints() {
        ArrayList<Integer> output = new ArrayList<Integer>();
        int rowCount = trainingSet.data.length;
        int colCount = currentWeights.length;
        int labelColIDx = trainingSet.labelColIDx;
        for(int rowCounter = 0; rowCounter < rowCount ; rowCounter++){
            double prediction = 0.0;
            for(int colCounter = 0; colCounter < colCount; colCounter++){
                prediction += currentWeights[colCounter] *
                        trainingSet.data[rowCounter][colCounter];
            }//column iteration ends here
            if( isDifferentSign(prediction,
                    trainingSet.data[rowCounter][labelColIDx]) )
                output.add(rowCounter);//add all misclassified points to list
        }//rows iteration ends here
        return output;
    }

    private boolean isDifferentSign(double value1, double value2) {
        if(value1 < 0) return (value2 > 0);
        if(value1 > 0) return (value2 < 0);
        if(value1 == 0) return (!(value2 >= 0));//treating 0 same as positive
        Log.write("Unhandled sign condition detected.");
        return false;//default return - code not expected to reach here
    }
    private void updateLearningRate(int iterationCount){
        learnRate = INIT_LEARN_RATE;//keeping it constant as of now
        //if (iterationCount == 1) return;
        //learnRate = learnRate / (0.1 * iterationCount) ;
    }

    private void updateWeights(ArrayList<Integer> mistakes){
        double[] newWeights = new double[currentWeights.length];
        for(int colIDx = 0; colIDx < currentWeights.length; colIDx++){
            newWeights[colIDx] = 0.0;
            for(int rowIDx : mistakes){
                double multiplier = trainingSet.data[rowIDx][trainingSet
                        .labelColIDx];
                newWeights[colIDx] += multiplier * trainingSet
                        .data[rowIDx][colIDx];
            }//rows iteration ends here
            newWeights[colIDx] = learnRate * newWeights[colIDx];
            newWeights[colIDx] = currentWeights[colIDx] + newWeights[colIDx];
        }// columns iteration ends here
         currentWeights = newWeights;//update weights vector
    }

    private double[][] flipDataPoints(){
        int rowCount = trainingSet.data.length;
        int colCount = trainingSet.numericColIDXs.length + trainingSet
                .nominalColIDXs.length + trainingSet.labelColIDx;
        double[][] data = new double[rowCount][colCount];
        for(int rowCounter = 0; rowCounter < rowCount ; rowCounter++){
            int flipMultiplier = 1;//default 1 for positive class data points
            if(trainingSet.data[rowCounter][trainingSet.labelColIDx] < 0)
                flipMultiplier = -1;//negative for negative class data points
            for(int colCounter = 0; colCounter < colCount; colCounter++){
                data[rowCounter][colCounter] =
                        flipMultiplier *
                                trainingSet.data[rowCounter][colCounter];
            }//column iteration ends here
        }//rows iteration ends here
        return data;
    }
}
