/*
* The Broad Institute
* SOFTWARE COPYRIGHT NOTICE AGREEMENT
* This is copyright (2007-2009) by the Broad Institute/Massachusetts Institute
* of Technology.  It is licensed to You under the Gnu Public License, Version 2.0
* (the "License"); you may not use this file except in compliance with
*  the License.  You may obtain a copy of the License at
*
*    http://www.opensource.org/licenses/gpl-2.0.php
*
* This software is supplied without any warranty or guaranteed support
* whatsoever. Neither the Broad Institute nor MIT can be responsible for its
* use, misuse, or functionality.
*/












package org.broad.igv.roc;

//~--- non-JDK imports --------------------------------------------------------

import cern.colt.GenericSorting;

//~--- JDK imports ------------------------------------------------------------

import java.io.*;


/**
 * Computes the ROC AUC (area under the curve) given a classification vector and a set of scores.
 * Reference: 
 * 
 * ROC Graphs: Notes and Practical Considerations for Data Mining Researchers
 * Tom Fawcett
 * Intelligent Enterprise Technologies Laboratory
 * HP Laboratories Palo Alto
 * HPL-2003-4
 * January 7th , 2003* * 
 * 
 * @author jrobinso
 */
public class ROC {
    
    public static boolean ENABLED = false;

    private int nPts;
    private int[] classVector;
    private float[] values;

    /**
     * Constructs ...
     *
     *
     * @param classVector
     * @param values
     */
    public ROC(int[] classVector, float[] values) {
        this.classVector = classVector;
        this.values = values;
        nPts = Math.min(classVector.length, values.length);
    }

    /**
     * Method description
     *
     * @return
     */
    public double computeAUC() {

        GenericSorting.quickSort(0, classVector.length, new ROCComparator(), new ROCSwapper());

        int totalPositives = 0;
        int totalNegatives = 0;
        double nTruePositives = 0;
        double nFalsePositives = 0;
        for (int i = 0; i < nPts; i++)
        {
            if (classVector[i] == 0)
            {
                totalNegatives++;
            }
            else
            {
                totalPositives++;
            }
        }

        double area = 0;
        double lastFpRate = 0;
        double lastTpRate = 0;
        for (int i = 0; i < nPts; i++)
        {
            if (classVector[i] == 1)
            {
                nTruePositives++;
            }
            else
            {
                nFalsePositives++;
            }
            double fpRate = nFalsePositives / totalNegatives;

            if (fpRate != lastFpRate)
            {
                double tpRate = nTruePositives / totalPositives;
                double width = fpRate - lastFpRate;
                area += width * tpRate;
                lastFpRate = fpRate;
                lastTpRate = tpRate;
            }
        }

        // Last rectangle
        if (lastFpRate < 1)
        {
            double width = 1 - lastFpRate;
            area += width * lastTpRate;
        }

        return area;
    }

    class ROCSwapper implements cern.colt.Swapper {


        /**
         * Method description
         *
         *
         * @param a
         * @param b
         */
        public void swap(int a, int b) {

            int temp;
            temp = classVector[a];
            classVector[a] = classVector[b];
            classVector[b] = temp;

            float t = values[a];
            values[a] = values[b];
            values[b] = t;

        }



    }


    class ROCComparator implements cern.colt.function.IntComparator {

        /**
         * Method description
         *
         *
         * @param i
         * @param j
         *
         * @return
         */
        public int compare(int i, int j) {

            if (values[i] < values[j])
            {
                return 1;
            }
            else
            {
                if (values[i] > values[j])
                {
                    return -1;

                }
                else
                {
                    return 0;
                }
            }
        }
    }


    /**
     * Method description
     * TODO -- turn this into a unit test
     *
     * @param args
     *
     * @throws IOException
     */
    public static void main(String[] args) throws IOException {


        BufferedReader br = new BufferedReader(new FileReader("test.txt"));

        // Count lines
        int nLines = 0;
        String nextLine;
        while ((nextLine = br.readLine()) != null)
        {
            nLines++;
        }
        br.close();

        int[] classVector = new int[nLines];
        float[] values = new float[nLines];

        int n = 0;
        br = new BufferedReader(new FileReader("test.txt"));
        while ((nextLine = br.readLine()) != null)
        {
            String[] tokens = nextLine.split("\t");
            values[n] = Float.parseFloat(tokens[0]);
            classVector[n] = Integer.parseInt(tokens[1]);
            n++;
        }
        br.close();


        ROC roc = new ROC(classVector, values);
        System.out.println("AUC = " + roc.computeAUC());


    }

}
