package smile.classification;

import java.util.Arrays;
import java.util.Properties;
import smile.base.svm.KernelMachine;
import smile.base.svm.LASVM;
import smile.base.svm.LinearKernelMachine;
import smile.math.MathEx;
import smile.math.kernel.BinarySparseLinearKernel;
import smile.math.kernel.LinearKernel;
import smile.math.kernel.MercerKernel;
import smile.math.kernel.SparseLinearKernel;
import smile.util.IntSet;
import smile.util.SparseArray;

/* loaded from: input_file:smile/classification/SVM.class */
public class SVM<T> extends KernelMachine<T> implements Classifier<T> {
    public SVM(MercerKernel<T> mercerKernel, T[] tArr, double[] dArr, double d) {
        super(mercerKernel, tArr, dArr, d);
    }

    @Override // smile.classification.Classifier
    public int numClasses() {
        return 2;
    }

    @Override // smile.classification.Classifier
    public int[] classes() {
        return new int[]{-1, 1};
    }

    @Override // smile.classification.Classifier
    public int predict(T t) {
        return score(t) > 0.0d ? 1 : -1;
    }

    public static Classifier<double[]> fit(double[][] dArr, int[] iArr, double d, double d2) {
        return fit(dArr, iArr, d, d2, 1);
    }

    public static Classifier<double[]> fit(double[][] dArr, int[] iArr, double d, double d2, int i) {
        final KernelMachine<T> fit = new LASVM(new LinearKernel(), d, d2).fit(dArr, iArr, i);
        return new AbstractClassifier<double[]>(new IntSet(new int[]{-1, 1})) { // from class: smile.classification.SVM.1
            final LinearKernelMachine model;

            {
                this.model = LinearKernelMachine.of(fit);
            }

            @Override // smile.classification.Classifier
            public int predict(double[] dArr2) {
                return this.model.f(dArr2) > 0.0d ? 1 : -1;
            }
        };
    }

    public static Classifier<int[]> fit(int[][] iArr, int[] iArr2, int i, double d, double d2) {
        return fit(iArr, iArr2, i, d, d2, 1);
    }

    public static Classifier<int[]> fit(int[][] iArr, int[] iArr2, final int i, double d, double d2, int i2) {
        final KernelMachine<T> fit = new LASVM(new BinarySparseLinearKernel(), d, d2).fit(iArr, iArr2, i2);
        return new AbstractClassifier<int[]>(new IntSet(new int[]{-1, 1})) { // from class: smile.classification.SVM.2
            final LinearKernelMachine model;

            {
                this.model = LinearKernelMachine.binary(i, fit);
            }

            @Override // smile.classification.Classifier
            public int predict(int[] iArr3) {
                return this.model.f(iArr3) > 0.0d ? 1 : -1;
            }
        };
    }

    public static Classifier<SparseArray> fit(SparseArray[] sparseArrayArr, int[] iArr, int i, double d, double d2) {
        return fit(sparseArrayArr, iArr, i, d, d2, 1);
    }

    public static Classifier<SparseArray> fit(SparseArray[] sparseArrayArr, int[] iArr, final int i, double d, double d2, int i2) {
        final KernelMachine<T> fit = new LASVM(new SparseLinearKernel(), d, d2).fit(sparseArrayArr, iArr, i2);
        return new AbstractClassifier<SparseArray>(new IntSet(new int[]{-1, 1})) { // from class: smile.classification.SVM.3
            final LinearKernelMachine model;

            {
                this.model = LinearKernelMachine.sparse(i, fit);
            }

            @Override // smile.classification.Classifier
            public int predict(SparseArray sparseArray) {
                return this.model.f(sparseArray) > 0.0d ? 1 : -1;
            }
        };
    }

    public static <T> SVM<T> fit(T[] tArr, int[] iArr, MercerKernel<T> mercerKernel, double d, double d2) {
        return fit(tArr, iArr, mercerKernel, d, d2, 1);
    }

    public static <T> SVM<T> fit(T[] tArr, int[] iArr, MercerKernel<T> mercerKernel, double d, double d2, int i) {
        KernelMachine<T> fit = new LASVM(mercerKernel, d, d2).fit(tArr, iArr, i);
        return new SVM<>(fit.kernel(), fit.vectors(), fit.weights(), fit.intercept());
    }

    public static Classifier<double[]> fit(double[][] dArr, int[] iArr, Properties properties) {
        MercerKernel of = MercerKernel.of(properties.getProperty("smile.svm.kernel", "linear"));
        double parseDouble = Double.parseDouble(properties.getProperty("smile.svm.C", "1.0"));
        double parseDouble2 = Double.parseDouble(properties.getProperty("smile.svm.tolerance", "1E-3"));
        int parseInt = Integer.parseInt(properties.getProperty("smile.svm.epochs", "1"));
        int[] unique = MathEx.unique(iArr);
        String lowerCase = properties.getProperty("smile.svm.type", unique.length == 2 ? "binary" : "ovr").toLowerCase();
        boolean z = -1;
        switch (lowerCase.hashCode()) {
            case -1388966911:
                if (lowerCase.equals("binary")) {
                    z = 2;
                    break;
                }
                break;
            case 110440:
                if (lowerCase.equals("ovo")) {
                    z = true;
                    break;
                }
                break;
            case 110443:
                if (lowerCase.equals("ovr")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                return of instanceof LinearKernel ? OneVersusRest.fit(dArr, iArr, (dArr2, iArr2) -> {
                    return fit(dArr2, iArr2, parseDouble, parseDouble2, parseInt);
                }) : OneVersusRest.fit(dArr, iArr, (dArr3, iArr3) -> {
                    return fit(dArr3, iArr3, of, parseDouble, parseDouble2, parseInt);
                });
            case true:
                return of instanceof LinearKernel ? OneVersusOne.fit(dArr, iArr, (dArr4, iArr4) -> {
                    return fit(dArr4, iArr4, parseDouble, parseDouble2, parseInt);
                }) : OneVersusOne.fit(dArr, iArr, (dArr5, iArr5) -> {
                    return fit(dArr5, iArr5, of, parseDouble, parseDouble2, parseInt);
                });
            case true:
                Arrays.sort(unique);
                if (unique[0] != -1 || unique[1] != 1) {
                    iArr = (int[]) iArr.clone();
                    for (int i = 0; i < iArr.length; i++) {
                        iArr[i] = iArr[i] == unique[0] ? -1 : 1;
                    }
                }
                return of instanceof LinearKernel ? fit(dArr, iArr, parseDouble, parseDouble2, parseInt) : fit(dArr, iArr, of, parseDouble, parseDouble2, parseInt);
            default:
                throw new IllegalArgumentException("Unknown SVM type: " + lowerCase);
        }
    }
}
