/*
 * Decompiled with CFR 0.152.
 */
package smile.classification;

import java.util.Properties;
import smile.classification.AbstractClassifier;
import smile.classification.DiscriminantAnalysis;
import smile.math.MathEx;
import smile.tensor.DenseMatrix;
import smile.tensor.EVD;
import smile.tensor.Vector;
import smile.util.IntSet;
import smile.util.Strings;

public class QDA
extends AbstractClassifier<double[]> {
    private static final long serialVersionUID = 2L;
    private final int p;
    private final int k;
    private final double[] logppriori;
    private final double[] priori;
    private final double[][] mu;
    private final Vector[] eigen;
    private final DenseMatrix[] scaling;

    public QDA(double[] priori, double[][] mu, Vector[] eigen, DenseMatrix[] scaling) {
        this(priori, mu, eigen, scaling, IntSet.of((int)priori.length));
    }

    public QDA(double[] priori, double[][] mu, Vector[] eigen, DenseMatrix[] scaling, IntSet labels) {
        super(labels);
        this.k = priori.length;
        this.p = mu[0].length;
        this.priori = priori;
        this.mu = mu;
        this.eigen = eigen;
        this.scaling = scaling;
        this.logppriori = new double[this.k];
        for (int i = 0; i < this.k; ++i) {
            double logev = 0.0;
            for (int j = 0; j < this.p; ++j) {
                logev += Math.log(eigen[i].get(j));
            }
            this.logppriori[i] = Math.log(priori[i]) - 0.5 * logev;
        }
    }

    public static QDA fit(double[][] x, int[] y) {
        return QDA.fit(x, y, null, 1.0E-4);
    }

    public static QDA fit(double[][] x, int[] y, Properties params) {
        double[] priori = Strings.parseDoubleArray((String)params.getProperty("smile.qda.priori"));
        double tol = Double.parseDouble(params.getProperty("smile.qda.tolerance", "1E-4"));
        return QDA.fit(x, y, priori, tol);
    }

    public static QDA fit(double[][] x, int[] y, double[] priori, double tol) {
        DiscriminantAnalysis da = DiscriminantAnalysis.fit(x, y, priori, tol);
        DenseMatrix[] cov = DiscriminantAnalysis.cov(x, y, da.mu, da.ni);
        int k = cov.length;
        int p = cov[0].nrow();
        Vector[] eigen = new Vector[k];
        DenseMatrix[] scaling = new DenseMatrix[k];
        tol *= tol;
        for (int i = 0; i < k; ++i) {
            for (int j = 0; j < p; ++j) {
                if (!(cov[i].get(j, j) < tol)) continue;
                throw new IllegalArgumentException(String.format("Class %d covariance matrix (column %d) is close to singular.", i, j));
            }
            EVD eig = cov[i].eigen().sort();
            for (int j = 0; j < eig.wr().size(); ++j) {
                if (!(eig.wr().get(j) < tol)) continue;
                throw new IllegalArgumentException(String.format("Class %d covariance matrix is close to singular.", i));
            }
            eigen[i] = eig.wr();
            scaling[i] = eig.Vr();
        }
        return new QDA(da.priori, da.mu, eigen, scaling, da.labels);
    }

    public double[] priori() {
        return this.priori;
    }

    @Override
    public int predict(double[] x) {
        return this.predict(x, new double[this.k]);
    }

    @Override
    public boolean soft() {
        return true;
    }

    @Override
    public int predict(double[] x, double[] posteriori) {
        if (x.length != this.p) {
            throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", x.length, this.p));
        }
        Vector d = this.scaling[0].vector(this.p);
        Vector ux = this.scaling[0].vector(this.p);
        for (int i = 0; i < this.k; ++i) {
            double[] mui = this.mu[i];
            for (int j = 0; j < this.p; ++j) {
                d.set(j, x[j] - mui[j]);
            }
            this.scaling[i].tv(d, ux);
            double f = 0.0;
            Vector eig = this.eigen[i];
            for (int j = 0; j < this.p; ++j) {
                double uxj = ux.get(j);
                f += uxj * uxj / eig.get(j);
            }
            posteriori[i] = this.logppriori[i] - 0.5 * f;
        }
        return this.classes.valueOf(MathEx.softmax((double[])posteriori));
    }
}

