/*
 * Decompiled with CFR 0.152.
 */
package smile.classification;

import smile.classification.ClassLabels;
import smile.linalg.UPLO;
import smile.math.MathEx;
import smile.tensor.DenseMatrix;
import smile.tensor.ScalarType;
import smile.util.IntSet;

class DiscriminantAnalysis {
    final int k;
    final int[] y;
    final IntSet labels;
    final int[] ni;
    final double[] priori;
    final double[] mean;
    final double[][] mu;

    public DiscriminantAnalysis(ClassLabels codec, double[] priori, double[] mean, double[][] mu) {
        this.k = codec.k;
        this.ni = codec.ni;
        this.y = codec.y;
        this.labels = codec.classes;
        this.priori = priori;
        this.mean = mean;
        this.mu = mu;
    }

    public static DiscriminantAnalysis fit(double[][] x, int[] y, double[] priori, double tol) {
        int j;
        int i;
        if (x.length != y.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
        }
        if (tol <= 0.0) {
            throw new IllegalArgumentException("Invalid tol: " + tol);
        }
        int n = x.length;
        ClassLabels codec = ClassLabels.fit(y);
        int k = codec.k;
        y = codec.y;
        int[] ni = codec.ni;
        if (n <= k) {
            throw new IllegalArgumentException(String.format("Sample size is too small: %d <= %d", n, k));
        }
        if (priori == null) {
            priori = codec.priori;
        } else {
            if (priori.length != k) {
                throw new IllegalArgumentException("Invalid number of priori probabilities: " + priori.length);
            }
            double sum = 0.0;
            for (double pr : priori) {
                if (pr <= 0.0 || pr >= 1.0) {
                    throw new IllegalArgumentException("Invalid priori probability: " + pr);
                }
                sum += pr;
            }
            if (Math.abs(sum - 1.0) > 1.0E-10) {
                throw new IllegalArgumentException("The sum of priori probabilities is not one: " + sum);
            }
        }
        int p = x[0].length;
        double[] mean = MathEx.colMeans((double[][])x);
        double[][] mu = new double[k][p];
        for (i = 0; i < n; ++i) {
            double[] xi = x[i];
            double[] mui = mu[y[i]];
            for (j = 0; j < p; ++j) {
                int n2 = j;
                mui[n2] = mui[n2] + xi[j];
            }
        }
        for (i = 0; i < k; ++i) {
            int m = ni[i];
            double[] mui = mu[i];
            j = 0;
            while (j < p) {
                int n3 = j++;
                mui[n3] = mui[n3] / (double)m;
            }
        }
        return new DiscriminantAnalysis(codec, priori, mean, mu);
    }

    public static DenseMatrix St(double[][] x, double[] mean, int k, double tol) {
        int n = x.length;
        int p = x[0].length;
        DenseMatrix St = DenseMatrix.zeros((ScalarType)ScalarType.Float64, (int)p, (int)p);
        St.withUplo(UPLO.LOWER);
        for (double[] xi : x) {
            for (int j = 0; j < p; ++j) {
                for (int l = 0; l <= j; ++l) {
                    St.add(j, l, (xi[j] - mean[j]) * (xi[l] - mean[l]));
                }
            }
        }
        tol *= tol;
        for (int j = 0; j < p; ++j) {
            for (int l = 0; l <= j; ++l) {
                St.div(j, l, (double)(n - k));
                St.set(l, j, St.get(j, l));
            }
            if (!(St.get(j, j) < tol)) continue;
            throw new IllegalArgumentException(String.format("Covariance matrix (column %d) is close to singular.", j));
        }
        return St;
    }

    public static DenseMatrix[] cov(double[][] x, int[] y, double[][] mu, int[] ni) {
        DenseMatrix v;
        int i;
        int n = x.length;
        int p = x[0].length;
        int k = mu.length;
        DenseMatrix[] cov = new DenseMatrix[k];
        for (i = 0; i < k; ++i) {
            if (ni[i] <= p) {
                throw new IllegalArgumentException(String.format("The sample size of class %d is too small.", i));
            }
            cov[i] = DenseMatrix.zeros((ScalarType)ScalarType.Float64, (int)p, (int)p);
            cov[i].withUplo(UPLO.LOWER);
        }
        for (i = 0; i < n; ++i) {
            v = cov[y[i]];
            double[] mui = mu[y[i]];
            double[] xi = x[i];
            for (int j = 0; j < p; ++j) {
                for (int l = 0; l <= j; ++l) {
                    v.add(j, l, (xi[j] - mui[j]) * (xi[l] - mui[l]));
                }
            }
        }
        for (i = 0; i < k; ++i) {
            v = cov[i];
            int m = ni[i] - 1;
            for (int j = 0; j < p; ++j) {
                for (int l = 0; l <= j; ++l) {
                    v.div(j, l, (double)m);
                    v.set(l, j, v.get(j, l));
                }
            }
        }
        return cov;
    }
}

