/*
 * Decompiled with CFR 0.152.
 */
package smile.classification;

import java.util.Properties;
import smile.classification.Classifier;
import smile.classification.DiscriminantAnalysis;
import smile.data.DataFrame;
import smile.data.formula.Formula;
import smile.math.MathEx;
import smile.math.matrix.DenseMatrix;
import smile.math.matrix.EVD;
import smile.math.matrix.Matrix;
import smile.math.matrix.SVD;
import smile.projection.Projection;
import smile.util.IntSet;

public class FLD
implements Classifier<double[]>,
Projection<double[]> {
    private static final long serialVersionUID = 2L;
    private final int p;
    private final int k;
    private final DenseMatrix scaling;
    private final double[] mean;
    private final double[][] mu;
    private final IntSet labels;

    public FLD(double[] mean, double[][] mu, DenseMatrix scaling) {
        this(mean, mu, scaling, IntSet.of((int)mu.length));
    }

    public FLD(double[] mean, double[][] mu, DenseMatrix scaling, IntSet labels) {
        this.k = mu.length;
        this.p = mean.length;
        this.scaling = scaling;
        this.labels = labels;
        int L = scaling.ncols();
        this.mean = new double[L];
        scaling.atx(mean, this.mean);
        this.mu = new double[this.k][L];
        for (int i = 0; i < this.k; ++i) {
            scaling.atx(mu[i], this.mu[i]);
        }
    }

    public static FLD fit(Formula formula, DataFrame data) {
        return FLD.fit(formula, data, new Properties());
    }

    public static FLD fit(Formula formula, DataFrame data, Properties prop) {
        int L = Integer.valueOf(prop.getProperty("smile.fld.dimension", "-1"));
        double tol = Double.valueOf(prop.getProperty("smile.fld.tolerance", "1E-4"));
        double[][] x = formula.x(data).toArray();
        int[] y = formula.y(data).toIntArray();
        return FLD.fit(x, y, L, tol);
    }

    public static FLD fit(double[][] x, int[] y) {
        return FLD.fit(x, y, -1, 1.0E-4);
    }

    public static FLD fit(double[][] x, int[] y, int L, double tol) {
        if (x.length != y.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
        }
        DiscriminantAnalysis da = DiscriminantAnalysis.fit(x, y, null, tol);
        int n = x.length;
        int k = da.k;
        int p = da.mean.length;
        if (L >= k) {
            throw new IllegalArgumentException(String.format("The dimensionality of mapped space is too high: %d >= %d", L, k));
        }
        if (L <= 0) {
            L = k - 1;
        }
        double[] mean = da.mean;
        double[][] mu = da.mu;
        DenseMatrix scaling = n - k < p ? FLD.small(L, x, mean, mu, da.priori, tol) : FLD.fld(L, x, mean, mu, tol);
        return new FLD(mean, mu, scaling, da.labels);
    }

    private static DenseMatrix fld(int L, double[][] x, double[] mean, double[][] mu, double tol) {
        int j;
        int i;
        int k = mu.length;
        int p = mean.length;
        DenseMatrix St = DiscriminantAnalysis.St(x, mean, k, tol);
        EVD eigen = St.eigen();
        tol *= tol;
        double[] s = eigen.getEigenValues();
        for (i = 0; i < s.length; ++i) {
            if (s[i] < tol) {
                throw new IllegalArgumentException("The covariance matrix is close to singular.");
            }
            s[i] = 1.0 / s[i];
        }
        for (i = 0; i < k; ++i) {
            double[] mui = mu[i];
            for (int j2 = 0; j2 < p; ++j2) {
                int n = j2;
                mui[n] = mui[n] - mean[j2];
            }
        }
        DenseMatrix Sb = Matrix.zeros((int)p, (int)p);
        for (int c = 0; c < k; ++c) {
            double[] mui = mu[c];
            for (j = 0; j < p; ++j) {
                for (int i2 = 0; i2 <= j; ++i2) {
                    Sb.add(i2, j, mui[i2] * mui[j]);
                }
            }
        }
        for (int j3 = 0; j3 < p; ++j3) {
            for (int i3 = 0; i3 <= j3; ++i3) {
                Sb.div(i3, j3, (double)k);
                Sb.set(j3, i3, Sb.get(i3, j3));
            }
        }
        DenseMatrix U = eigen.getEigenVectors();
        DenseMatrix UB = (DenseMatrix)U.atbmm((Object)Sb);
        for (j = 0; j < p; ++j) {
            double sj = s[j];
            for (int i4 = 0; i4 < k; ++i4) {
                UB.mul(i4, j, sj);
            }
        }
        DenseMatrix StInvSb = (DenseMatrix)U.abmm((Object)UB);
        StInvSb.setSymmetric(true);
        DenseMatrix scaling = StInvSb.eigen().getEigenVectors().submat(0, 0, p, L);
        return scaling;
    }

    private static DenseMatrix small(int L, double[][] x, double[] mean, double[][] mu, double[] priori, double tol) {
        int j;
        int i;
        int k = mu.length;
        int p = mean.length;
        int n = x.length;
        double sqrtn = Math.sqrt(n);
        DenseMatrix X = Matrix.zeros((int)p, (int)n);
        for (i = 0; i < n; ++i) {
            double[] xi = x[i];
            for (j = 0; j < p; ++j) {
                X.set(j, i, (xi[j] - mean[j]) / sqrtn);
            }
        }
        for (i = 0; i < k; ++i) {
            double[] mui = mu[i];
            for (j = 0; j < p; ++j) {
                int n2 = j;
                mui[n2] = mui[n2] - mean[j];
            }
        }
        DenseMatrix M2 = Matrix.zeros((int)p, (int)k);
        for (int i2 = 0; i2 < k; ++i2) {
            double pi = Math.sqrt(priori[i2]);
            double[] mui = mu[i2];
            for (int j2 = 0; j2 < p; ++j2) {
                M2.set(j2, i2, pi * mui[j2]);
            }
        }
        SVD svd = X.svd(true);
        DenseMatrix U = svd.getU();
        double[] s = svd.getSingularValues();
        tol *= tol;
        DenseMatrix UTM = (DenseMatrix)U.atbmm((Object)M2);
        for (int i3 = 0; i3 < n; ++i3) {
            double si = 0.0;
            if (s[i3] > tol) {
                si = 1.0 / Math.sqrt(s[i3]);
            }
            for (int j3 = 0; j3 < k; ++j3) {
                UTM.mul(i3, j3, si);
            }
        }
        DenseMatrix StInvM = (DenseMatrix)U.abmm((Object)UTM);
        DenseMatrix U2 = (DenseMatrix)U.atbmm((Object)StInvM.svd(true).getU().submat(0, 0, p + 1, L));
        for (int i4 = 0; i4 < n; ++i4) {
            double si = 0.0;
            if (s[i4] > tol) {
                si = 1.0 / Math.sqrt(s[i4]);
            }
            for (int j4 = 0; j4 < L; ++j4) {
                U2.mul(i4, j4, si);
            }
        }
        DenseMatrix scaling = (DenseMatrix)U.abmm((Object)U2);
        return scaling;
    }

    @Override
    public int predict(double[] x) {
        if (x.length != this.p) {
            throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", x.length, this.p));
        }
        double[] wx = this.project(x);
        int y = 0;
        double nearest = Double.POSITIVE_INFINITY;
        for (int i = 0; i < this.k; ++i) {
            double d = MathEx.distance((double[])wx, (double[])this.mu[i]);
            if (!(d < nearest)) continue;
            nearest = d;
            y = i;
        }
        return this.labels.valueOf(y);
    }

    @Override
    public double[] project(double[] x) {
        if (x.length != this.p) {
            throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", x.length, this.p));
        }
        double[] y = new double[this.scaling.ncols()];
        this.scaling.atx(x, y);
        MathEx.sub((double[])y, (double[])this.mean);
        return y;
    }

    public double[][] project(double[][] x) {
        double[][] y = new double[x.length][this.scaling.ncols()];
        for (int i = 0; i < x.length; ++i) {
            if (x[i].length != this.p) {
                throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", x[i].length, this.p));
            }
            this.scaling.atx(x[i], y[i]);
            MathEx.sub((double[])y[i], (double[])this.mean);
        }
        return y;
    }

    public DenseMatrix getProjection() {
        return this.scaling;
    }
}

