/*
 * The Broad Institute
 * SOFTWARE COPYRIGHT NOTICE AGREEMENT
 * This is copyright (2007-2008) by the Broad Institute/Massachusetts Institute 
 * of Technology.  It is licensed to You under the Gnu Public License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 *  the License.  You may obtain a copy of the License at
 *
 *    http://www.opensource.org/licenses/gpl-2.0.php
 *
 * This software is supplied without any warranty or guaranteed support
 * whatsoever. Neither the Broad Institute nor MIT can be responsible for its
 * use, misuse, or functionality.
*/

package org.broad.igv.data.wgs;

import org.broad.igv.preprocess.*;
import org.broad.igv.data.GenomeSummaryData;
import org.broad.igv.data.Dataset;

import cern.colt.list.DoubleArrayList;

import cern.jet.stat.quantile.DoubleQuantileFinder;
import cern.jet.stat.quantile.QuantileFinderFactory;

import org.broad.igv.feature.Chromosome;
import org.broad.igv.feature.Genome;

import ncsa.hdf.hdf5lib.HDF5Constants;

import org.broad.igv.feature.GenomeManager;


import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.log4j.Logger;
import org.broad.igv.h5.HDF5LocalWriter;
import org.broad.igv.h5.HDFWriter;

/**
 *
 * @author jrobinso
 */
public abstract class WGSBaseProcessor {

    final public static String CHR_ALL = "All";

    private static Logger log = Logger.getLogger(WGSBaseProcessor.class);

    /** Data format version number.  This is recorded to the HDF5.  */
    private static int version = 3;

    /** The Dataset to be processed */
    private SequenceDataset dataset;

    private Genome genome;

    private int tileSize = 700;

    private int zoomMax = 10;

    private int zoomMin = 2;

    private Set<String> chromosomeGroupKeys = new HashSet<String>();

    private Set<String> zoomGroupKeys = new HashSet<String>();

    private Set<String> datasetCache = new HashSet<String>();

    //GenomeSummaryData genomeSummaryData;

    StatusMonitor statusMonitor;

    HDFWriter writer = new HDF5LocalWriter(); //new ZipWriter(); // 

    boolean inferZeroes = false;

    public WGSBaseProcessor(SequenceDataset dataset, StatusMonitor statusMonitor) {
        this(dataset);
        this.statusMonitor = statusMonitor;
    }

    public WGSBaseProcessor(SequenceDataset dataset) {
        this.dataset = dataset;
        this.genome = dataset.getGenome();
        //genomeSummaryData = new GenomeSummaryData(genome);
    }

    public boolean process(String outputFile) {
        int file = writer.createFile(outputFile);
        try {

            // openFile root group
            int root = writer.openGroup(file, "/");

            writer.writeAttribute(root, "name", dataset.getName());
            writer.writeAttribute(root, "has.data", 1);
            writer.writeAttribute(root, "normalized", 0);
            writer.writeAttribute(root, "log.values", 0);
            writer.writeAttribute(root, "version", version);
            writer.writeAttribute(root, "type", dataset.getType());
            System.out.println("window.span = " + dataset.getWindowSpan());

            if (dataset.getWindowSpan() > 0) {
                writer.writeAttribute(root, "window.span", dataset.getWindowSpan());
            }

            // Process features, and as a side effect compute bin information
            double estTimeFraction = 0.1;   // Estimate of total time spent processing features
            int featureGroup = writer.createGroup(root, "features");
            Map<String, List<BinnedData>> binInfoMap = processFeatures(featureGroup, estTimeFraction);
            writer.closeGroup(featureGroup);

            // Process data
            int dataGroup = writer.createGroup(root, "data");
            writer.createAndWriteStringDataset(dataGroup, "track.id", new String[] {dataset.getTrackName()});

            estTimeFraction = 0.7;
            processData(dataGroup, binInfoMap, estTimeFraction);
            writer.closeGroup(dataGroup);

            // cleanup
            writer.closeGroup(root);


            return true;

        } catch (InterruptedException ex) {
            // TODO -- cleanup
            return false;

        } finally {
            try {
                writer.closeFile(file);
            } catch (Exception e) {
                if (!Thread.interrupted()) {
                //log.error(e);
                }
            }
        }
    }

