/*
 * Decompiled with CFR 0.152.
 */
package org.broadinstitute.gatk.utils;

import cern.jet.math.Arithmetic;
import cern.jet.random.Normal;
import com.google.java.contract.Ensures;
import com.google.java.contract.Requires;
import java.io.Serializable;
import java.util.Comparator;
import java.util.TreeSet;
import org.apache.commons.math.MathException;
import org.apache.commons.math.distribution.NormalDistribution;
import org.apache.commons.math.distribution.NormalDistributionImpl;
import org.broadinstitute.gatk.engine.GenomeAnalysisEngine;
import org.broadinstitute.gatk.utils.collections.Pair;
import org.broadinstitute.gatk.utils.exceptions.GATKException;

public class MannWhitneyU {
    private static Normal STANDARD_NORMAL = new Normal(0.0, 1.0, null);
    private static NormalDistribution APACHE_NORMAL = new NormalDistributionImpl(0.0, 1.0, 0.01);
    private static double LNSQRT2PI = Math.log(Math.sqrt(Math.PI * 2));
    private TreeSet<Pair<Number, USet>> observations;
    private int sizeSet1;
    private int sizeSet2;
    private ExactMode exactMode;

    public MannWhitneyU(ExactMode mode, boolean dither) {
        this.observations = dither ? new TreeSet<Pair<Number, USet>>(new DitheringComparator()) : new TreeSet<Pair<Number, USet>>(new NumberedPairComparator());
        this.sizeSet1 = 0;
        this.sizeSet2 = 0;
        this.exactMode = mode;
    }

    public MannWhitneyU() {
        this(ExactMode.POINT, true);
    }

    public MannWhitneyU(boolean dither) {
        this(ExactMode.POINT, dither);
    }

    public MannWhitneyU(ExactMode mode) {
        this(mode, true);
    }

    public void add(Number n2, USet set) {
        this.observations.add(new Pair<Number, USet>(n2, set));
        if (set == USet.SET1) {
            ++this.sizeSet1;
        } else {
            ++this.sizeSet2;
        }
    }

    public Pair<Long, Long> getR1R2() {
        long u1 = MannWhitneyU.calculateOneSidedU(this.observations, USet.SET1);
        long n1 = this.sizeSet1 * (this.sizeSet1 + 1) / 2;
        long r1 = u1 + n1;
        long n2 = this.sizeSet2 * (this.sizeSet2 + 1) / 2;
        long u2 = n1 * n2 - u1;
        long r2 = u2 + n2;
        return new Pair<Long, Long>(r1, r2);
    }

    @Requires(value={"lessThanOther != null"})
    @Ensures(value={"validateObservations(observations) || Double.isNaN(result.getFirst())", "result != null", "! Double.isInfinite(result.getFirst())", "! Double.isInfinite(result.getSecond())"})
    public Pair<Double, Double> runOneSidedTest(USet lessThanOther) {
        int m2;
        long u2 = MannWhitneyU.calculateOneSidedU(this.observations, lessThanOther);
        int n2 = lessThanOther == USet.SET1 ? this.sizeSet1 : this.sizeSet2;
        int n3 = m2 = lessThanOther == USet.SET1 ? this.sizeSet2 : this.sizeSet1;
        if (n2 == 0 || m2 == 0) {
            return new Pair<Double, Double>(Double.NaN, Double.NaN);
        }
        return MannWhitneyU.calculateP(n2, m2, u2, false, this.exactMode);
    }

    @Ensures(value={"result != null", "! Double.isInfinite(result.getFirst())", "! Double.isInfinite(result.getSecond())"})
    public Pair<Double, Double> runTwoSidedTest() {
        int m2;
        Pair<Long, USet> uPair = MannWhitneyU.calculateTwoSidedU(this.observations);
        long u2 = (Long)uPair.first;
        int n2 = uPair.second == USet.SET1 ? this.sizeSet1 : this.sizeSet2;
        int n3 = m2 = uPair.second == USet.SET1 ? this.sizeSet2 : this.sizeSet1;
        if (n2 == 0 || m2 == 0) {
            return new Pair<Double, Double>(Double.NaN, Double.NaN);
        }
        return MannWhitneyU.calculateP(n2, m2, u2, true, this.exactMode);
    }

    @Requires(value={"m > 0", "n > 0"})
    @Ensures(value={"result != null", "! Double.isInfinite(result.getFirst())", "! Double.isInfinite(result.getSecond())"})
    protected static Pair<Double, Double> calculateP(int n2, int m2, long u2, boolean twoSided, ExactMode exactMode) {
        Pair<Double, Double> zandP = n2 > 8 && m2 > 8 ? MannWhitneyU.calculatePNormalApproximation(n2, m2, u2, twoSided) : (n2 > 5 && m2 > 7 ? MannWhitneyU.calculatePNormalApproximation(n2, m2, u2, twoSided) : (n2 > 8 || m2 > 8 ? MannWhitneyU.calculatePFromTable(n2, m2, u2, twoSided) : MannWhitneyU.calculatePRecursively(n2, m2, u2, twoSided, exactMode)));
        return zandP;
    }

