/*
 * Decompiled with CFR 0.152.
 */
package smile.classification;

import java.util.Arrays;
import java.util.Properties;
import smile.base.svm.KernelMachine;
import smile.base.svm.LASVM;
import smile.classification.BinarySparseLinearSVM;
import smile.classification.Classifier;
import smile.classification.LinearSVM;
import smile.classification.OneVersusOne;
import smile.classification.OneVersusRest;
import smile.classification.SparseLinearSVM;
import smile.math.MathEx;
import smile.math.kernel.BinarySparseLinearKernel;
import smile.math.kernel.LinearKernel;
import smile.math.kernel.MercerKernel;
import smile.math.kernel.SparseLinearKernel;
import smile.util.SparseArray;

public class SVM<T>
extends KernelMachine<T>
implements Classifier<T> {
    public SVM(MercerKernel<T> kernel, T[] vectors, double[] weight, double b) {
        super(kernel, vectors, weight, b);
    }

    @Override
    public int numClasses() {
        return 2;
    }

    @Override
    public int[] classes() {
        return new int[]{-1, 1};
    }

    @Override
    public int predict(T x) {
        return this.score(x) > 0.0 ? 1 : -1;
    }

    public static LinearSVM fit(double[][] x, int[] y, Options options) {
        LASVM lasvm = new LASVM(new LinearKernel(), options.C, options.tol);
        KernelMachine<double[]> svm = lasvm.fit((T[])x, y, options.epochs);
        return new LinearSVM(svm);
    }

    public static BinarySparseLinearSVM fit(int[][] x, int[] y, int p, Options options) {
        LASVM lasvm = new LASVM(new BinarySparseLinearKernel(), options.C, options.tol);
        KernelMachine<int[]> svm = lasvm.fit((T[])x, y, options.epochs);
        return new BinarySparseLinearSVM(p, svm);
    }

    public static SparseLinearSVM fit(SparseArray[] x, int[] y, int p, Options options) {
        LASVM<SparseArray> lasvm = new LASVM<SparseArray>((MercerKernel<SparseArray>)new SparseLinearKernel(), options.C, options.tol);
        KernelMachine<SparseArray> svm = lasvm.fit(x, y, options.epochs);
        return new SparseLinearSVM(p, svm);
    }

    public static <T> SVM<T> fit(T[] x, int[] y, MercerKernel<T> kernel, Options options) {
        LASVM<T> lasvm = new LASVM<T>(kernel, options.C, options.tol);
        KernelMachine<T> model = lasvm.fit(x, y, options.epochs);
        return new SVM<T>(model.kernel(), model.vectors(), model.weights(), model.intercept());
    }

    public static Classifier<double[]> fit(double[][] x, int[] y, Properties params) {
        String trainer;
        MercerKernel kernel = MercerKernel.of((String)params.getProperty("smile.svm.kernel", "linear"));
        Options options = Options.of(params);
        int[] classes = MathEx.unique((int[])y);
        switch (trainer = params.getProperty("smile.svm.type", classes.length == 2 ? "binary" : "ovr").toLowerCase()) {
            case "ovr": {
                if (kernel instanceof LinearKernel) {
                    return OneVersusRest.fit(x, y, (T[] xi, int[] yi) -> SVM.fit(xi, yi, options));
                }
                return OneVersusRest.fit(x, y, (T[] xi, int[] yi) -> SVM.fit(xi, yi, kernel, options));
            }
            case "ovo": {
                if (kernel instanceof LinearKernel) {
                    return OneVersusOne.fit(x, y, (T[] xi, int[] yi) -> SVM.fit(xi, yi, options));
                }
                return OneVersusOne.fit(x, y, (T[] xi, int[] yi) -> SVM.fit(xi, yi, kernel, options));
            }
            case "binary": {
                Arrays.sort(classes);
                if (classes[0] != -1 || classes[1] != 1) {
                    y = (int[])y.clone();
                    for (int i = 0; i < y.length; ++i) {
                        y[i] = y[i] == classes[0] ? -1 : 1;
                    }
                }
                if (kernel instanceof LinearKernel) {
                    return SVM.fit(x, y, options);
                }
                return SVM.fit(x, y, kernel, options);
            }
        }
        throw new IllegalArgumentException("Unknown SVM type: " + trainer);
    }

    public record Options(double C, double tol, int epochs) {
        public Options {
            if (C < 0.0) {
                throw new IllegalArgumentException("Invalid maximum number of iterations: " + C);
            }
            if (tol <= 0.0) {
                throw new IllegalArgumentException("Invalid tolerance: " + tol);
            }
            if (epochs < 1) {
                throw new IllegalArgumentException("Invalid epochs: " + epochs);
            }
        }

        public Options(double C) {
            this(C, 0.001, 1);
        }

        public Properties toProperties() {
            Properties props = new Properties();
            props.setProperty("smile.svm.C", Double.toString(this.C));
            props.setProperty("smile.svm.tolerance", Double.toString(this.tol));
            props.setProperty("smile.svm.epochs", Integer.toString(this.epochs));
            return props;
        }

        public static Options of(Properties props) {
            double C = Double.parseDouble(props.getProperty("smile.svm.C", "1.0"));
            double tol = Double.parseDouble(props.getProperty("smile.svm.tolerance", "1E-3"));
            int epochs = Integer.parseInt(props.getProperty("smile.svm.epochs", "1"));
            return new Options(C, tol, epochs);
        }
    }
}

