/*
 * Decompiled with CFR 0.152.
 */
package smile.classification;

import java.util.Properties;
import smile.classification.AbstractClassifier;
import smile.classification.DiscriminantAnalysis;
import smile.math.MathEx;
import smile.tensor.DenseMatrix;
import smile.tensor.EVD;
import smile.tensor.Vector;
import smile.util.IntSet;
import smile.util.Strings;

public class LDA
extends AbstractClassifier<double[]> {
    private static final long serialVersionUID = 2L;
    private final int p;
    private final int k;
    private final double[] logppriori;
    private final double[] priori;
    private final double[][] mu;
    private final Vector eigen;
    private final DenseMatrix scaling;

    public LDA(double[] priori, double[][] mu, Vector eigen, DenseMatrix scaling) {
        this(priori, mu, eigen, scaling, IntSet.of((int)priori.length));
    }

    public LDA(double[] priori, double[][] mu, Vector eigen, DenseMatrix scaling, IntSet labels) {
        super(labels);
        this.k = priori.length;
        this.p = mu[0].length;
        this.priori = priori;
        this.mu = mu;
        this.eigen = eigen;
        this.scaling = scaling;
        this.logppriori = new double[this.k];
        for (int i = 0; i < this.k; ++i) {
            this.logppriori[i] = Math.log(priori[i]);
        }
    }

    public static LDA fit(double[][] x, int[] y) {
        return LDA.fit(x, y, null, 1.0E-4);
    }

    public static LDA fit(double[][] x, int[] y, Properties params) {
        double[] priori = Strings.parseDoubleArray((String)params.getProperty("smile.lda.priori"));
        double tol = Double.parseDouble(params.getProperty("smile.lda.tolerance", "1E-4"));
        return LDA.fit(x, y, priori, tol);
    }

    public static LDA fit(double[][] x, int[] y, double[] priori, double tol) {
        DiscriminantAnalysis da = DiscriminantAnalysis.fit(x, y, priori, tol);
        DenseMatrix St = DiscriminantAnalysis.St(x, da.mean, da.k, tol);
        EVD eigen = St.eigen().sort();
        tol *= tol;
        for (int j = 0; j < eigen.wr().size(); ++j) {
            if (!(eigen.wr().get(j) < tol)) continue;
            throw new IllegalArgumentException("The covariance matrix is close to singular.");
        }
        return new LDA(da.priori, da.mu, eigen.wr(), eigen.Vr(), da.labels);
    }

    public double[] priori() {
        return this.priori;
    }

    @Override
    public int predict(double[] x) {
        return this.predict(x, new double[this.k]);
    }

    @Override
    public boolean soft() {
        return true;
    }

    @Override
    public int predict(double[] x, double[] posteriori) {
        if (x.length != this.p) {
            throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", x.length, this.p));
        }
        Vector d = this.scaling.vector(this.p);
        Vector ux = this.scaling.vector(this.p);
        for (int i = 0; i < this.k; ++i) {
            double[] mean = this.mu[i];
            for (int j = 0; j < this.p; ++j) {
                d.set(j, x[j] - mean[j]);
            }
            this.scaling.tv(d, ux);
            double f = 0.0;
            for (int j = 0; j < this.p; ++j) {
                double uxj = ux.get(j);
                f += uxj * uxj / this.eigen.get(j);
            }
            posteriori[i] = this.logppriori[i] - 0.5 * f;
        }
        return this.classes.valueOf(MathEx.softmax((double[])posteriori));
    }
}