    /**
     * Check the tread for interrupt status.  Used to support cancelling
     */
    private void checkForInterrupt() throws InterruptedException {
        if (Thread.interrupted()) {
            System.out.println("Interrupted");
            throw new InterruptedException();
        }
    }

    /**
     *
     * @param featureGroup
     * @return
     */
    private Map<String, List<BinnedData>> processFeatures(int featureGroup, double estTimeFraction) throws InterruptedException {

        double progIncrement = (estTimeFraction * 100) / dataset.getChromosomes().length;

        Map<String, List<BinnedData>> binInfoMap = new HashMap();
        for (String chr : dataset.getChromosomes()) {

            checkForInterrupt();

            List<BinnedData> binnedData = processFeaturesForChromosome(chr, featureGroup);
            if (binnedData != null) {
                binInfoMap.put(chr, binnedData);
            }

            if (statusMonitor != null) {
                statusMonitor.incrementStatus(progIncrement);
            }


        }
        binInfoMap.put(CHR_ALL, processFeaturesForChromosome(CHR_ALL, featureGroup));

        return binInfoMap;
    }

    /**
     *
     * @param chr
     * @param featureGroup
     * @return
     */
    private List<BinnedData> processFeaturesForChromosome(String chr, int featureGroup) throws InterruptedException {


        // Chromosome c = genome.getChromosome(chr);
        int chrLength = 0;

        if (chr.equals(CHR_ALL)) {
             chrLength = (int) (genome.getLength() / 1000);
        } else {
            Chromosome c = genome.getChromosome(chr);

            if (c == null) {
                System.out.println("Missing chromosome: " + chr);
                return null;
            }
             chrLength = c.getLength();
           
        }


        int chrGroup = writer.createGroup(featureGroup, chr);

        writer.writeAttribute(chrGroup, "length", chrLength);

        List<BinnedData> binnedDataList = computeAllBins(chr, chrLength);

        // Record maximum zoom level
        int numZoomLevels = binnedDataList.size() == 0 ? 0 : binnedDataList.get(binnedDataList.size() - 1).getZoomLevel() + 1;

        writer.writeAttribute(chrGroup, "zoom.levels", numZoomLevels);

        for (BinnedData binnedData : binnedDataList) {

            checkForInterrupt();


            String zoomName = "z" + binnedData.getZoomLevel();
            int zoomGroup = writer.createGroup(chrGroup, zoomName);

            writer.writeAttribute(zoomGroup, "bin.size", binnedData.getBinSize());

            double tileWidth = tileSize * binnedData.getBinSize();

            writer.writeAttribute(zoomGroup, "tile.width", tileWidth);

            // TODO  mean.count, data.count,  max.count
            writer.writeAttribute(zoomGroup, "mean.count", binnedData.getMeanCount());
            writer.writeAttribute(zoomGroup, "median.count", binnedData.getMedianCount());
            writer.writeAttribute(zoomGroup, "max.count", binnedData.getMaxCount());
            writer.writeAttribute(zoomGroup, "percentile90.count", binnedData.getPercentile90Count());

            // Record bin starting startLocations
            int[] locations = binnedData.getLocations();

            writer.createAndWriteVectorDataset(zoomGroup, "start", locations);

            // Record boundary indices (bin number) for each tile
            int[] tileBoundaries = binnedData.getTileBoundaries();

            writer.createAndWriteVectorDataset(zoomGroup, "tile.boundary", tileBoundaries);

            // Record # pts for each bin
            float[] ptsPerBin = binnedData.getCounts();

            writer.createAndWriteVectorDataset(zoomGroup, "count", ptsPerBin);
            writer.closeGroup(zoomGroup);
        }

        // Record unprocessed coordinates
        int rawGroup = writer.createGroup(chrGroup, "raw");

            int[] locs = new int[chrLength];
            for(int i=0; i<chrLength; i++) {
                locs[i] = i;
            }
        writer.createAndWriteVectorDataset(rawGroup, "start", locs);

        assert (chrLength < Integer.MAX_VALUE);
        recordRawIndex(rawGroup, (int) chrLength);
        writer.closeGroup(rawGroup);
 
        //if (!chr.equals(CHR_ALL) && startLocations.length > 0) {
        //    genomeSummaryData.addLocations(chr, startLocations);
       // }


        writer.closeGroup(chrGroup);

        return binnedDataList;
    }