    public static Pair<Double, Double> calculatePFromTable(int n2, int m2, long u2, boolean twoSided) {
        return MannWhitneyU.calculatePNormalApproximation(n2, m2, u2, twoSided);
    }

    @Requires(value={"m > 0", "n > 0"})
    @Ensures(value={"result != null", "! Double.isInfinite(result.getFirst())", "! Double.isInfinite(result.getSecond())"})
    public static Pair<Double, Double> calculatePNormalApproximation(int n2, int m2, long u2, boolean twoSided) {
        double z = MannWhitneyU.getZApprox(n2, m2, u2);
        if (twoSided) {
            return new Pair<Double, Double>(z, 2.0 * (z < 0.0 ? STANDARD_NORMAL.cdf(z) : 1.0 - STANDARD_NORMAL.cdf(z)));
        }
        return new Pair<Double, Double>(z, STANDARD_NORMAL.cdf(z));
    }

    @Requires(value={"m > 0", "n > 0"})
    @Ensures(value={"! Double.isNaN(result)", "! Double.isInfinite(result)"})
    private static double getZApprox(int n2, int m2, long u2) {
        double mean = ((double)((long)m2 * (long)n2) + 1.0) / 2.0;
        double var = (double)((long)n2 * (long)m2) * ((double)(n2 + m2) + 1.0) / 12.0;
        double z = ((double)u2 - mean) / Math.sqrt(var);
        return z;
    }

    public static double calculatePUniformApproximation(int n2, int m2, long u2) {
        long R = u2 + (long)(n2 * (n2 + 1) / 2);
        double a2 = Math.sqrt(m2 * (n2 + m2 + 1));
        double b2 = (double)n2 / 2.0 * (1.0 - Math.sqrt((n2 + m2 + 1) / m2));
        double z = b2 + (double)R / a2;
        if (z < 0.0) {
            return 1.0;
        }
        if (z > (double)n2) {
            return 0.0;
        }
        if (z > (double)n2 / 2.0) {
            return 1.0 - 1.0 / Arithmetic.factorial((int)n2) * MannWhitneyU.uniformSumHelper(z, (int)Math.floor(z), n2, 0);
        }
        return 1.0 / Arithmetic.factorial((int)n2) * MannWhitneyU.uniformSumHelper(z, (int)Math.floor(z), n2, 0);
    }

    private static double uniformSumHelper(double z, int m2, int n2, int k2) {
        if (k2 > m2) {
            return 0.0;
        }
        int coef = k2 % 2 == 0 ? 1 : -1;
        return (double)coef * Arithmetic.binomial((long)n2, (long)k2) * Math.pow(z - (double)k2, n2) + MannWhitneyU.uniformSumHelper(z, m2, n2, k2 + 1);
    }

    @Requires(value={"observed != null", "observed.size() > 0"})
    @Ensures(value={"result != null", "result.first > 0"})
    public static Pair<Long, USet> calculateTwoSidedU(TreeSet<Pair<Number, USet>> observed) {
        int set1SeenSoFar = 0;
        int set2SeenSoFar = 0;
        long uSet1DomSet2 = 0L;
        long uSet2DomSet1 = 0L;
        USet previous = null;
        for (Pair<Number, USet> dataPoint : observed) {
            if (dataPoint.second == USet.SET1) {
                ++set1SeenSoFar;
            } else {
                ++set2SeenSoFar;
            }
            if (previous != null) {
                if (dataPoint.second == USet.SET1) {
                    uSet2DomSet1 += (long)set2SeenSoFar;
                } else {
                    uSet1DomSet2 += (long)set1SeenSoFar;
                }
            }
            previous = (USet)((Object)dataPoint.second);
        }
        return uSet1DomSet2 < uSet2DomSet1 ? new Pair<Long, USet>(uSet1DomSet2, USet.SET1) : new Pair<Long, USet>(uSet2DomSet1, USet.SET2);
    }

    @Requires(value={"observed != null", "dominator != null", "observed.size() > 0"})
    @Ensures(value={"result >= 0"})
    public static long calculateOneSidedU(TreeSet<Pair<Number, USet>> observed, USet dominator) {
        long otherBeforeDominator = 0L;
        int otherSeenSoFar = 0;
        for (Pair<Number, USet> dataPoint : observed) {
            if (dataPoint.second != dominator) {
                ++otherSeenSoFar;
                continue;
            }
            otherBeforeDominator += (long)otherSeenSoFar;
        }
        return otherBeforeDominator;
    }

