package smile.classification;

import java.util.Arrays;
import java.util.Properties;
import java.util.stream.IntStream;
import smile.data.CategoricalEncoder;
import smile.data.DataFrame;
import smile.data.formula.Formula;
import smile.math.BFGS;
import smile.math.DifferentiableMultivariateFunction;
import smile.math.MathEx;
import smile.util.IntSet;
import smile.validation.ModelSelection;

/* loaded from: input_file:smile/classification/LogisticRegression.class */
public abstract class LogisticRegression implements SoftClassifier<double[]>, OnlineClassifier<double[]> {
    private static final long serialVersionUID = 2;
    int p;
    int k;
    double L;
    double lambda;
    double eta = 0.1d;
    final IntSet labels;

    /* loaded from: input_file:smile/classification/LogisticRegression$Binomial.class */
    public static class Binomial extends LogisticRegression {
        private double[] w;

        public Binomial(double[] dArr, double d, double d2, IntSet intSet) {
            super(dArr.length - 1, d, d2, intSet);
            this.w = dArr;
        }

        public double[] coefficients() {
            return this.w;
        }

        @Override // smile.classification.Classifier
        public int predict(double[] dArr) {
            return this.labels.valueOf(1.0d / (1.0d + Math.exp(-LogisticRegression.dot(dArr, this.w))) < 0.5d ? 0 : 1);
        }

        @Override // smile.classification.SoftClassifier
        public int predict(double[] dArr, double[] dArr2) {
            if (dArr.length != this.p) {
                throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", Integer.valueOf(dArr.length), Integer.valueOf(this.p)));
            }
            if (dArr2.length != this.k) {
                throw new IllegalArgumentException(String.format("Invalid posteriori vector size: %d, expected: %d", Integer.valueOf(dArr2.length), Integer.valueOf(this.k)));
            }
            double exp = 1.0d / (1.0d + Math.exp(-LogisticRegression.dot(dArr, this.w)));
            dArr2[0] = 1.0d - exp;
            dArr2[1] = exp;
            return this.labels.valueOf(exp < 0.5d ? 0 : 1);
        }