    protected int getZoomMax() {
        return zoomMax;
    }

    /**
     *
     * @param chr
     * @param maxLength
     * @return
     */
    private List<BinnedData> computeAllBins(String chr, int chrLength) {
        double binCountCutoff = 3;
        List<BinnedData> binInfoList = new ArrayList();
        int adjustedZoomMax = (chr.equals(CHR_ALL)
                ? 1
                : getZoomMax());

        //int[] startLocations = getStartLocationsForChromosome(chr);
        //int[] endLocations = getEndLocationsForChromosome(chr);

        for (int z = 0; z < adjustedZoomMax; z++) {
            int nTiles = (int) Math.pow(2, z);
            int nBins = nTiles * this.tileSize;
            double binSize = ((double) chrLength) / nBins;

            if (binSize < 0) {
                System.out.println("Negative bin size");
            }


            List<Bin> bins = allocateBins(chrLength, nBins, binSize);
            BinnedData binInfo = computeBinnedData(z, chrLength, nTiles, bins, binSize);

            binInfoList.add(binInfo);

            if ((binInfo.getMeanCount() < binCountCutoff) && (z > zoomMin)) {
                break;
            }
        }

        return binInfoList;
    }

    protected abstract List<Bin> allocateBins(int chrLengh, int nBins, double binSize);

    private void recordRawIndex(int groupId, int chrLength) {
        double chunkSize = 10000;
        int nChunks = (int) (chrLength / chunkSize) + 1;
        int[] indices = new int[nChunks];
        int i = 0;
        int n = 0;

        while ((n < nChunks) && (i < chrLength)) {
            int boundary = (int) (n * chunkSize);

            try {
                while ((i < boundary) && (i < chrLength - 1)) {
                    i++;
                }
            } catch (Exception e) {
                e.printStackTrace();
            }

            indices[n] = i;
            n++;
        }

        // If we haven't filled the index array it means we have run out of
        // startLocations.  In other words there is no data (startLocations) in the
        // remaining chunks.  Record there indeces = the max location index.
        while (n < nChunks) {
            indices[n] = chrLength - 1;
            n++;
        }

        writer.writeAttribute(groupId, "index.span", chunkSize);
        writer.createAndWriteVectorDataset(groupId, "index", indices);
    }