    @Requires(value={"m > 0", "n > 0", "u >= 0"})
    @Ensures(value={"result != null", "! Double.isInfinite(result.getFirst())", "! Double.isInfinite(result.getSecond())"})
    public static Pair<Double, Double> calculatePRecursively(int n2, int m2, long u2, boolean twoSided, ExactMode mode) {
        double z;
        if (m2 > 8 && n2 > 5) {
            throw new GATKException(String.format("Please use the appropriate (normal or sum of uniform) approximation. Values n: %d, m: %d", n2, m2));
        }
        double p2 = mode == ExactMode.POINT ? MannWhitneyU.cpr(n2, m2, u2) : MannWhitneyU.cumulativeCPR(n2, m2, u2);
        try {
            double sd;
            z = mode == ExactMode.CUMULATIVE ? APACHE_NORMAL.inverseCumulativeProbability(p2) : (p2 > 1.0 / Math.sqrt((sd = Math.sqrt((1.0 + 1.0 / (double)(1 + n2 + m2)) * (double)(n2 * m2) * (1.0 + (double)n2 + (double)m2) / 12.0)) * sd * 2.0 * Math.PI) ? 0.0 : (u2 >= (long)(n2 * m2 / 2) ? Math.sqrt(-2.0 * (Math.log(sd) + Math.log(p2) + LNSQRT2PI)) : -Math.sqrt(-2.0 * (Math.log(sd) + Math.log(p2) + LNSQRT2PI))));
        }
        catch (MathException me) {
            throw new GATKException("A math exception occurred in inverting the probability", me);
        }
        return new Pair<Double, Double>(z, twoSided ? 2.0 * p2 : p2);
    }

    protected static double calculatePRecursivelyDoNotCheckValuesEvenThoughItIsSlow(int n2, int m2, long u2) {
        return MannWhitneyU.cpr(n2, m2, u2);
    }

    protected static long countSequences(int n2, int m2, long u2) {
        if (u2 < 0L) {
            return 0L;
        }
        if (m2 == 0 || n2 == 0) {
            return u2 == 0L ? 1L : 0L;
        }
        return MannWhitneyU.countSequences(n2 - 1, m2, u2 - (long)m2) + MannWhitneyU.countSequences(n2, m2 - 1, u2);
    }

    private static double cpr(int n2, int m2, long u2) {
        if (u2 < 0L) {
            return 0.0;
        }
        if (m2 == 0 || n2 == 0) {
            return u2 == 0L ? 1.0 : 0.0;
        }
        return (double)n2 / (double)(n2 + m2) * MannWhitneyU.cpr(n2 - 1, m2, u2 - (long)m2) + (double)m2 / (double)(n2 + m2) * MannWhitneyU.cpr(n2, m2 - 1, u2);
    }

    private static double cumulativeCPR(int n2, int m2, long u2) {
        double p2 = 0.0;
        long uSym = u2 <= (long)(n2 * m2 / 2) ? u2 : (long)n2 * (long)m2 - u2;
        for (long uu = 0L; uu < uSym; ++uu) {
            p2 += MannWhitneyU.cpr(n2, m2, uu);
        }
        return u2 <= (long)(n2 * m2 / 2) ? p2 : 1.0 - p2;
    }

    protected TreeSet<Pair<Number, USet>> getObservations() {
        return this.observations;
    }

    protected Pair<Integer, Integer> getSetSizes() {
        return new Pair<Integer, Integer>(this.sizeSet1, this.sizeSet2);
    }

    protected static boolean validateObservations(TreeSet<Pair<Number, USet>> tree) {
        boolean seen1 = false;
        boolean seen2 = false;
        boolean seenInvalid = false;
        for (Pair<Number, USet> p2 : tree) {
            if (!seen1 && p2.getSecond() == USet.SET1) {
                seen1 = true;
            }
            if (!seen2 && p2.getSecond() == USet.SET2) {
                seen2 = true;
            }
            if (!Double.isNaN(p2.getFirst().doubleValue()) && !Double.isInfinite(p2.getFirst().doubleValue())) continue;
            seenInvalid = true;
        }
        return !seenInvalid && seen1 && seen2;
    }

    public static enum ExactMode {
        POINT,
        CUMULATIVE;

    }

    public static enum USet {
        SET1,
        SET2;

    }

    private static class NumberedPairComparator
    implements Comparator<Pair<Number, USet>>,
    Serializable {
        @Override
        public boolean equals(Object other) {
            return false;
        }

        @Override
        public int compare(Pair<Number, USet> left, Pair<Number, USet> right) {
            return Double.compare(((Number)left.first).doubleValue(), ((Number)right.first).doubleValue());
        }
    }

    private static class DitheringComparator
    implements Comparator<Pair<Number, USet>>,
    Serializable {
        @Override
        public boolean equals(Object other) {
            return false;
        }

        @Override
        public int compare(Pair<Number, USet> left, Pair<Number, USet> right) {
            double comp = Double.compare(((Number)left.first).doubleValue(), ((Number)right.first).doubleValue());
            if (comp > 0.0) {
                return 1;
            }
            if (comp < 0.0) {
                return -1;
            }
            return GenomeAnalysisEngine.getRandomGenerator().nextBoolean() ? -1 : 1;
        }
    }
}

