/*
 * Decompiled with CFR 0.152.
 */
package smile.classification;

import java.util.Properties;
import smile.classification.DiscriminantAnalysis;
import smile.classification.QDA;
import smile.data.DataFrame;
import smile.data.formula.Formula;
import smile.math.matrix.DenseMatrix;
import smile.math.matrix.EVD;
import smile.util.IntSet;
import smile.util.Strings;

public class RDA
extends QDA {
    private static final long serialVersionUID = 2L;

    public RDA(double[] priori, double[][] mu, double[][] eigen, DenseMatrix[] scaling) {
        super(priori, mu, eigen, scaling, IntSet.of((int)priori.length));
    }

    public RDA(double[] priori, double[][] mu, double[][] eigen, DenseMatrix[] scaling, IntSet labels) {
        super(priori, mu, eigen, scaling, labels);
    }

    public static RDA fit(Formula formula, DataFrame data) {
        return RDA.fit(formula, data, new Properties());
    }

    public static RDA fit(Formula formula, DataFrame data, Properties prop) {
        double[][] x = formula.x(data).toArray();
        int[] y = formula.y(data).toIntArray();
        return RDA.fit(x, y, prop);
    }

    public static RDA fit(double[][] x, int[] y, Properties prop) {
        double alpha = Double.valueOf(prop.getProperty("smile.rda.alpha", "0.9"));
        double[] priori = Strings.parseDoubleArray((String)prop.getProperty("smile.rda.priori"));
        double tol = Double.valueOf(prop.getProperty("smile.rda.tolerance", "1E-4"));
        return RDA.fit(x, y, alpha, priori, tol);
    }

    public static RDA fit(double[][] x, int[] y, double alpha) {
        return RDA.fit(x, y, alpha, null, 1.0E-4);
    }

    public static RDA fit(double[][] x, int[] y, double alpha, double[] priori, double tol) {
        if (alpha < 0.0 || alpha > 1.0) {
            throw new IllegalArgumentException("Invalid regularization factor: " + alpha);
        }
        DiscriminantAnalysis da = DiscriminantAnalysis.fit(x, y, priori, tol);
        int k = da.k;
        int p = da.mean.length;
        DenseMatrix St = DiscriminantAnalysis.St(x, da.mean, k, tol);
        DenseMatrix[] cov = DiscriminantAnalysis.cov(x, y, da.mu, da.ni);
        double[][] eigen = new double[k][];
        DenseMatrix[] scaling = new DenseMatrix[k];
        tol *= tol;
        for (int i = 0; i < k; ++i) {
            DenseMatrix v = cov[i];
            for (int r = 0; r < p; ++r) {
                for (int s = 0; s <= r; ++s) {
                    v.set(r, s, alpha * v.get(r, s) + (1.0 - alpha) * St.get(r, s));
                    v.set(s, r, v.get(r, s));
                }
            }
            for (int j = 0; j < p; ++j) {
                if (!(v.get(j, j) < tol)) continue;
                throw new IllegalArgumentException(String.format("Class %d covariance matrix (column %d) is close to singular.", i, j));
            }
            EVD evd = v.eigen();
            for (double s : evd.getEigenValues()) {
                if (!(s < tol)) continue;
                throw new IllegalArgumentException(String.format("Class %d covariance matrix is close to singular.", i));
            }
            eigen[i] = evd.getEigenValues();
            scaling[i] = evd.getEigenVectors();
        }
        return new RDA(da.priori, da.mu, eigen, scaling, da.labels);
    }
}