    private void recordStats(DataStatistics[] stats, int group) {
        int nPts = stats.length;
        float[] min = new float[nPts];
        float[] mean = new float[nPts];
        float[] max = new float[nPts];
        float[] median = new float[nPts];
        float[] percentile10 = new float[nPts];
        float[] percentile90 = new float[nPts];
        float[] percentile98 = new float[nPts];
        float[] stddev = new float[nPts];
        for (int i = 0; i < nPts; i++) {
            DataStatistics stat = stats[i];
            min[i] = stat == null ? Float.NaN : (float) stat.getMin();
            mean[i] = stat == null ? Float.NaN : (float) stats[i].getMean();
            max[i] = stat == null ? Float.NaN : (float) stats[i].getMax();
            median[i] = stat == null ? Float.NaN : (float) stats[i].getMedian();
            percentile10[i] = stat == null ? Float.NaN : (float) stats[i].getPercentile10();
            percentile90[i] = stat == null ? Float.NaN : (float) stats[i].getPercentile90();
            percentile98[i] = stat == null ? Float.NaN : (float) stats[i].getPercentile90();
            stddev[i] = stat == null ? Float.NaN : (float) stats[i].getStdDev();
        }

        writer.createAndWriteVectorDataset(group, "min", min);
        writer.createAndWriteVectorDataset(group, "mean", mean);
        writer.createAndWriteVectorDataset(group, "max", max);
        writer.createAndWriteVectorDataset(group, "median", median);
        writer.createAndWriteVectorDataset(group, "percentile10", percentile10);
        writer.createAndWriteVectorDataset(group, "percentile90", percentile90);
        writer.createAndWriteVectorDataset(group, "percentile98", percentile98);
        writer.createAndWriteVectorDataset(group, "stddev", stddev);

    }

    /**
     * Open the HDF5 group for the chromosome.  Create the group if neccessary.
     * 
     * @param chr
     * @param dataGroup  HDF5 handle to the containing data group
     * @return  HDF5 handle for the chromosome group
     */
    public int openChromosomeGroup(String chr, int dataGroup) {
        if (chromosomeGroupKeys.contains(chr)) {
            return writer.openGroup(dataGroup, chr);
        } else {
            int chrGroup = writer.createGroup(dataGroup, chr);
            chromosomeGroupKeys.add(chr);
            return chrGroup;
        }
    }

    /**
     * Open the HDF5 group for a specific zoom level.  Create the group if
     * neccessary.
     * 
     * @param chr
     * @param zoomName
     * @param chrGroup  The HDF5 handle for the containing chromosome group
     * @return  HDF5 handle for the zoom group
     */
    public int openZoomGroup(String chr, String zoomName, int chrGroup) {
        String key = chr + zoomName;

        if (zoomGroupKeys.contains(key)) {
            return writer.openGroup(chrGroup, zoomName);
        } else {
            int zoomGroup = writer.createGroup(chrGroup, zoomName);
            zoomGroupKeys.add(key);
            return zoomGroup;
        }
    }

    /**
     * Open a 2-D dataArrayDataset, create if neccessary.
     *
     * @param chr
     * @param zoomName
     * @param dsName
     * @param nCols
     * @param zoomGroup
     * @return
     */
    public int openDataset(String chr, String zoomName, String dsName, int nCols, int zoomGroup) {

        String key = chr + zoomName + dsName;


        if (datasetCache.contains(key)) {
            int datasetId = writer.openDataset(zoomGroup, dsName);

            return datasetId;
        } else {
            int nRows = 1;
            int datasetId = writer.createDataset(zoomGroup, dsName, HDF5Constants.H5T_NATIVE_FLOAT,
                    new long[]{nRows, nCols});

            datasetCache.add(key);

            return datasetId;
        }
    }

    /**
     *
     * @param chr
     * @param zoomName
     * @param dsName
     * @param nCols
     * @param zoomGroup
     * @return
     */
    public int openVectorDataset(String chr, String zoomName, String dsName, int zoomGroup) {
        String key = chr + zoomName + dsName;

        if (datasetCache.contains(key)) {
            int datasetId = writer.openDataset(zoomGroup, dsName);

            return datasetId;
        } else {
            int nRows = 1;
            int datasetId = writer.createDataset(zoomGroup, dsName, HDF5Constants.H5T_NATIVE_FLOAT,
                    new long[]{nRows});

            datasetCache.add(key);

            return datasetId;
        }
    }

    private List<String> getAllChromosomes() {
        List<String> allChromosomes = new ArrayList(Arrays.asList(dataset.getChromosomes()));
        allChromosomes.add(CHR_ALL);
        return allChromosomes;

    }