        @Override // smile.classification.OnlineClassifier
        public void update(double[] dArr, int i) {
            if (dArr.length != this.p) {
                throw new IllegalArgumentException("Invalid input vector size: " + dArr.length);
            }
            double indexOf = this.labels.indexOf(i) - MathEx.logistic(LogisticRegression.dot(dArr, this.w));
            double[] dArr2 = this.w;
            int i2 = this.p;
            dArr2[i2] = dArr2[i2] + (this.eta * indexOf);
            for (int i3 = 0; i3 < this.p; i3++) {
                double[] dArr3 = this.w;
                int i4 = i3;
                dArr3[i4] = dArr3[i4] + (this.eta * indexOf * dArr[i3]);
            }
            if (this.lambda > 0.0d) {
                for (int i5 = 0; i5 < this.p; i5++) {
                    double[] dArr4 = this.w;
                    int i6 = i5;
                    dArr4[i6] = dArr4[i6] - ((this.eta * this.lambda) * this.w[i5]);
                }
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:smile/classification/LogisticRegression$BinomialObjective.class */
    public static class BinomialObjective implements DifferentiableMultivariateFunction {
        double[][] x;
        int[] y;
        int p;
        double lambda;
        int partitionSize = Integer.valueOf(System.getProperty("smile.data.partition.size", "1000")).intValue();
        int partitions;
        double[][] gradients;

        BinomialObjective(double[][] dArr, int[] iArr, double d) {
            this.x = dArr;
            this.y = iArr;
            this.lambda = d;
            this.p = dArr[0].length;
            this.partitions = (dArr.length / this.partitionSize) + (dArr.length % this.partitionSize == 0 ? 0 : 1);
            this.gradients = new double[this.partitions][this.p + 1];
        }

        public double f(double[] dArr) {
            double sum = IntStream.range(0, this.x.length).parallel().mapToDouble(i -> {
                double dot = LogisticRegression.dot(this.x[i], dArr);
                return MathEx.log1pe(dot) - (this.y[i] * dot);
            }).sum();
            if (this.lambda > 0.0d) {
                double d = 0.0d;
                for (int i2 = 0; i2 < this.p; i2++) {
                    d += dArr[i2] * dArr[i2];
                }
                sum += 0.5d * this.lambda * d;
            }
            return sum;
        }

        public double g(double[] dArr, double[] dArr2) {
            double sum = IntStream.range(0, this.partitions).parallel().mapToDouble(i -> {
                double[] dArr3 = this.gradients[i];
                Arrays.fill(dArr3, 0.0d);
                int i = i * this.partitionSize;
                int i2 = (i + 1) * this.partitionSize;
                if (i2 > this.x.length) {
                    i2 = this.x.length;
                }
                return IntStream.range(i, i2).sequential().mapToDouble(i3 -> {
                    double[] dArr4 = this.x[i3];
                    double dot = LogisticRegression.dot(dArr4, dArr);
                    double logistic = this.y[i3] - MathEx.logistic(dot);
                    for (int i3 = 0; i3 < this.p; i3++) {
                        int i4 = i3;
                        dArr3[i4] = dArr3[i4] - (logistic * dArr4[i3]);
                    }
                    int i5 = this.p;
                    dArr3[i5] = dArr3[i5] - logistic;
                    return MathEx.log1pe(dot) - (this.y[i3] * dot);
                }).sum();
            }).sum();
            Arrays.fill(dArr2, 0.0d);
            for (double[] dArr3 : this.gradients) {
                for (int i2 = 0; i2 < dArr2.length; i2++) {
                    int i3 = i2;
                    dArr2[i3] = dArr2[i3] + dArr3[i2];
                }
            }
            if (this.lambda > 0.0d) {
                double d = 0.0d;
                for (int i4 = 0; i4 < this.p; i4++) {
                    d += dArr[i4] * dArr[i4];
                    int i5 = i4;
                    dArr2[i5] = dArr2[i5] + (this.lambda * dArr[i4]);
                }
                sum += 0.5d * this.lambda * d;
            }
            return sum;
        }
    }

    /* loaded from: input_file:smile/classification/LogisticRegression$Multinomial.class */
    public static class Multinomial extends LogisticRegression {
        private double[][] w;

        public Multinomial(double[][] dArr, double d, double d2, IntSet intSet) {
            super(dArr[0].length - 1, d, d2, intSet);
            this.w = dArr;
        }

        public double[][] coefficients() {
            return this.w;
        }

        @Override // smile.classification.Classifier
        public int predict(double[] dArr) {
            return predict(dArr, new double[this.k]);
        }

        @Override // smile.classification.SoftClassifier
        public int predict(double[] dArr, double[] dArr2) {
            if (dArr.length != this.p) {
                throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", Integer.valueOf(dArr.length), Integer.valueOf(this.p)));
            }
            if (dArr2.length != this.k) {
                throw new IllegalArgumentException(String.format("Invalid posteriori vector size: %d, expected: %d", Integer.valueOf(dArr2.length), Integer.valueOf(this.k)));
            }
            dArr2[this.k - 1] = 0.0d;
            for (int i = 0; i < this.k - 1; i++) {
                dArr2[i] = LogisticRegression.dot(dArr, this.w[i]);
            }
            MathEx.softmax(dArr2);
            return this.labels.valueOf(MathEx.whichMax(dArr2));
        }

        @Override // smile.classification.OnlineClassifier
        public void update(double[] dArr, int i) {
            if (dArr.length != this.p) {
                throw new IllegalArgumentException("Invalid input vector size: " + dArr.length);
            }
            int indexOf = this.labels.indexOf(i);
            double[] dArr2 = new double[this.k];
            for (int i2 = 0; i2 < this.k - 1; i2++) {
                dArr2[i2] = LogisticRegression.dot(dArr, this.w[i2]);
            }
            MathEx.softmax(dArr2);
            int i3 = 0;
            while (i3 < this.k - 1) {
                double[] dArr3 = this.w[i3];
                double d = (indexOf == i3 ? 1.0d : 0.0d) - dArr2[i3];
                int i4 = this.p;
                dArr3[i4] = dArr3[i4] + (this.eta * d);
                for (int i5 = 0; i5 < this.p; i5++) {
                    int i6 = i5;
                    dArr3[i6] = dArr3[i6] + (this.eta * d * dArr[i5]);
                }
                if (this.lambda > 0.0d) {
                    for (int i7 = 0; i7 < this.p; i7++) {
                        int i8 = i7;
                        dArr3[i8] = dArr3[i8] - ((this.eta * this.lambda) * dArr3[i7]);
                    }
                }
                i3++;
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:smile/classification/LogisticRegression$MultinomialObjective.class */
    public static class MultinomialObjective implements DifferentiableMultivariateFunction {
        double[][] x;
        int[] y;
        int k;
        int p;
        double lambda;
        int partitionSize = Integer.valueOf(System.getProperty("smile.data.partition.size", "1000")).intValue();
        int partitions;
        double[][] gradients;
        double[][] posterioris;

        MultinomialObjective(double[][] dArr, int[] iArr, int i, double d) {
            this.x = dArr;
            this.y = iArr;
            this.k = i;
            this.lambda = d;
            this.p = dArr[0].length;
            this.partitions = (dArr.length / this.partitionSize) + (dArr.length % this.partitionSize == 0 ? 0 : 1);
            this.gradients = new double[this.partitions][(i - 1) * (this.p + 1)];
            this.posterioris = new double[this.partitions][i];
        }

        public double f(double[] dArr) {
            double sum = IntStream.range(0, this.partitions).parallel().mapToDouble(i -> {
                double[] dArr2 = this.posterioris[i];
                int i = i * this.partitionSize;
                int i2 = (i + 1) * this.partitionSize;
                if (i2 > this.x.length) {
                    i2 = this.x.length;
                }
                return IntStream.range(i, i2).sequential().mapToDouble(i3 -> {
                    dArr2[this.k - 1] = 0.0d;
                    for (int i3 = 0; i3 < this.k - 1; i3++) {
                        dArr2[i3] = LogisticRegression.dot(this.x[i3], dArr, i3, this.p);
                    }
                    MathEx.softmax(dArr2);
                    return -MathEx.log(dArr2[this.y[i3]]);
                }).sum();
            }).sum();
            if (this.lambda > 0.0d) {
                double d = 0.0d;
                for (int i2 = 0; i2 < this.k - 1; i2++) {
                    int i3 = i2 * (this.p + 1);
                    for (int i4 = 0; i4 < this.p; i4++) {
                        double d2 = dArr[i3 + i4];
                        d += d2 * d2;
                    }
                }
                sum += 0.5d * this.lambda * d;
            }
            return sum;
        }

        public double g(double[] dArr, double[] dArr2) {
            double sum = IntStream.range(0, this.partitions).parallel().mapToDouble(i -> {
                double[] dArr3 = this.posterioris[i];
                double[] dArr4 = this.gradients[i];
                Arrays.fill(dArr4, 0.0d);
                int i = i * this.partitionSize;
                int i2 = (i + 1) * this.partitionSize;
                if (i2 > this.x.length) {
                    i2 = this.x.length;
                }
                return IntStream.range(i, i2).sequential().mapToDouble(i3 -> {
                    dArr3[this.k - 1] = 0.0d;
                    for (int i3 = 0; i3 < this.k - 1; i3++) {
                        dArr3[i3] = LogisticRegression.dot(this.x[i3], dArr, i3, this.p);
                    }
                    MathEx.softmax(dArr3);
                    int i4 = 0;
                    while (i4 < this.k - 1) {
                        double d = (this.y[i3] == i4 ? 1.0d : 0.0d) - dArr3[i4];
                        int i5 = i4 * (this.p + 1);
                        for (int i6 = 0; i6 < this.p; i6++) {
                            int i7 = i5 + i6;
                            dArr4[i7] = dArr4[i7] - (d * this.x[i3][i6]);
                        }
                        int i8 = i5 + this.p;
                        dArr4[i8] = dArr4[i8] - d;
                        i4++;
                    }
                    return -MathEx.log(dArr3[this.y[i3]]);
                }).sum();
            }).sum();
            Arrays.fill(dArr2, 0.0d);
            for (double[] dArr3 : this.gradients) {
                for (int i2 = 0; i2 < dArr2.length; i2++) {
                    int i3 = i2;
                    dArr2[i3] = dArr2[i3] + dArr3[i2];
                }
            }
            if (this.lambda > 0.0d) {
                double d = 0.0d;
                for (int i4 = 0; i4 < this.k - 1; i4++) {
                    int i5 = i4 * (this.p + 1);
                    for (int i6 = 0; i6 < this.p; i6++) {
                        double d2 = dArr[i5 + i6];
                        d += d2 * d2;
                        int i7 = i5 + i6;
                        dArr2[i7] = dArr2[i7] + (this.lambda * d2);
                    }
                }
                sum += 0.5d * this.lambda * d;
            }
            return sum;
        }
    }

    public LogisticRegression(int i, double d, double d2, IntSet intSet) {
        this.k = intSet.size();
        this.p = i;
        this.L = d;
        this.lambda = d2;
        this.labels = intSet;
    }

    public static Binomial binomial(Formula formula, DataFrame dataFrame) {
        return binomial(formula, dataFrame, new Properties());
    }

    public static Binomial binomial(Formula formula, DataFrame dataFrame, Properties properties) {
        return binomial(formula.x(dataFrame).toArray(false, CategoricalEncoder.DUMMY), formula.y(dataFrame).toIntArray(), properties);
    }

    public static Binomial binomial(double[][] dArr, int[] iArr) {
        return binomial(dArr, iArr, new Properties());
    }

    public static Binomial binomial(double[][] dArr, int[] iArr, Properties properties) {
        return binomial(dArr, iArr, Double.valueOf(properties.getProperty("smile.logit.lambda", "0.1")).doubleValue(), Double.valueOf(properties.getProperty("smile.logit.tolerance", "1E-5")).doubleValue(), Integer.valueOf(properties.getProperty("smile.logit.max.iterations", "500")).intValue());
    }

    public static Binomial binomial(double[][] dArr, int[] iArr, double d, double d2, int i) {
        if (dArr.length != iArr.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", Integer.valueOf(dArr.length), Integer.valueOf(iArr.length)));
        }
        if (d < 0.0d) {
            throw new IllegalArgumentException("Invalid regularization factor: " + d);
        }
        if (d2 <= 0.0d) {
            throw new IllegalArgumentException("Invalid tolerance: " + d2);
        }
        if (i <= 0) {
            throw new IllegalArgumentException("Invalid maximum number of iterations: " + i);
        }
        int length = dArr[0].length;
        ClassLabels fit = ClassLabels.fit(iArr);
        int i2 = fit.k;
        int[] iArr2 = fit.y;
        if (i2 != 2) {
            throw new IllegalArgumentException("Fits binomial model on multi-class data.");
        }
        BinomialObjective binomialObjective = new BinomialObjective(dArr, iArr2, d);
        double[] dArr2 = new double[length + 1];
        Binomial binomial = new Binomial(dArr2, -BFGS.minimize(binomialObjective, 5, dArr2, d2, i), d, fit.labels);
        binomial.setLearningRate(0.1d / dArr.length);
        return binomial;
    }

    public static Multinomial multinomial(Formula formula, DataFrame dataFrame) {
        return multinomial(formula, dataFrame, new Properties());
    }

    public static Multinomial multinomial(Formula formula, DataFrame dataFrame, Properties properties) {
        return multinomial(formula.x(dataFrame).toArray(false, CategoricalEncoder.DUMMY), formula.y(dataFrame).toIntArray(), properties);
    }

    public static Multinomial multinomial(double[][] dArr, int[] iArr) {
        return multinomial(dArr, iArr, new Properties());
    }

    public static Multinomial multinomial(double[][] dArr, int[] iArr, Properties properties) {
        double doubleValue = Double.valueOf(properties.getProperty("smile.logit.lambda", "0.1")).doubleValue();
        Boolean.valueOf(properties.getProperty("smile.logit.standard.error", "true")).booleanValue();
        return multinomial(dArr, iArr, doubleValue, Double.valueOf(properties.getProperty("smile.logit.tolerance", "1E-5")).doubleValue(), Integer.valueOf(properties.getProperty("smile.logit.max.iterations", "500")).intValue());
    }

    public static Multinomial multinomial(double[][] dArr, int[] iArr, double d, double d2, int i) {
        if (dArr.length != iArr.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", Integer.valueOf(dArr.length), Integer.valueOf(iArr.length)));
        }
        if (d < 0.0d) {
            throw new IllegalArgumentException("Invalid regularization factor: " + d);
        }
        if (d2 <= 0.0d) {
            throw new IllegalArgumentException("Invalid tolerance: " + d2);
        }
        if (i <= 0) {
            throw new IllegalArgumentException("Invalid maximum number of iterations: " + i);
        }
        int length = dArr[0].length;
        ClassLabels fit = ClassLabels.fit(iArr);
        int i2 = fit.k;
        int[] iArr2 = fit.y;
        if (i2 <= 2) {
            throw new IllegalArgumentException("Fits multinomial model on binary class data.");
        }
        MultinomialObjective multinomialObjective = new MultinomialObjective(dArr, iArr2, i2, d);
        double[] dArr2 = new double[(i2 - 1) * (length + 1)];
        double d3 = -BFGS.minimize(multinomialObjective, 5, dArr2, d2, i);
        double[][] dArr3 = new double[i2 - 1][length + 1];
        int i3 = 0;
        for (int i4 = 0; i4 < i2 - 1; i4++) {
            int i5 = 0;
            while (i5 <= length) {
                dArr3[i4][i5] = dArr2[i3];
                i5++;
                i3++;
            }
        }
        Multinomial multinomial = new Multinomial(dArr3, d3, d, fit.labels);
        multinomial.setLearningRate(0.1d / dArr.length);
        return multinomial;
    }

    public static LogisticRegression fit(Formula formula, DataFrame dataFrame) {
        return fit(formula, dataFrame, new Properties());
    }

    public static LogisticRegression fit(Formula formula, DataFrame dataFrame, Properties properties) {
        return fit(formula.x(dataFrame).toArray(false, CategoricalEncoder.DUMMY), formula.y(dataFrame).toIntArray(), properties);
    }

    public static LogisticRegression fit(double[][] dArr, int[] iArr) {
        return fit(dArr, iArr, new Properties());
    }

    public static LogisticRegression fit(double[][] dArr, int[] iArr, Properties properties) {
        return fit(dArr, iArr, Double.valueOf(properties.getProperty("smile.logistic.lambda", "0.1")).doubleValue(), Double.valueOf(properties.getProperty("smile.logistic.tolerance", "1E-5")).doubleValue(), Integer.valueOf(properties.getProperty("smile.logistic.max.iterations", "500")).intValue());
    }

    public static LogisticRegression fit(double[][] dArr, int[] iArr, double d, double d2, int i) {
        return ClassLabels.fit(iArr).k == 2 ? binomial(dArr, iArr, d, d2, i) : multinomial(dArr, iArr, d, d2, i);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static double dot(double[] dArr, double[] dArr2) {
        double d = dArr2[dArr.length];
        for (int i = 0; i < dArr.length; i++) {
            d += dArr[i] * dArr2[i];
        }
        return d;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static double dot(double[] dArr, double[] dArr2, int i, int i2) {
        int i3 = i * (i2 + 1);
        double d = dArr2[i3 + i2];
        for (int i4 = 0; i4 < i2; i4++) {
            d += dArr[i4] * dArr2[i3 + i4];
        }
        return d;
    }

    public void setLearningRate(double d) {
        if (d <= 0.0d) {
            throw new IllegalArgumentException("Invalid learning rate: " + d);
        }
        this.eta = d;
    }

    public double getLearningRate() {
        return this.eta;
    }

    public double loglikelihood() {
        return this.L;
    }

    public double AIC() {
        return ModelSelection.AIC(this.L, (this.k - 1) * (this.p + 1));
    }
}
