/*
 * Decompiled with CFR 0.152.
 */
package org.apache.commons.statistics.inference;

import java.util.Arrays;
import java.util.Objects;
import java.util.function.Consumer;
import java.util.function.DoublePredicate;
import java.util.function.DoubleUnaryOperator;
import java.util.function.IntToDoubleFunction;
import org.apache.commons.numbers.combinatorics.LogBinomialCoefficient;
import org.apache.commons.statistics.inference.AlternativeHypothesis;
import org.apache.commons.statistics.inference.Arguments;
import org.apache.commons.statistics.inference.BaseSignificanceResult;
import org.apache.commons.statistics.inference.BracketFinder;
import org.apache.commons.statistics.inference.BrentOptimizer;
import org.apache.commons.statistics.inference.Hypergeom;
import org.apache.commons.statistics.inference.InferenceException;
import org.apache.commons.statistics.inference.Searches;

public final class UnconditionedExactTest {
    private static final UnconditionedExactTest DEFAULT = new UnconditionedExactTest(AlternativeHypothesis.TWO_SIDED, Method.BOSCHLOO, 33, true);
    private static final double LOWER_BOUND = 1.0E-5;
    private static final double SOLVER_RELATIVE_EPS = 1.4901161193847656E-8;
    private static final double INC_FRACTION = 0.125;
    private static final int MAX_CANDIDATES = 3;
    private static final double MINIMA_EPS = 0.02;
    private static final int MAX_TABLES = 0x7FFFFFF7;
    private static final String COLUMN_SUM = "Column sum";
    private final AlternativeHypothesis alternative;
    private final Method method;
    private final int points;
    private final boolean optimize;

    private UnconditionedExactTest(AlternativeHypothesis alternative, Method method, int points, boolean optimize) {
        this.alternative = alternative;
        this.method = method;
        this.points = points;
        this.optimize = optimize;
    }

    public static UnconditionedExactTest withDefaults() {
        return DEFAULT;
    }

    public UnconditionedExactTest with(AlternativeHypothesis v) {
        return new UnconditionedExactTest(Objects.requireNonNull(v), this.method, this.points, this.optimize);
    }

    public UnconditionedExactTest with(Method v) {
        return new UnconditionedExactTest(this.alternative, Objects.requireNonNull(v), this.points, this.optimize);
    }

    public UnconditionedExactTest withInitialPoints(int v) {
        if (v <= 1) {
            throw new InferenceException("%s < %s", v, 2);
        }
        return new UnconditionedExactTest(this.alternative, this.method, v, this.optimize);
    }

    public UnconditionedExactTest withOptimize(boolean v) {
        return new UnconditionedExactTest(this.alternative, this.method, this.points, v);
    }

    public double statistic(int[][] table) {
        UnconditionedExactTest.checkTable(table);
        int a = table[0][0];
        int b = table[0][1];
        int c = table[1][0];
        int d = table[1][1];
        int m = a + c;
        int n = b + d;
        switch (this.method) {
            case Z_POOLED: {
                return UnconditionedExactTest.statisticZ(a, b, m, n, true);
            }
            case Z_UNPOOLED: {
                return UnconditionedExactTest.statisticZ(a, b, m, n, false);
            }
            case BOSCHLOO: {
                return this.statisticBoschloo(a, b, m, n);
            }
        }
        throw new IllegalStateException(String.valueOf((Object)this.method));
    }

    public Result test(int[][] table) {
        UnconditionedExactTest.checkTable(table);
        int a = table[0][0];
        int b = table[0][1];
        int c = table[1][0];
        int d = table[1][1];
        int m = a + c;
        int n = b + d;
        XYList tableList = new XYList(m, n);
        double statistic = this.findExtremeTables(a, b, tableList);
        if (tableList.isEmpty() || tableList.isFull()) {
            return new Result(statistic);
        }
        double[] opt = this.computePValue(tableList);
        return new Result(statistic, opt[0], opt[1]);
    }

    private double findExtremeTables(int a, int b, XYList tableList) {
        int m = tableList.getMaxX();
        int n = tableList.getMaxY();
        switch (this.method) {
            case Z_POOLED: {
                return this.findExtremeTablesZ(a, b, m, n, true, tableList);
            }
            case Z_UNPOOLED: {
                return this.findExtremeTablesZ(a, b, m, n, false, tableList);
            }
            case BOSCHLOO: {
                return this.findExtremeTablesBoschloo(a, b, m, n, tableList);
            }
        }
        throw new IllegalStateException(String.valueOf((Object)this.method));
    }

