package smile.classification;

import java.util.Arrays;
import java.util.Properties;
import java.util.stream.IntStream;
import smile.math.BFGS;
import smile.math.DifferentiableMultivariateFunction;
import smile.math.MathEx;
import smile.util.IntSet;

/* loaded from: input_file:smile/classification/Maxent.class */
public class Maxent implements SoftClassifier<int[]>, OnlineClassifier<int[]> {
    private static final long serialVersionUID = 2;
    private int p;
    private int k;
    private double L;
    private double[] w;
    private double[][] W;
    private double eta;
    private final IntSet labels;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:smile/classification/Maxent$BinaryObjectiveFunction.class */
    public static class BinaryObjectiveFunction implements DifferentiableMultivariateFunction {
        int[][] x;
        int[] y;
        int p;
        double lambda;
        int partitionSize = Integer.valueOf(System.getProperty("smile.data.partition.size", "1000")).intValue();
        int partitions;
        double[][] gradients;

        BinaryObjectiveFunction(int[][] iArr, int[] iArr2, int i, double d) {
            this.x = iArr;
            this.y = iArr2;
            this.p = i;
            this.lambda = d;
            this.partitions = (iArr.length / this.partitionSize) + (iArr.length % this.partitionSize == 0 ? 0 : 1);
            this.gradients = new double[this.partitions][i + 1];
        }

        public double f(double[] dArr) {
            double sum = IntStream.range(0, this.x.length).parallel().mapToDouble(i -> {
                double dot = Maxent.dot(this.x[i], dArr);
                return MathEx.log1pe(dot) - (this.y[i] * dot);
            }).sum();
            if (this.lambda > 0.0d) {
                double d = 0.0d;
                for (int i2 = 0; i2 < this.p; i2++) {
                    d += dArr[i2] * dArr[i2];
                }
                sum += 0.5d * this.lambda * d;
            }
            return sum;
        }

