/*
 * Decompiled with CFR 0.152.
 */
package smile.classification;

import java.util.Properties;
import smile.classification.DiscriminantAnalysis;
import smile.classification.SoftClassifier;
import smile.data.CategoricalEncoder;
import smile.data.DataFrame;
import smile.data.formula.Formula;
import smile.math.MathEx;
import smile.math.matrix.Matrix;
import smile.util.IntSet;
import smile.util.Strings;

public class LDA
implements SoftClassifier<double[]> {
    private static final long serialVersionUID = 2L;
    private final int p;
    private final int k;
    private final double[] logppriori;
    private final double[] priori;
    private final double[][] mu;
    private final double[] eigen;
    private final Matrix scaling;
    private final IntSet labels;

    public LDA(double[] priori, double[][] mu, double[] eigen, Matrix scaling) {
        this(priori, mu, eigen, scaling, IntSet.of((int)priori.length));
    }

    public LDA(double[] priori, double[][] mu, double[] eigen, Matrix scaling, IntSet labels) {
        this.k = priori.length;
        this.p = mu[0].length;
        this.priori = priori;
        this.mu = mu;
        this.eigen = eigen;
        this.scaling = scaling;
        this.labels = labels;
        this.logppriori = new double[this.k];
        for (int i = 0; i < this.k; ++i) {
            this.logppriori[i] = Math.log(priori[i]);
        }
    }

    public static LDA fit(Formula formula, DataFrame data) {
        return LDA.fit(formula, data, new Properties());
    }

    public static LDA fit(Formula formula, DataFrame data, Properties prop) {
        double[][] x = formula.x(data).toArray(false, CategoricalEncoder.DUMMY);
        int[] y = formula.y(data).toIntArray();
        return LDA.fit(x, y, prop);
    }

    public static LDA fit(double[][] x, int[] y) {
        return LDA.fit(x, y, null, 1.0E-4);
    }

    public static LDA fit(double[][] x, int[] y, Properties prop) {
        double[] priori = Strings.parseDoubleArray((String)prop.getProperty("smile.lda.priori"));
        double tol = Double.valueOf(prop.getProperty("smile.lda.tolerance", "1E-4"));
        return LDA.fit(x, y, priori, tol);
    }

    public static LDA fit(double[][] x, int[] y, double[] priori, double tol) {
        DiscriminantAnalysis da = DiscriminantAnalysis.fit(x, y, priori, tol);
        Matrix St = DiscriminantAnalysis.St(x, da.mean, da.k, tol);
        Matrix.EVD eigen = St.eigen(false, true, true).sort();
        tol *= tol;
        for (double s : eigen.wr) {
            if (!(s < tol)) continue;
            throw new IllegalArgumentException("The covariance matrix is close to singular.");
        }
        return new LDA(da.priori, da.mu, eigen.wr, eigen.Vr, da.labels);
    }

    public double[] priori() {
        return this.priori;
    }

    @Override
    public int predict(double[] x) {
        return this.predict(x, new double[this.k]);
    }

    @Override
    public int predict(double[] x, double[] posteriori) {
        if (x.length != this.p) {
            throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", x.length, this.p));
        }
        double[] d = new double[this.p];
        double[] ux = new double[this.p];
        for (int i = 0; i < this.k; ++i) {
            double[] mean = this.mu[i];
            for (int j = 0; j < this.p; ++j) {
                d[j] = x[j] - mean[j];
            }
            this.scaling.tv(d, ux);
            double f = 0.0;
            for (int j = 0; j < this.p; ++j) {
                f += ux[j] * ux[j] / this.eigen[j];
            }
            posteriori[i] = this.logppriori[i] - 0.5 * f;
        }
        return this.labels.valueOf(MathEx.softmax((double[])posteriori));
    }
}

