/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;

import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import org.apache.lucene.util.RamUsageEstimator;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.LenientlyParsedOutputAggregator;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.StrictlyParsedOutputAggregator;
import org.elasticsearch.xpack.core.ml.inference.utils.Statistics;

public class LogisticRegression
implements StrictlyParsedOutputAggregator,
LenientlyParsedOutputAggregator {
    public static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(LogisticRegression.class);
    public static final ParseField NAME = new ParseField("logistic_regression", new String[0]);
    public static final ParseField WEIGHTS = new ParseField("weights", new String[0]);
    private static final ConstructingObjectParser<LogisticRegression, Void> LENIENT_PARSER = LogisticRegression.createParser(true);
    private static final ConstructingObjectParser<LogisticRegression, Void> STRICT_PARSER = LogisticRegression.createParser(false);
    private final double[] weights;

    private static ConstructingObjectParser<LogisticRegression, Void> createParser(boolean lenient) {
        ConstructingObjectParser parser = new ConstructingObjectParser(NAME.getPreferredName(), lenient, a -> new LogisticRegression((List)a[0]));
        parser.declareDoubleArray(ConstructingObjectParser.optionalConstructorArg(), WEIGHTS);
        return parser;
    }

    public static LogisticRegression fromXContentStrict(XContentParser parser) {
        return (LogisticRegression)STRICT_PARSER.apply(parser, null);
    }

    public static LogisticRegression fromXContentLenient(XContentParser parser) {
        return (LogisticRegression)LENIENT_PARSER.apply(parser, null);
    }

    LogisticRegression() {
        this((List<Double>)null);
    }

    private LogisticRegression(List<Double> weights) {
        this(weights == null ? null : weights.stream().mapToDouble(Double::valueOf).toArray());
    }

    public LogisticRegression(double[] weights) {
        this.weights = weights;
    }

    public LogisticRegression(StreamInput in) throws IOException {
        this.weights = (double[])(in.readBoolean() ? in.readDoubleArray() : null);
    }

    @Override
    public Integer expectedValueSize() {
        return this.weights == null ? null : Integer.valueOf(this.weights.length);
    }

    @Override
    public double[] processValues(double[][] values) {
        Objects.requireNonNull(values, "values must not be null");
        if (this.weights != null && values.length != this.weights.length) {
            throw new IllegalArgumentException("values must be the same length as weights.");
        }
        double[] sumOnAxis1 = new double[values[0].length];
        for (int j = 0; j < values.length; ++j) {
            double[] value = values[j];
            double weight = this.weights == null ? 1.0 : this.weights[j];
            for (int i = 0; i < value.length; ++i) {
                if (i >= sumOnAxis1.length) {
                    throw new IllegalArgumentException("value entries must have the same dimensions");
                }
                int n = i;
                sumOnAxis1[n] = sumOnAxis1[n] + value[i] * weight;
            }
        }
        if (sumOnAxis1.length > 1) {
            return Statistics.softMax(sumOnAxis1);
        }
        double probOfClassOne = Statistics.sigmoid(sumOnAxis1[0]);
        assert (0.0 <= probOfClassOne && probOfClassOne <= 1.0);
        return new double[]{1.0 - probOfClassOne, probOfClassOne};
    }

    @Override
    public double aggregate(double[] values) {
        Objects.requireNonNull(values, "values must not be null");
        int bestValue = 0;
        double bestProb = Double.NEGATIVE_INFINITY;
        for (int i = 0; i < values.length; ++i) {
            if (!(values[i] > bestProb)) continue;
            bestProb = values[i];
            bestValue = i;
        }
        return bestValue;
    }

    @Override
    public String getName() {
        return NAME.getPreferredName();
    }

    @Override
    public boolean compatibleWith(TargetType targetType) {
        return true;
    }

    public String getWriteableName() {
        return NAME.getPreferredName();
    }

    public void writeTo(StreamOutput out) throws IOException {
        out.writeBoolean(this.weights != null);
        if (this.weights != null) {
            out.writeDoubleArray(this.weights);
        }
    }

    public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
        builder.startObject();
        if (this.weights != null) {
            builder.field(WEIGHTS.getPreferredName(), (Object)this.weights);
        }
        builder.endObject();
        return builder;
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        LogisticRegression that = (LogisticRegression)o;
        return Arrays.equals(this.weights, that.weights);
    }

    public int hashCode() {
        return Arrays.hashCode(this.weights);
    }

    public long ramBytesUsed() {
        long weightSize = this.weights == null ? 0L : RamUsageEstimator.sizeOf((double[])this.weights);
        return SHALLOW_SIZE + weightSize;
    }
}