    private static double statisticZ(int a, int b, int m, int n, boolean pooled) {
        double p0 = (double)a / (double)m;
        double p1 = (double)b / (double)n;
        if (p0 != p1) {
            double variance;
            if (pooled) {
                double p = (double)(a + b) / (double)(m + n);
                variance = p * (1.0 - p) * (1.0 / (double)m + 1.0 / (double)n);
            } else {
                variance = p0 * (1.0 - p0) / (double)m + p1 * (1.0 - p1) / (double)n;
            }
            return (p0 - p1) / Math.sqrt(variance);
        }
        return 0.0;
    }

    private double findExtremeTablesZ(int a, int b, int m, int n, boolean pooled, XYList tableList) {
        DoublePredicate test;
        double statistic = UnconditionedExactTest.statisticZ(a, b, m, n, pooled);
        if (this.alternative == AlternativeHypothesis.GREATER_THAN) {
            test = z -> z >= statistic;
        } else if (this.alternative == AlternativeHypothesis.LESS_THAN) {
            test = z -> z <= statistic;
        } else {
            if (statistic == 0.0) {
                return 0.0;
            }
            double za = Math.abs(statistic);
            test = z -> Math.abs(z) >= za;
        }
        double mn = (double)m + (double)n;
        double norm = 1.0 / (double)m + 1.0 / (double)n;
        for (int i = 0; i <= m; ++i) {
            double p0 = (double)i / (double)m;
            double vp0 = p0 * (1.0 - p0) / (double)m;
            for (int j = 0; j <= n; ++j) {
                double z2;
                double p1 = (double)j / (double)n;
                if (p0 == p1) {
                    z2 = 0.0;
                } else {
                    double variance;
                    if (pooled) {
                        double p = (double)(i + j) / mn;
                        variance = p * (1.0 - p) * norm;
                    } else {
                        variance = vp0 + p1 * (1.0 - p1) / (double)n;
                    }
                    z2 = (p0 - p1) / Math.sqrt(variance);
                }
                if (!test.test(z2)) continue;
                tableList.add(i, j);
            }
        }
        return statistic;
    }

    private double statisticBoschloo(int a, int b, int m, int n) {
        int nn = m + n;
        int k = a + b;
        Hypergeom dist = new Hypergeom(nn, k, m);
        if (this.alternative == AlternativeHypothesis.GREATER_THAN) {
            return dist.sf(a - 1);
        }
        if (this.alternative == AlternativeHypothesis.LESS_THAN) {
            return dist.cdf(a);
        }
        return UnconditionedExactTest.statisticBoschlooTwoSided(dist, a);
    }

    private static double statisticBoschlooTwoSided(Hypergeom distribution, int k) {
        double pk = distribution.pmf(k);
        int m1 = distribution.getLowerMode();
        int m2 = distribution.getUpperMode();
        if (k < m1) {
            int i = Searches.searchDescending(m2, distribution.getSupportUpperBound(), pk, distribution::pmf);
            return distribution.cdf(k) + distribution.sf(i - 1);
        }
        if (k > m2) {
            int i = Searches.searchAscending(distribution.getSupportLowerBound(), m1, pk, distribution::pmf);
            return distribution.cdf(i) + distribution.sf(k - 1);
        }
        double pm = distribution.pmf(k == m1 ? m2 : m1);
        return pm > pk ? 1.0 - pm : 1.0;
    }

    private double findExtremeTablesBoschloo(int a, int b, int m, int n, XYList tableList) {
        double statistic = this.statisticBoschloo(a, b, m, n);
        BoschlooStatistic func = this.alternative == AlternativeHypothesis.GREATER_THAN ? (dist, x) -> dist.sf(x - 1) : (this.alternative == AlternativeHypothesis.LESS_THAN ? Hypergeom::cdf : UnconditionedExactTest::statisticBoschlooTwoSided);
        int mn = m + n;
        for (int k = 0; k <= mn; ++k) {
            Hypergeom dist2 = new Hypergeom(mn, k, m);
            int lo = dist2.getSupportLowerBound();
            int hi = dist2.getSupportUpperBound();
            for (int i = lo; i <= hi; ++i) {
                if (!(func.value(dist2, i) <= statistic)) continue;
                tableList.add(i, k - i);
            }
        }
        return statistic;
    }

