/*
 * Decompiled with CFR 0.152.
 */
package smile.base.svm;

import java.util.Arrays;
import java.util.stream.IntStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.base.svm.KernelMachine;
import smile.math.MathEx;
import smile.math.kernel.MercerKernel;

public class OCSVM<T> {
    private static final Logger logger = LoggerFactory.getLogger(OCSVM.class);
    private static final double TAU = 1.0E-12;
    private final MercerKernel<T> kernel;
    private final double nu;
    private final double tol;
    private double C;
    private T[] x;
    private double rho;
    private double[] alpha;
    private double[] O;
    private double[][] K;
    private int svmin = -1;
    private int svmax = -1;
    private double omin = Double.MAX_VALUE;
    private double omax = -1.7976931348623157E308;

    public OCSVM(MercerKernel<T> kernel, double nu, double tol) {
        if (nu <= 0.0 || nu > 1.0) {
            throw new IllegalArgumentException("Invalid nu: " + nu);
        }
        if (tol <= 0.0) {
            throw new IllegalArgumentException("Invalid tolerance of convergence test:" + tol);
        }
        this.kernel = kernel;
        this.nu = nu;
        this.tol = tol;
    }

    public KernelMachine<T> fit(T[] x) {
        int i2;
        this.x = x;
        int n = x.length;
        this.K = new double[n][n];
        IntStream.range(0, n).parallel().forEach(i -> {
            Object xi = x[i];
            double[] Ki = this.K[i];
            for (int j = 0; j < n; ++j) {
                Ki[j] = this.kernel.k(xi, x[j]);
            }
        });
        int vl = (int)Math.round(this.nu * (double)n);
        this.C = 1.0 / (double)vl;
        int[] index = MathEx.permutate((int)n);
        this.alpha = new double[n];
        for (i2 = 0; i2 < vl; ++i2) {
            this.alpha[index[i2]] = this.C;
        }
        this.O = new double[n];
        this.rho = Double.NEGATIVE_INFINITY;
        for (i2 = 0; i2 < n; ++i2) {
            double[] Ki = this.K[i2];
            for (int j = 0; j < n; ++j) {
                int n2 = i2;
                this.O[n2] = this.O[n2] + Ki[j] * this.alpha[j];
            }
            if (!(this.alpha[i2] > 0.0) || !(this.rho < this.O[i2])) continue;
            this.rho = this.O[i2];
        }
        this.minmax();
        int phase = Math.min(n, 1000);
        int count = 1;
        while (this.smo(this.tol)) {
            if (count % phase == 0) {
                logger.info("{} SMO iterations", (Object)count);
            }
            ++count;
        }
        int nsv = 0;
        int bsv = 0;
        for (int i3 = 0; i3 < n; ++i3) {
            if (!(this.alpha[i3] > 0.0)) continue;
            ++nsv;
            if (this.alpha[i3] != this.C) continue;
            ++bsv;
        }
        T[] vectors = Arrays.copyOf(x, nsv);
        double[] weight = new double[nsv];
        double b = -(this.rho - this.tol);
        int j = 0;
        for (int i4 = 0; i4 < n; ++i4) {
            if (!(this.alpha[i4] > 0.0)) continue;
            vectors[j] = x[i4];
            weight[j++] = this.alpha[i4];
        }
        logger.info("{} samples, {} support vectors, {} bounded", new Object[]{n, nsv, bsv});
        return new KernelMachine<T>(this.kernel, vectors, weight, b);
    }

    private void minmax() {
        this.svmin = -1;
        this.svmax = -1;
        this.omin = Double.MAX_VALUE;
        this.omax = -1.7976931348623157E308;
        int n = this.x.length;
        for (int i = 0; i < n; ++i) {
            double oi = this.O[i];
            double ai = this.alpha[i];
            if (oi < this.omin && ai < this.C) {
                this.svmin = i;
                this.omin = oi;
            }
            if (!(oi > this.omax) || !(ai > 0.0)) continue;
            this.svmax = i;
            this.omax = oi;
        }
    }

    private boolean smo(double epsgr) {
        double gain;
        double mu;
        double curv;
        double Z;
        int i;
        double best;
        int v1 = this.svmin;
        int v2 = this.svmax;
        int n = this.x.length;
        if (v2 < 0) {
            double O1 = this.O[v1];
            double[] K1 = this.K[v1];
            double k11 = K1[v1];
            best = 0.0;
            for (i = 0; i < n; ++i) {
                Z = this.O[i] - O1;
                curv = k11 + this.K[i][i] - 2.0 * K1[i];
                if (curv <= 0.0) {
                    curv = 1.0E-12;
                }
                mu = Z / curv;
                if (!(this.O[i] > O1) || !(this.alpha[i] > 0.0) || !((gain = -Z * mu) < best)) continue;
                best = gain;
                v2 = i;
            }
        }
        if (v1 < 0) {
            double O2 = this.O[v2];
            double[] K2 = this.K[v2];
            double k22 = K2[v2];
            best = 0.0;
            for (i = 0; i < n; ++i) {
                Z = O2 - this.O[i];
                curv = k22 + this.K[i][i] - 2.0 * K2[i];
                if (curv <= 0.0) {
                    curv = 1.0E-12;
                }
                mu = Z / curv;
                if (!(this.O[i] < O2) || !(this.alpha[i] < this.C) || !((gain = -Z * mu) < best)) continue;
                best = gain;
                v1 = i;
            }
        }
        if (v1 < 0 || v2 < 0) {
            return false;
        }
        double old_alpha1 = this.alpha[v1];
        double old_alpha2 = this.alpha[v2];
        double[] k1 = this.K[v1];
        double[] k2 = this.K[v2];
        double curv2 = this.K[v1][v1] + this.K[v2][v2] - 2.0 * this.K[v1][v2];
        if (curv2 <= 0.0) {
            curv2 = 1.0E-12;
        }
        double delta = (this.O[v1] - this.O[v2]) / curv2;
        double sum = this.alpha[v1] + this.alpha[v2];
        int n2 = v2;
        this.alpha[n2] = this.alpha[n2] + delta;
        int n3 = v1;
        this.alpha[n3] = this.alpha[n3] - delta;
        if (sum > this.C) {
            if (this.alpha[v1] > this.C) {
                this.alpha[v1] = this.C;
                this.alpha[v2] = sum - this.C;
            }
        } else if (this.alpha[v2] < 0.0) {
            this.alpha[v2] = 0.0;
            this.alpha[v1] = sum;
        }
        if (sum > this.C) {
            if (this.alpha[v2] > this.C) {
                this.alpha[v2] = this.C;
                this.alpha[v1] = sum - this.C;
            }
        } else if (this.alpha[v1] < 0.0) {
            this.alpha[v1] = 0.0;
            this.alpha[v2] = sum;
        }
        double delta_alpha1 = this.alpha[v1] - old_alpha1;
        double delta_alpha2 = this.alpha[v2] - old_alpha2;
        for (int i2 = 0; i2 < n; ++i2) {
            int n4 = i2;
            this.O[n4] = this.O[n4] + (k1[i2] * delta_alpha1 + k2[i2] * delta_alpha2);
        }
        this.rho = (this.omax + this.omin) / 2.0;
        this.minmax();
        return this.omax - this.omin > epsgr;
    }
}

