/*
 * Decompiled with CFR 0.152.
 */
package smile.classification;

import java.util.Properties;
import smile.classification.AbstractClassifier;
import smile.classification.DiscriminantAnalysis;
import smile.math.MathEx;
import smile.sort.QuickSort;
import smile.tensor.DenseMatrix;
import smile.tensor.EVD;
import smile.tensor.SVD;
import smile.tensor.ScalarType;
import smile.tensor.Vector;
import smile.util.IntSet;

public class FLD
extends AbstractClassifier<double[]> {
    private static final long serialVersionUID = 2L;
    private final int p;
    private final int k;
    private final DenseMatrix scaling;
    private final Vector mean;
    private final Vector[] mu;

    public FLD(double[] mean, double[][] mu, DenseMatrix scaling) {
        this(mean, mu, scaling, IntSet.of((int)mu.length));
    }

    public FLD(double[] mean, double[][] mu, DenseMatrix scaling, IntSet labels) {
        super(labels);
        this.k = mu.length;
        this.p = mean.length;
        this.scaling = scaling;
        int L = scaling.ncol();
        this.mean = scaling.tv(mean);
        this.mu = new Vector[this.k];
        for (int i = 0; i < this.k; ++i) {
            this.mu[i] = scaling.tv(mu[i]);
        }
    }

    public static FLD fit(double[][] x, int[] y) {
        return FLD.fit(x, y, -1, 1.0E-4);
    }

    public static FLD fit(double[][] x, int[] y, Properties params) {
        int L = Integer.parseInt(params.getProperty("smile.fisher.dimension", "-1"));
        double tol = Double.parseDouble(params.getProperty("smile.fisher.tolerance", "1E-4"));
        return FLD.fit(x, y, L, tol);
    }

    public static FLD fit(double[][] x, int[] y, int L, double tol) {
        if (x.length != y.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
        }
        DiscriminantAnalysis da = DiscriminantAnalysis.fit(x, y, null, tol);
        int n = x.length;
        int k = da.k;
        int p = da.mean.length;
        if (L >= k) {
            throw new IllegalArgumentException(String.format("The dimensionality of mapped space is too high: %d >= %d", L, k));
        }
        if (L <= 0) {
            L = k - 1;
        }
        double[] mean = da.mean;
        double[][] mu = da.mu;
        DenseMatrix scaling = n - k < p ? FLD.small(L, x, mean, mu, da.priori, tol) : FLD.fld(L, x, mean, mu, tol);
        return new FLD(mean, mu, scaling, da.labels);
    }

    private static DenseMatrix fld(int L, double[][] x, double[] mean, double[][] mu, double tol) {
        int k = mu.length;
        int p = mean.length;
        DenseMatrix St = DiscriminantAnalysis.St(x, mean, k, tol);
        for (double[] mui : mu) {
            for (int j = 0; j < p; ++j) {
                int n = j;
                mui[n] = mui[n] - mean[j];
            }
        }
        DenseMatrix Sb = DenseMatrix.zeros((ScalarType)ScalarType.Float64, (int)p, (int)p);
        for (double[] mui : mu) {
            for (int j = 0; j < p; ++j) {
                for (int i = 0; i <= j; ++i) {
                    Sb.add(i, j, mui[i] * mui[j]);
                }
            }
        }
        for (int j = 0; j < p; ++j) {
            for (int i = 0; i <= j; ++i) {
                Sb.div(i, j, (double)k);
                Sb.set(j, i, Sb.get(i, j));
            }
        }
        DenseMatrix Sw = St.copy();
        Sw.sub(Sb);
        DenseMatrix SwInvSb = Sw.inverse().mm(Sb);
        EVD eig = SwInvSb.eigen();
        double[] w = new double[p];
        for (int i = 0; i < p; ++i) {
            double wri = eig.wr().get(i);
            double wii = eig.wi().get(i);
            w[i] = -(wri * wri + wii * wii);
        }
        int[] index = QuickSort.sort((double[])w);
        DenseMatrix scaling = DenseMatrix.zeros((ScalarType)ScalarType.Float64, (int)p, (int)L);
        for (int j = 0; j < L; ++j) {
            int l = index[j];
            for (int i = 0; i < p; ++i) {
                scaling.set(i, j, eig.Vr().get(i, l));
            }
        }
        return scaling;
    }

    private static DenseMatrix small(int L, double[][] x, double[] mean, double[][] mu, double[] priori, double tol) {
        int k = mu.length;
        int p = mean.length;
        int n = x.length;
        double sqrtn = Math.sqrt(n);
        DenseMatrix X = DenseMatrix.zeros((ScalarType)ScalarType.Float64, (int)p, (int)n);
        for (int i = 0; i < n; ++i) {
            double[] xi = x[i];
            for (int j = 0; j < p; ++j) {
                X.set(j, i, (xi[j] - mean[j]) / sqrtn);
            }
        }
        for (double[] mui : mu) {
            for (int j = 0; j < p; ++j) {
                int n2 = j;
                mui[n2] = mui[n2] - mean[j];
            }
        }
        DenseMatrix M2 = DenseMatrix.zeros((ScalarType)ScalarType.Float64, (int)p, (int)k);
        for (int i = 0; i < k; ++i) {
            double pi = Math.sqrt(priori[i]);
            double[] mui = mu[i];
            for (int j = 0; j < p; ++j) {
                M2.set(j, i, pi * mui[j]);
            }
        }
        SVD svd = X.svd();
        DenseMatrix U = svd.U();
        Vector s = svd.s();
        tol *= tol;
        DenseMatrix UTM = U.tm(M2);
        for (int i = 0; i < n; ++i) {
            double si = s.get(i);
            si = si > tol ? 1.0 / Math.sqrt(si) : 0.0;
            for (int j = 0; j < k; ++j) {
                UTM.mul(i, j, si);
            }
        }
        DenseMatrix StInvM = U.mm(UTM);
        DenseMatrix U2 = U.tm(StInvM.svd().U().submatrix(0, 0, p, L));
        for (int i = 0; i < n; ++i) {
            double si = s.get(i);
            si = si > tol ? 1.0 / Math.sqrt(si) : 0.0;
            for (int j = 0; j < L; ++j) {
                U2.mul(i, j, si);
            }
        }
        return U.mm(U2);
    }

    @Override
    public int predict(double[] x) {
        if (x.length != this.p) {
            throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", x.length, this.p));
        }
        Vector wx = this.project(x);
        int y = 0;
        double nearest = Double.POSITIVE_INFINITY;
        for (int i = 0; i < this.k; ++i) {
            double d = MathEx.distance((Vector)wx, (Vector)this.mu[i]);
            if (!(d < nearest)) continue;
            nearest = d;
            y = i;
        }
        return this.classes.valueOf(y);
    }

    public Vector project(double[] x) {
        if (x.length != this.p) {
            throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", x.length, this.p));
        }
        Vector y = this.scaling.tv(Vector.column((double[])x));
        y.sub((DenseMatrix)this.mean);
        return y;
    }

    public Vector[] project(double[][] x) {
        Vector[] y = new Vector[x.length];
        for (int i = 0; i < x.length; ++i) {
            if (x[i].length != this.p) {
                throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", x[i].length, this.p));
            }
            y[i] = this.scaling.vector(this.scaling.ncol());
            this.scaling.tv(Vector.column((double[])x[i]), y[i]);
            y[i].sub((DenseMatrix)this.mean);
        }
        return y;
    }

    public DenseMatrix getProjection() {
        return this.scaling;
    }
}