    private double[] computePValue(XYList tableList) {
        DoubleUnaryOperator func = UnconditionedExactTest.createBinomialModel(tableList);
        Candidates minima = new Candidates(3, 0.02);
        int n = this.points - 1;
        double inc = 0.99998 / (double)n;
        double v2 = 0.0;
        double v3 = func.applyAsDouble(1.0E-5);
        double px = 1.0E-5;
        for (int i = 1; i < n; ++i) {
            double x = 1.0E-5 + (double)i * inc;
            double v1 = v2;
            v2 = v3;
            v3 = func.applyAsDouble(x);
            this.addCandidate(minima, v1, v2, v3, px);
            px = x;
        }
        double x = 0.99999;
        double vn = func.applyAsDouble(0.99999);
        this.addCandidate(minima, v2, v3, vn, px);
        this.addCandidate(minima, v3, vn, 0.0, 0.99999);
        double[] min = minima.getMinimum();
        if (this.optimize && min[1] > -1.0) {
            BrentOptimizer opt = new BrentOptimizer(1.4901161193847656E-8, Double.MIN_VALUE);
            BracketFinder bf = new BracketFinder();
            minima.forEach(candidate -> {
                double fa;
                double a = candidate[0];
                double b = a - Math.copySign(inc * 0.125, a - 0.5);
                if (bf.search(func, a, b, 0.0, 1.0)) {
                    BrentOptimizer.PointValuePair p = opt.optimize(func, bf.getLo(), bf.getHi(), bf.getMid(), bf.getFMid());
                    a = p.getPoint();
                    fa = p.getValue();
                } else {
                    a = bf.getMid();
                    fa = bf.getFMid();
                }
                if (fa < min[1]) {
                    min[0] = a;
                    min[1] = fa;
                }
            });
        }
        min[1] = -Math.max(-1.0, min[1]);
        return min;
    }

    private static DoubleUnaryOperator createBinomialModel(XYList tableList) {
        IntToDoubleFunction binomN;
        IntToDoubleFunction binomM;
        int m = tableList.getMaxX();
        int n = tableList.getMaxY();
        int mn = m + n;
        double[] c = new double[tableList.size()];
        int[] ij = new int[tableList.size()];
        int width = tableList.getWidth();
        if (tableList.size() < mn) {
            binomM = k -> LogBinomialCoefficient.value((int)m, (int)k);
            binomN = k -> LogBinomialCoefficient.value((int)n, (int)k);
        } else {
            binomM = UnconditionedExactTest.createLogBinomialCoefficients(m);
            binomN = m == n ? binomM : UnconditionedExactTest.createLogBinomialCoefficients(n);
        }
        int flag = 0;
        int j = 0;
        for (int i = 0; i < c.length; ++i) {
            int y;
            int index = tableList.get(i);
            int x = index % width;
            int xy = x + (y = index / width);
            if (xy == 0) {
                flag |= 1;
                continue;
            }
            if (xy == mn) {
                flag |= 2;
                continue;
            }
            ij[j] = xy;
            c[j] = binomM.applyAsDouble(x) + binomN.applyAsDouble(y);
            ++j;
        }
        int size = j;
        boolean ij0 = flag & true;
        boolean ijmn = (flag & 2) != 0;
        return pi -> {
            double logp = Math.log(pi);
            double log1mp = Math.log1p(-pi);
            double sum = 0.0;
            for (int i = 0; i < size; ++i) {
                sum += Math.exp((double)ij[i] * logp + (double)(mn - ij[i]) * log1mp + c[i]);
            }
            if (ij0) {
                sum += Math.exp((double)mn * log1mp);
            }
            if (ijmn) {
                sum += Math.exp((double)mn * logp);
            }
            return -sum;
        };
    }

    private static IntToDoubleFunction createLogBinomialCoefficients(int n) {
        double[] binom = new double[n + 1];
        int j = n - 1;
        for (int i = 1; i <= j; ++i, --j) {
            binom[i] = binom[j] = LogBinomialCoefficient.value((int)n, (int)i);
        }
        return k -> binom[k];
    }

    private void addCandidate(Candidates minima, double v1, double v2, double v3, double x2) {
        double min;
        double d = min = v1 < v3 ? v1 : v3;
        if (min < v2) {
            return;
        }
        minima.add(x2, v2);
    }