    /**
     * 
     * @param dataGroup handle to the HDF5 data group node
     * @param binnedDataList
     * @param estTimeFraction
     * @throws java.lang.InterruptedException
     */
    private void processData(
            int dataGroup,
            Map<String, List<BinnedData>> binInfoMap,
            double estTimeFraction) throws InterruptedException {

        List<String> allChromosomes = getAllChromosomes();

        int nSteps = allChromosomes.size() * 1;
        double procProgIncrement = (estTimeFraction * 0.8 * 100) / nSteps;
        double rawDataProgIncrement = (estTimeFraction * 0.2 * 100) / nSteps;

        // Loop through chromosomes
        for (String chr : allChromosomes) {

            checkForInterrupt();

            if (/*chr.equals(CHR_ALL) ||*/ genome.getChromosome(chr) != null) {

                int chrGroup = openChromosomeGroup(chr, dataGroup);

                // Loop through samples
                int sampleNumber = 0;
                boolean hasNulls = false;
                int nCols = 0;
                float[][] allData = new float[1][];
                DataStatistics[] stats = new DataStatistics[1];
                //for (String sample : dataset.getDataHeadings()) {
                
                    float[] data = this.getDataForChromosome(chr);
                    allData[sampleNumber] = data;

                    if (data == null || data.length == 0) {
                        allData[sampleNumber] = null;
                        stats[sampleNumber] = null;
                        log.info("No data for  chr: " + chr);
                    } else {
                        checkForInterrupt();
                        nCols = data.length;

                        processDataForChromosome(chrGroup, binInfoMap.get(chr), chr);

                        if (statusMonitor != null) {
                            statusMonitor.incrementStatus(procProgIncrement);
                        }

                        if (statusMonitor != null) {
                            statusMonitor.incrementStatus(rawDataProgIncrement);
                        }

                        //genomeSummaryData.addData(sample, chr, data);

                        stats[sampleNumber] = ProcessingUtils.computeStats(data);

                    }

                    sampleNumber++;
                //}

                // If there are any null rows replace them with NaN
                if (hasNulls && nCols > 0) {
                    float[] nanArray = new float[nCols];
                    Arrays.fill(nanArray, Float.NaN);
                    for (int i = 0; i < allData.length; i++) {
                        if (allData[i] == null) {
                            allData[i] = nanArray;
                        }
                    }
                }

                int rawGroup = openZoomGroup(chr, "raw", chrGroup);

                writer.createAndWriteDataset(rawGroup, "value", allData);
                recordStats(stats, rawGroup);

                writer.closeGroup(rawGroup);

                writer.closeGroup(chrGroup);
            }
        }
    }

    /**
     * Get the data value array for a particular sample and chromosome
     * @param sample
     * @param chr
     * @return
     */
    float[] getDataForChromosome(String chr) {
        if (chr.equals(CHR_ALL)) {
            //TODO
            return null;
            //return genomeSummaryData.getData(sample);
        } else {
            return dataset.getData(chr);
        }
    }

    /**
     *
     * @param sampleNumber integer identifying the sampele.  Used to select a
     *                     specific row in the data array
     * @param sample  display name for the sample
     * @param chrGroup  handle to the chromosome group node
     * @param binnedDataList list of BinnedData objects for this sample
     * @param chr  name of chromosome being processed
     */
    private void processDataForChromosome(int chrGroup,
            List<BinnedData> binnedDataList, String chr) throws InterruptedException {

        // Loop through zoom levels
        for (BinnedData binnedData : binnedDataList) {

            checkForInterrupt();

            List<? extends Bin> bins = binnedData.getBins();

            if (bins.size() > 0) {

                // Arrays for the statistics, 1 element per bin.
                //float[] median = new float[bins.size()];
                //float[] percent10 = new float[bins.size()];
                //float[] percent90 = new float[bins.size()];
                //float[] min = new float[bins.size()];
                //float[] max = new float[bins.size()];
                float[] mean = new float[bins.size()];
                //float[] stdDev = new float[bins.size()];

                float[] data = getDataForChromosome(chr);

                for (int b = 0; b < bins.size(); b++) {
                    Bin bin = bins.get(b);
                    float[] binData = getDataForBin(data, bin);

                    if (binData == null) {
                                mean[b] = Float.NaN;
                    } else {
                        //
                        DataStatistics stats = ProcessingUtils.computeStats(binData);
                        mean[b] = computeMean(binData);
                    }
                }

                String zoomName = "z" + binnedData.getZoomLevel();
                int zoomGroup = openZoomGroup(chr, zoomName, chrGroup);

                recordStats("mean", mean, zoomGroup, chr, zoomName);
                writer.closeGroup(zoomGroup);
            }
        }
    }
    
