/*
 * Decompiled with CFR 0.152.
 */
package smile.classification;

import java.util.Arrays;
import java.util.Properties;
import smile.base.svm.KernelMachine;
import smile.base.svm.LASVM;
import smile.base.svm.LinearKernelMachine;
import smile.classification.AbstractClassifier;
import smile.classification.Classifier;
import smile.classification.OneVersusOne;
import smile.classification.OneVersusRest;
import smile.math.MathEx;
import smile.math.kernel.BinarySparseLinearKernel;
import smile.math.kernel.LinearKernel;
import smile.math.kernel.MercerKernel;
import smile.math.kernel.SparseLinearKernel;
import smile.util.IntSet;
import smile.util.SparseArray;

public class SVM<T>
extends KernelMachine<T>
implements Classifier<T> {
    public SVM(MercerKernel<T> kernel, T[] vectors, double[] weight, double b) {
        super(kernel, vectors, weight, b);
    }

    @Override
    public int numClasses() {
        return 2;
    }

    @Override
    public int[] classes() {
        return new int[]{-1, 1};
    }

    @Override
    public int predict(T x) {
        return this.score(x) > 0.0 ? 1 : -1;
    }

    public static Classifier<double[]> fit(double[][] x, int[] y, double C, double tol) {
        return SVM.fit(x, y, C, tol, 1);
    }

    public static Classifier<double[]> fit(double[][] x, int[] y, double C, double tol, int epochs) {
        LASVM lasvm = new LASVM(new LinearKernel(), C, tol);
        final KernelMachine svm = lasvm.fit((T[])x, y, epochs);
        IntSet labels = new IntSet(new int[]{-1, 1});
        return new AbstractClassifier<double[]>(labels){
            final LinearKernelMachine model;
            {
                super(classes);
                this.model = LinearKernelMachine.of(svm);
            }

            @Override
            public int predict(double[] x) {
                return this.model.f(x) > 0.0 ? 1 : -1;
            }
        };
    }

    public static Classifier<int[]> fit(int[][] x, int[] y, int p, double C, double tol) {
        return SVM.fit(x, y, p, C, tol, 1);
    }

    public static Classifier<int[]> fit(int[][] x, int[] y, final int p, double C, double tol, int epochs) {
        LASVM lasvm = new LASVM(new BinarySparseLinearKernel(), C, tol);
        final KernelMachine svm = lasvm.fit((T[])x, y, epochs);
        IntSet labels = new IntSet(new int[]{-1, 1});
        return new AbstractClassifier<int[]>(labels){
            final LinearKernelMachine model;
            {
                super(classes);
                this.model = LinearKernelMachine.binary(p, svm);
            }

            @Override
            public int predict(int[] x) {
                return this.model.f(x) > 0.0 ? 1 : -1;
            }
        };
    }

    public static Classifier<SparseArray> fit(SparseArray[] x, int[] y, int p, double C, double tol) {
        return SVM.fit(x, y, p, C, tol, 1);
    }

    public static Classifier<SparseArray> fit(SparseArray[] x, int[] y, final int p, double C, double tol, int epochs) {
        LASVM<SparseArray> lasvm = new LASVM<SparseArray>((MercerKernel<SparseArray>)new SparseLinearKernel(), C, tol);
        final KernelMachine<SparseArray> svm = lasvm.fit(x, y, epochs);
        IntSet labels = new IntSet(new int[]{-1, 1});
        return new AbstractClassifier<SparseArray>(labels){
            final LinearKernelMachine model;
            {
                super(classes);
                this.model = LinearKernelMachine.sparse(p, svm);
            }

            @Override
            public int predict(SparseArray x) {
                return this.model.f(x) > 0.0 ? 1 : -1;
            }
        };
    }

    public static <T> SVM<T> fit(T[] x, int[] y, MercerKernel<T> kernel, double C, double tol) {
        return SVM.fit(x, y, kernel, C, tol, 1);
    }

    public static <T> SVM<T> fit(T[] x, int[] y, MercerKernel<T> kernel, double C, double tol, int epochs) {
        LASVM<T> lasvm = new LASVM<T>(kernel, C, tol);
        KernelMachine<T> model = lasvm.fit(x, y, epochs);
        return new SVM<T>(model.kernel(), model.vectors(), model.weights(), model.intercept());
    }

    public static Classifier<double[]> fit(double[][] x, int[] y, Properties params) {
        String trainer;
        MercerKernel kernel = MercerKernel.of((String)params.getProperty("smile.svm.kernel", "linear"));
        double C = Double.parseDouble(params.getProperty("smile.svm.C", "1.0"));
        double tol = Double.parseDouble(params.getProperty("smile.svm.tolerance", "1E-3"));
        int epochs = Integer.parseInt(params.getProperty("smile.svm.epochs", "1"));
        int[] classes = MathEx.unique((int[])y);
        switch (trainer = params.getProperty("smile.svm.type", classes.length == 2 ? "binary" : "ovr").toLowerCase()) {
            case "ovr": {
                if (kernel instanceof LinearKernel) {
                    return OneVersusRest.fit(x, y, (T[] xi, int[] yi) -> SVM.fit(xi, yi, C, tol, epochs));
                }
                return OneVersusRest.fit(x, y, (T[] xi, int[] yi) -> SVM.fit(xi, yi, kernel, C, tol, epochs));
            }
            case "ovo": {
                if (kernel instanceof LinearKernel) {
                    return OneVersusOne.fit(x, y, (T[] xi, int[] yi) -> SVM.fit(xi, yi, C, tol, epochs));
                }
                return OneVersusOne.fit(x, y, (T[] xi, int[] yi) -> SVM.fit(xi, yi, kernel, C, tol, epochs));
            }
            case "binary": {
                Arrays.sort(classes);
                if (classes[0] != -1 || classes[1] != 1) {
                    y = (int[])y.clone();
                    for (int i = 0; i < y.length; ++i) {
                        y[i] = y[i] == classes[0] ? -1 : 1;
                    }
                }
                if (kernel instanceof LinearKernel) {
                    return SVM.fit(x, y, C, tol, epochs);
                }
                return SVM.fit(x, y, kernel, C, tol, epochs);
            }
        }
        throw new IllegalArgumentException("Unknown SVM type: " + trainer);
    }
}