        public double g(double[] dArr, double[] dArr2) {
            double sum = IntStream.range(0, this.partitions).parallel().mapToDouble(i -> {
                double[] dArr3 = this.gradients[i];
                Arrays.fill(dArr3, 0.0d);
                int i = i * this.partitionSize;
                int i2 = (i + 1) * this.partitionSize;
                if (i2 > this.x.length) {
                    i2 = this.x.length;
                }
                return IntStream.range(i, i2).sequential().mapToDouble(i3 -> {
                    double dot = Maxent.dot(this.x[i3], dArr);
                    double logistic = this.y[i3] - MathEx.logistic(dot);
                    for (int i3 : this.x[i3]) {
                        dArr3[i3] = dArr3[i3] - logistic;
                    }
                    int i4 = this.p;
                    dArr3[i4] = dArr3[i4] - logistic;
                    return MathEx.log1pe(dot) - (this.y[i3] * dot);
                }).sum();
            }).sum();
            Arrays.fill(dArr2, 0.0d);
            for (double[] dArr3 : this.gradients) {
                for (int i2 = 0; i2 < dArr2.length; i2++) {
                    int i3 = i2;
                    dArr2[i3] = dArr2[i3] + dArr3[i2];
                }
            }
            if (this.lambda > 0.0d) {
                double d = 0.0d;
                for (int i4 = 0; i4 < this.p; i4++) {
                    d += dArr[i4] * dArr[i4];
                    int i5 = i4;
                    dArr2[i5] = dArr2[i5] + (this.lambda * dArr[i4]);
                }
                sum += 0.5d * this.lambda * d;
            }
            return sum;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:smile/classification/Maxent$MultiClassObjectiveFunction.class */
    public static class MultiClassObjectiveFunction implements DifferentiableMultivariateFunction {
        int[][] x;
        int[] y;
        int k;
        int p;
        double lambda;
        int partitionSize = Integer.valueOf(System.getProperty("smile.data.partition.size", "1000")).intValue();
        int partitions;
        double[][] gradients;
        double[][] posterioris;

        MultiClassObjectiveFunction(int[][] iArr, int[] iArr2, int i, int i2, double d) {
            this.x = iArr;
            this.y = iArr2;
            this.k = i;
            this.p = i2;
            this.lambda = d;
            this.partitions = (iArr.length / this.partitionSize) + (iArr.length % this.partitionSize == 0 ? 0 : 1);
            this.gradients = new double[this.partitions][(i - 1) * (i2 + 1)];
            this.posterioris = new double[this.partitions][i];
        }

        public double f(double[] dArr) {
            double sum = IntStream.range(0, this.partitions).parallel().mapToDouble(i -> {
                double[] dArr2 = this.posterioris[i];
                int i = i * this.partitionSize;
                int i2 = (i + 1) * this.partitionSize;
                if (i2 > this.x.length) {
                    i2 = this.x.length;
                }
                return IntStream.range(i, i2).sequential().mapToDouble(i3 -> {
                    dArr2[this.k - 1] = 0.0d;
                    for (int i3 = 0; i3 < this.k - 1; i3++) {
                        dArr2[i3] = Maxent.dot(this.x[i3], dArr, i3, this.p);
                    }
                    MathEx.softmax(dArr2);
                    return -MathEx.log(dArr2[this.y[i3]]);
                }).sum();
            }).sum();
            if (this.lambda > 0.0d) {
                double d = 0.0d;
                for (int i2 = 0; i2 < this.k - 1; i2++) {
                    int i3 = i2 * (this.p + 1);
                    for (int i4 = 0; i4 < this.p; i4++) {
                        double d2 = dArr[i3 + i4];
                        d += d2 * d2;
                    }
                }
                sum += 0.5d * this.lambda * d;
            }
            return sum;
        }

        public double g(double[] dArr, double[] dArr2) {
            double sum = IntStream.range(0, this.partitions).parallel().mapToDouble(i -> {
                double[] dArr3 = this.posterioris[i];
                double[] dArr4 = this.gradients[i];
                Arrays.fill(dArr4, 0.0d);
                int i = i * this.partitionSize;
                int i2 = (i + 1) * this.partitionSize;
                if (i2 > this.x.length) {
                    i2 = this.x.length;
                }
                return IntStream.range(i, i2).sequential().mapToDouble(i3 -> {
                    dArr3[this.k - 1] = 0.0d;
                    for (int i3 = 0; i3 < this.k - 1; i3++) {
                        dArr3[i3] = Maxent.dot(this.x[i3], dArr, i3, this.p);
                    }
                    MathEx.softmax(dArr3);
                    int i4 = 0;
                    while (i4 < this.k - 1) {
                        double d = (this.y[i3] == i4 ? 1.0d : 0.0d) - dArr3[i4];
                        int i5 = i4 * (this.p + 1);
                        for (int i6 : this.x[i3]) {
                            int i7 = i5 + i6;
                            dArr4[i7] = dArr4[i7] - d;
                        }
                        int i8 = i5 + this.p;
                        dArr4[i8] = dArr4[i8] - d;
                        i4++;
                    }
                    return -MathEx.log(dArr3[this.y[i3]]);
                }).sum();
            }).sum();
            Arrays.fill(dArr2, 0.0d);
            for (double[] dArr3 : this.gradients) {
                for (int i2 = 0; i2 < dArr2.length; i2++) {
                    int i3 = i2;
                    dArr2[i3] = dArr2[i3] + dArr3[i2];
                }
            }
            if (this.lambda > 0.0d) {
                double d = 0.0d;
                for (int i4 = 0; i4 < this.k - 1; i4++) {
                    int i5 = i4 * (this.p + 1);
                    for (int i6 = 0; i6 < this.p; i6++) {
                        double d2 = dArr[i5 + i6];
                        d += d2 * d2;
                        int i7 = i5 + i6;
                        dArr2[i7] = dArr2[i7] + (this.lambda * d2);
                    }
                }
                sum += 0.5d * this.lambda * d;
            }
            return sum;
        }
    }

    public Maxent(double d, double[] dArr) {
        this(d, dArr, IntSet.of(2));
    }

    public Maxent(double d, double[] dArr, IntSet intSet) {
        this.eta = 0.1d;
        this.p = dArr.length - 1;
        this.k = 2;
        this.L = d;
        this.w = dArr;
        this.labels = intSet;
    }

    public Maxent(double d, double[][] dArr) {
        this(d, dArr, IntSet.of(dArr.length + 1));
    }

    public Maxent(double d, double[][] dArr, IntSet intSet) {
        this.eta = 0.1d;
        this.p = dArr[0].length - 1;
        this.k = dArr.length + 1;
        this.L = d;
        this.W = dArr;
        this.labels = intSet;
    }

    public static Maxent fit(int i, int[][] iArr, int[] iArr2) {
        return fit(i, iArr, iArr2, new Properties());
    }

    public static Maxent fit(int i, int[][] iArr, int[] iArr2, Properties properties) {
        return fit(i, iArr, iArr2, Double.valueOf(properties.getProperty("smile.maxent.lambda", "0.1")).doubleValue(), Double.valueOf(properties.getProperty("smile.maxent.tolerance", "1E-5")).doubleValue(), Integer.valueOf(properties.getProperty("smile.maxent.max.iterations", "500")).intValue());
    }

    public static Maxent fit(int i, int[][] iArr, int[] iArr2, double d, double d2, int i2) {
        Maxent maxent;
        if (iArr.length != iArr2.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", Integer.valueOf(iArr.length), Integer.valueOf(iArr2.length)));
        }
        if (i < 0) {
            throw new IllegalArgumentException("Invalid dimension: " + i);
        }
        if (d < 0.0d) {
            throw new IllegalArgumentException("Invalid regularization factor: " + d);
        }
        if (d2 <= 0.0d) {
            throw new IllegalArgumentException("Invalid tolerance: " + d2);
        }
        if (i2 <= 0) {
            throw new IllegalArgumentException("Invalid maximum number of iterations: " + i2);
        }
        ClassLabels fit = ClassLabels.fit(iArr2);
        int i3 = fit.k;
        int[] iArr3 = fit.y;
        BFGS bfgs = new BFGS(d2, i2);
        if (i3 == 2) {
            BinaryObjectiveFunction binaryObjectiveFunction = new BinaryObjectiveFunction(iArr, iArr3, i, d);
            double[] dArr = new double[i + 1];
            maxent = new Maxent(-bfgs.minimize(binaryObjectiveFunction, 5, dArr), dArr, fit.labels);
        } else {
            MultiClassObjectiveFunction multiClassObjectiveFunction = new MultiClassObjectiveFunction(iArr, iArr3, i3, i, d);
            double[] dArr2 = new double[(i3 - 1) * (i + 1)];
            double d3 = -bfgs.minimize(multiClassObjectiveFunction, 5, dArr2);
            double[][] dArr3 = new double[i3 - 1][i + 1];
            int i4 = 0;
            for (int i5 = 0; i5 < i3 - 1; i5++) {
                int i6 = 0;
                while (i6 <= i) {
                    dArr3[i5][i6] = dArr2[i4];
                    i6++;
                    i4++;
                }
            }
            maxent = new Maxent(d3, dArr3, fit.labels);
        }
        maxent.setLearningRate(0.1d / iArr.length);
        return maxent;
    }

    public int dimension() {
        return this.p;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static double dot(int[] iArr, double[] dArr) {
        double d = dArr[dArr.length - 1];
        for (int i : iArr) {
            d += dArr[i];
        }
        return d;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static double dot(int[] iArr, double[] dArr, int i, int i2) {
        int i3 = i * (i2 + 1);
        double d = dArr[i3 + i2];
        for (int i4 : iArr) {
            d += dArr[i3 + i4];
        }
        return d;
    }

    @Override // smile.classification.OnlineClassifier
    public void update(int[] iArr, int i) {
        int indexOf = this.labels.indexOf(i);
        if (this.k == 2) {
            double logistic = indexOf - MathEx.logistic(dot(iArr, this.w));
            double[] dArr = this.w;
            int i2 = this.p;
            dArr[i2] = dArr[i2] + (this.eta * logistic);
            for (int i3 : iArr) {
                double[] dArr2 = this.w;
                dArr2[i3] = dArr2[i3] + (this.eta * logistic);
            }
            return;
        }
        double[] dArr3 = new double[this.k];
        for (int i4 = 0; i4 < this.k - 1; i4++) {
            dArr3[i4] = dot(iArr, this.W[i4]);
        }
        MathEx.softmax(dArr3);
        int i5 = 0;
        while (i5 < this.k - 1) {
            double[] dArr4 = this.W[i5];
            double d = (indexOf == i5 ? 1.0d : 0.0d) - dArr3[i5];
            int i6 = this.p;
            dArr4[i6] = dArr4[i6] + (this.eta * d);
            for (int i7 : iArr) {
                dArr4[i7] = dArr4[i7] + (this.eta * d);
            }
            i5++;
        }
    }

    public void setLearningRate(double d) {
        if (d <= 0.0d) {
            throw new IllegalArgumentException("Invalid learning rate: " + d);
        }
        this.eta = d;
    }

    public double getLearningRate() {
        return this.eta;
    }

    public double loglikelihood() {
        return this.L;
    }

    @Override // smile.classification.Classifier
    public int predict(int[] iArr) {
        if (this.k != 2) {
            return predict(iArr, new double[this.k]);
        }
        return this.labels.valueOf(1.0d / (1.0d + Math.exp(-dot(iArr, this.w))) < 0.5d ? 0 : 1);
    }

    @Override // smile.classification.SoftClassifier
    public int predict(int[] iArr, double[] dArr) {
        if (dArr != null && dArr.length != this.k) {
            throw new IllegalArgumentException(String.format("Invalid posteriori vector size: %d, expected: %d", Integer.valueOf(dArr.length), Integer.valueOf(this.k)));
        }
        if (this.k == 2) {
            double exp = 1.0d / (1.0d + Math.exp(-dot(iArr, this.w)));
            dArr[0] = 1.0d - exp;
            dArr[1] = exp;
            return this.labels.valueOf(exp < 0.5d ? 0 : 1);
        }
        dArr[this.k - 1] = 0.0d;
        for (int i = 0; i < this.k - 1; i++) {
            dArr[i] = dot(iArr, this.W[i]);
        }
        MathEx.softmax(dArr);
        return this.labels.valueOf(MathEx.whichMax(dArr));
    }
}