    private float computeMean(float [] data) {
        float sum = 0;
        for(int i=0; i<data.length; i++) {
            sum += data[i];
        }
        return sum / data.length;
    }

    protected void recordStats(String type, float[] data, int zoomGroup, String chr, String zoomName) {
        int dataArrayDataset = openDataset(chr, zoomName, type, data.length, zoomGroup);

        writer.writeDataRow(dataArrayDataset, 0, data, data.length);
        writer.closeDataset(dataArrayDataset);

        float median = (float) ProcessingUtils.computeMedian(data);
        int medianDataset = openVectorDataset(chr, zoomName, "median." + type, zoomGroup);

        writer.writeDataValue(medianDataset, 0, median);
        writer.closeDataset(medianDataset);
    }

    protected abstract float[] getDataForBin(float[] data, Bin bin);

    /**
     *
     * @param zoomLevel
     * @param maxLength
     * @param nTiles
     * @param bins
     * @param binSize
     * @return
     */
    private BinnedData computeBinnedData(int zoomLevel, double chrLength, 
            int nTiles, List<Bin> bins, double binSize) {

        // Find tile breaks.  Could possibly do this n loop above.
        int[] tileBoundaries = new int[nTiles];
        int binNumber = 0;
        double tileLength = chrLength / nTiles;

        for (int tileNumber = 0; (tileNumber < nTiles - 1) && (binNumber < bins.size()); tileNumber++) {

            // Find end binIndex for this tile.  Using a linear search, might
            // need to use a faster scheme.
            while (bins.get(binNumber).getStart() < (tileNumber + 1) * tileLength) {
                binNumber++;

                if (binNumber == bins.size()) {
                    break;
                }
            }

            tileBoundaries[tileNumber] = binNumber;
        }

        // Boundary for last tile number is end
        tileBoundaries[nTiles - 1] = bins.size() - 1;

        BinnedData binInfo = new BinnedData(zoomLevel, binSize, bins, tileBoundaries);

        // Compute the mean, data, and 90th percentile of occupied bins.
        float mean = 0.0F;
        float max = 0.0F;
        DoubleArrayList percentiles = new DoubleArrayList(3);

        percentiles.add(0.1);
        percentiles.add(0.5);
        percentiles.add(0.90);

        DoubleQuantileFinder qf = QuantileFinderFactory.newDoubleQuantileFinder(
                false, Long.MAX_VALUE, 0.0010, 1.0E-4,
                percentiles.size(), null);

        for (Bin bin : bins) {
            int count = bin.getFeatureCount();

            mean += count;
            max = Math.max(max, count);
            qf.add(count);
        }

        binInfo.setMeanCount(mean / bins.size());
        binInfo.setMaxCount(max);

        DoubleArrayList quantiles = qf.quantileElements(percentiles);

        binInfo.setPercentile10(quantiles.get(0));
        binInfo.setMedianCount(quantiles.get(1));
        binInfo.setPercentile90(quantiles.get(2));

        return binInfo;
    }

    public void setZoomMax(int zoomMax) {
        this.zoomMax = zoomMax;
    }
}