    private static void checkTable(int[][] table) {
        Arguments.checkTable(table);
        int a = table[0][0];
        int c = table[1][0];
        int m = a + c;
        if (m == 0) {
            throw new InferenceException("%s[%s] is zero", COLUMN_SUM, 0);
        }
        int b = table[0][1];
        int d = table[1][1];
        int n = b + d;
        if (n == 0) {
            throw new InferenceException("%s[%s] is zero", COLUMN_SUM, 1);
        }
        long size = ((long)m + 1L) * ((long)n + 1L);
        if (size > 0x7FFFFFF7L) {
            throw new InferenceException("%s > %s", size, 0x7FFFFFF7);
        }
    }

    private static interface BoschlooStatistic {
        public double value(Hypergeom var1, int var2);
    }

    static class Candidates {
        private final int max;
        private final double eps;
        private double[][] data;
        private int size;
        private double min = Double.POSITIVE_INFINITY;
        private double threshold = Double.POSITIVE_INFINITY;

        Candidates(int max, double eps) {
            this.max = Math.max(1, max);
            this.eps = eps;
            this.data = new double[Math.min(this.max, 4)][];
        }

        void add(double k, double v) {
            if (Double.isNaN(v)) {
                if (this.size == 0) {
                    this.data[this.size++] = new double[]{k, v};
                }
                return;
            }
            if (v > this.threshold) {
                return;
            }
            if (v < this.min) {
                this.min = v;
                this.threshold = v + Math.abs(v) * this.eps;
                int s = 0;
                for (int i = 0; i < this.size; ++i) {
                    if (!(this.data[i][1] <= this.threshold)) continue;
                    this.data[s++] = this.data[i];
                }
                this.size = s;
            }
            this.addPair(k, v);
        }

        private void addPair(double k, double v) {
            if (this.size == this.data.length) {
                if (this.size == this.max) {
                    this.replaceWorst(k, v);
                    return;
                }
                this.data = (double[][])Arrays.copyOfRange(this.data, 0, (int)Math.min((long)this.max, (long)this.size * 2L));
            }
            this.data[this.size++] = new double[]{k, v};
        }

        private void replaceWorst(double k, double v) {
            double[] worst = this.data[0];
            for (int i = 1; i < this.size; ++i) {
                if (!(worst[1] < this.data[i][1])) continue;
                worst = this.data[i];
            }
            worst[0] = k;
            worst[1] = v;
        }

        double[] getMinimum() {
            double[] best = this.data[0];
            for (int i = 1; i < this.size; ++i) {
                if (!(best[1] > this.data[i][1])) continue;
                best = this.data[i];
            }
            return best;
        }

        void forEach(Consumer<double[]> action) {
            for (int i = 0; i < this.size; ++i) {
                action.accept(this.data[i]);
            }
        }
    }

    private static class XYList {
        private final int max;
        private final int width;
        private int size;
        private int[] data = new int[10];

        XYList(int maxx, int maxy) {
            this.width = maxx + 1;
            this.max = this.width * (maxy + 1);
        }

        int getWidth() {
            return this.width;
        }

        int getMaxX() {
            return this.width - 1;
        }

        int getMaxY() {
            return this.max / this.width - 1;
        }

        void add(int x, int y) {
            if (this.size == this.data.length) {
                this.data = Arrays.copyOf(this.data, (int)Math.min((long)this.max, (long)this.size * 2L));
            }
            this.data[this.size++] = this.width * y + x;
        }

        int get(int index) {
            return this.data[index];
        }

        int size() {
            return this.size;
        }

        boolean isEmpty() {
            return this.size == 0;
        }

        boolean isFull() {
            return this.size == this.max;
        }
    }

    public static final class Result
    extends BaseSignificanceResult {
        private final double pi;

        Result(double statistic) {
            super(statistic, 1.0);
            this.pi = 0.5;
        }

        Result(double statistic, double pi, double p) {
            super(statistic, p);
            this.pi = pi;
        }

        @Override
        public double getStatistic() {
            return super.getStatistic();
        }

        public double getNuisanceParameter() {
            return this.pi;
        }
    }

    public static enum Method {
        Z_POOLED,
        Z_UNPOOLED,
        BOSCHLOO;

    }
}

