/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.aggs.kstest;

import java.util.Arrays;
import java.util.Collections;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.DoubleStream;
import java.util.stream.IntStream;
import java.util.stream.LongStream;
import java.util.stream.Stream;
import org.apache.commons.math3.stat.inference.KolmogorovSmirnovTest;
import org.elasticsearch.common.Randomness;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.SuppressForbidden;
import org.elasticsearch.search.aggregations.AggregationExecutionException;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.InternalAggregation;
import org.elasticsearch.search.aggregations.pipeline.SiblingPipelineAggregator;
import org.elasticsearch.xpack.ml.aggs.DoubleArray;
import org.elasticsearch.xpack.ml.aggs.MlAggsHelper;
import org.elasticsearch.xpack.ml.aggs.kstest.Alternative;
import org.elasticsearch.xpack.ml.aggs.kstest.InternalKSTestAggregation;
import org.elasticsearch.xpack.ml.aggs.kstest.SamplingMethod;

public class BucketCountKSTestAggregator
extends SiblingPipelineAggregator {
    private static final int NUM_ITERATIONS = 20;
    private static final int MINIMUM_NUMBER_OF_DOCS = 23;
    private static final KolmogorovSmirnovTest KOLMOGOROV_SMIRNOV_TEST = new KolmogorovSmirnovTest();
    private final double[] fractions;
    private final EnumSet<Alternative> alternatives;
    private final SamplingMethod samplingMethod;

    public BucketCountKSTestAggregator(String name, @Nullable double[] fractions, EnumSet<Alternative> alternatives, String bucketsPath, SamplingMethod samplingMethod, Map<String, Object> metadata) {
        super(name, new String[]{bucketsPath}, metadata);
        this.fractions = fractions;
        this.alternatives = alternatives;
        this.samplingMethod = samplingMethod;
    }

    static Map<String, Double> ksTest(double[] fractions, MlAggsHelper.DoubleBucketValues bucketsValue, EnumSet<Alternative> alternatives, SamplingMethod samplingMethod) {
        long bucketsCountSum = LongStream.of(bucketsValue.getDocCounts()).sum();
        int nSamples = Math.min(bucketsCountSum > Integer.MAX_VALUE ? Integer.MAX_VALUE : (int)bucketsCountSum, samplingMethod.cdfPoints().length);
        double[] fX = DoubleArray.cumulativeSum(bucketsValue.getValues());
        if (fX[fX.length - 1] <= 0.0) {
            return alternatives.stream().map(Alternative::toString).collect(Collectors.toMap(Function.identity(), a -> Double.NaN));
        }
        DoubleArray.divMut(fX, fX[fX.length - 1]);
        double[] fY = DoubleArray.cumulativeSum(fractions);
        if (fY[fY.length - 1] <= 0.0) {
            return alternatives.stream().map(Alternative::toString).collect(Collectors.toMap(Function.identity(), a -> Double.NaN));
        }
        if (nSamples < 23) {
            return alternatives.stream().map(Alternative::toString).collect(Collectors.toMap(Function.identity(), a -> Double.NaN));
        }
        DoubleArray.divMut(fY, fY[fY.length - 1]);
        double[] monotonic = LongStream.range(1L, fY.length + 1).mapToDouble(Double::valueOf).toArray();
        if (nSamples >= samplingMethod.cdfPoints().length) {
            return BucketCountKSTestAggregator.ksTest(nSamples, fX, fY, monotonic, samplingMethod.cdfPoints(), alternatives);
        }
        Map result = Stream.generate(() -> BucketCountKSTestAggregator.ksTest(nSamples, fX, fY, monotonic, samplingMethod.cdfPoints(), alternatives)).limit(20L).reduce(new HashMap(), (memo, v) -> {
            v.forEach((alternative, ksTestValue) -> memo.merge(alternative, ksTestValue, (v1, v2) -> v1 + (v2 == 0.0 ? 0.0 : Math.log(v2))));
            return memo;
        });
        alternatives.stream().map(Alternative::toString).forEach(a -> result.put(a, Math.min(1.0, Math.max(Math.exp((Double)result.get(a) / 20.0), 0.0))));
        return result;
    }

    private static Map<String, Double> ksTest(int nSamples, double[] fX, double[] fY, double[] monotonic, double[] cdfPoints, EnumSet<Alternative> alternatives) {
        int[] samples = BucketCountKSTestAggregator.sampleOf(cdfPoints.length, nSamples);
        double[] x = new double[samples.length];
        double[] y = new double[samples.length];
        int index = 0;
        for (int i : samples) {
            double f = cdfPoints[i];
            x[index] = BucketCountKSTestAggregator.interpolate(fX, monotonic, f);
            y[index] = BucketCountKSTestAggregator.interpolate(fY, monotonic, f);
            ++index;
        }
        Arrays.sort(x);
        Arrays.sort(y);
        HashMap<String, Double> results = new HashMap<String, Double>();
        double zConstant = (double)x.length * (double)y.length / (double)(x.length + y.length);
        double continuityConstant = (double)(x.length + 2 * y.length) / Math.sqrt(x.length * y.length * (x.length + y.length));
        block5: for (Alternative alternative : alternatives) {
            double statistic = BucketCountKSTestAggregator.sidedStatistic(x, y, alternative);
            switch (alternative) {
                case GREATER: 
                case LESS: {
                    double z = Math.sqrt(zConstant) * statistic;
                    double unBounded = Math.exp(-2.0 * Math.pow(z, 2.0) - 2.0 * z * continuityConstant / 3.0);
                    results.put(alternative.toString(), Math.min(1.0, Math.max(unBounded, 0.0)));
                    continue block5;
                }
                case TWO_SIDED: {
                    results.put(alternative.toString(), KOLMOGOROV_SMIRNOV_TEST.exactP(statistic, x.length, y.length, false));
                    continue block5;
                }
            }
            throw new AggregationExecutionException("unexpected alternative [" + (Object)((Object)alternative) + "]");
        }
        return results;
    }

    private static int[] sampleOf(int i, int n) {
        if (i <= 0) {
            throw new IllegalArgumentException("cannot create a range from a non-positive number");
        }
        if (n >= i) {
            return IntStream.range(0, i).toArray();
        }
        List toSample = IntStream.range(0, i).boxed().collect(Collectors.toList());
        Collections.shuffle(toSample, Randomness.get());
        return toSample.subList(0, n).stream().mapToInt(Integer::intValue).toArray();
    }

    private static double interpolate(double[] xs, double[] fx, double x) {
        int i = Math.min(BucketCountKSTestAggregator.bisectRight(xs, x), xs.length - 1);
        return ((x - xs[i - 1]) * fx[i] + (xs[i] - x) * fx[i - 1]) / (xs[i] - xs[i - 1]);
    }

    private static int bisectRight(double[] xs, double x) {
        int pos = Arrays.binarySearch(xs, x);
        if (pos < 0) {
            pos = BucketCountKSTestAggregator.nonNegative(pos) - 1;
        }
        if (pos <= 0) {
            return 1;
        }
        while (pos < xs.length && xs[pos] <= x) {
            ++pos;
        }
        return pos;
    }

    @SuppressForbidden(reason="Math#abs(int) is safe here as we protect against MIN_VALUE")
    private static int nonNegative(int x) {
        if (x == Integer.MIN_VALUE) {
            throw new AggregationExecutionException("unexpected value while interpolating sampled values");
        }
        return Math.abs(x);
    }

    private static double sidedStatistic(double[] xa, double[] xb, Alternative alternative) {
        int ia = xa[0] < xb[0] ? 1 : 0;
        int ib = xa[0] < xb[0] ? 0 : 1;
        double t = 0.0;
        while (ia < xa.length && ib < xb.length) {
            t = Math.max(t, BucketCountKSTestAggregator.sidedKSStat((double)ia / (double)xa.length, (double)ib / (double)xb.length, alternative));
            if (xa[ia] < xb[ib]) {
                ++ia;
                continue;
            }
            if (xb[ib] < xa[ia]) {
                ++ib;
                continue;
            }
            ++ia;
            ++ib;
        }
        t = Math.max(t, BucketCountKSTestAggregator.sidedKSStat((double)ia / (double)xa.length, (double)ib / (double)xb.length, alternative));
        return alternative == Alternative.LESS ? Math.min(Math.max(t, 0.0), 1.0) : t;
    }

    private static double sidedKSStat(double a, double b, Alternative alternative) {
        switch (alternative) {
            case LESS: {
                return Math.max(b - a, 0.0);
            }
            case GREATER: {
                return Math.max(a - b, 0.0);
            }
        }
        return Math.abs(b - a);
    }

    public InternalAggregation doReduce(Aggregations aggregations, InternalAggregation.ReduceContext context) {
        Optional<MlAggsHelper.DoubleBucketValues> maybeBucketsValue = MlAggsHelper.extractDoubleBucketedValues(this.bucketsPaths()[0], aggregations).map(bucketValue -> {
            double[] values = new double[bucketValue.getValues().length + 1];
            long[] counts = new long[bucketValue.getDocCounts().length + 1];
            values[0] = 0.0;
            counts[0] = 0L;
            System.arraycopy(bucketValue.getValues(), 0, values, 1, values.length - 1);
            System.arraycopy(bucketValue.getDocCounts(), 0, counts, 1, counts.length - 1);
            return new MlAggsHelper.DoubleBucketValues(counts, values);
        });
        if (!maybeBucketsValue.isPresent()) {
            throw new AggregationExecutionException("unable to find valid bucket values in bucket path [" + this.bucketsPaths()[0] + "] for agg [" + this.name() + "]");
        }
        MlAggsHelper.DoubleBucketValues bucketsValue = maybeBucketsValue.get();
        double[] fractionsArr = this.fractions == null ? DoubleStream.concat(DoubleStream.of(0.0), Stream.generate(() -> 1.0 / (double)(bucketsValue.getDocCounts().length - 1)).limit(bucketsValue.getDocCounts().length - 1).mapToDouble(Double::valueOf)).toArray() : DoubleStream.concat(DoubleStream.of(0.0), Arrays.stream(this.fractions)).toArray();
        return new InternalKSTestAggregation(this.name(), this.metadata(), BucketCountKSTestAggregator.ksTest(fractionsArr, bucketsValue, this.alternatives, this.samplingMethod));
    }
}

