/*
 * Decompiled with CFR 0.152.
 */
package smile.base.svm;

import java.io.Serializable;
import java.lang.reflect.Array;
import java.util.Arrays;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.stream.Stream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.base.svm.KernelMachine;
import smile.base.svm.SupportVector;
import smile.math.MathEx;
import smile.math.kernel.MercerKernel;

public class LASVM<T>
implements Serializable {
    private static final long serialVersionUID = 2L;
    private static final Logger logger = LoggerFactory.getLogger(LASVM.class);
    private static final double TAU = 1.0E-12;
    private MercerKernel<T> kernel;
    private double Cp = 1.0;
    private double Cn = 1.0;
    private double tol = 0.001;
    private LinkedList<SupportVector<T>> sv = new LinkedList();
    private double b = 0.0;
    private boolean minmaxflag = false;
    private SupportVector<T> svmin = null;
    private SupportVector<T> svmax = null;
    private double gmin = Double.MAX_VALUE;
    private double gmax = -1.7976931348623157E308;
    private T[] x;
    private double[][] K;

    public LASVM(MercerKernel<T> kernel, double C, double tol) {
        this(kernel, C, C, tol);
    }

    public LASVM(MercerKernel<T> kernel, double Cp, double Cn, double tol) {
        this.kernel = kernel;
        this.Cp = Cp;
        this.Cn = Cn;
        this.tol = tol;
    }

    public KernelMachine<T> fit(T[] x, int[] y) {
        return this.fit(x, y, 2);
    }

    public KernelMachine<T> fit(T[] x, int[] y, int epoch) {
        this.x = x;
        this.K = new double[x.length][];
        this.init(x, y);
        int phase = Math.min(x.length, 1000);
        int iter = 0;
        for (int e = 0; e < epoch; ++e) {
            for (int i : MathEx.permutate((int)x.length)) {
                this.process(i, x[i], y[i]);
                do {
                    this.reprocess(this.tol);
                    this.minmax();
                } while (this.gmax - this.gmin > 1000.0);
                if (++iter % phase != 0) continue;
                logger.info("{} iterations, {} support vectors", (Object)iter, (Object)this.sv.size());
            }
        }
        this.finish();
        int n = this.sv.size();
        Object[] vectors = (Object[])Array.newInstance(x.getClass().getComponentType(), n);
        double[] alpha = new double[n];
        for (int i = 0; i < n; ++i) {
            SupportVector<T> v = this.sv.get(i);
            vectors[i] = v.x;
            alpha[i] = v.alpha;
        }
        return new KernelMachine<Object>(this.kernel, vectors, alpha, this.b);
    }

    private void init(T[] x, int[] y) {
        int few = 5;
        int cp = 0;
        int cn = 0;
        for (int i : MathEx.permutate((int)x.length)) {
            if (y[i] == 1 && cp < few) {
                if (this.process(i, x[i], y[i])) {
                    ++cp;
                }
            } else if (y[i] == -1 && cn < few && this.process(i, x[i], y[i])) {
                ++cn;
            }
            if (cp >= few && cn >= few) break;
        }
    }

    private void minmax() {
        if (this.minmaxflag) {
            return;
        }
        this.gmin = Double.MAX_VALUE;
        this.gmax = -1.7976931348623157E308;
        for (SupportVector supportVector : this.sv) {
            double gi = supportVector.g;
            double ai = supportVector.alpha;
            if (gi < this.gmin && ai > supportVector.cmin) {
                this.svmin = supportVector;
                this.gmin = gi;
            }
            if (!(gi > this.gmax) || !(ai < supportVector.cmax)) continue;
            this.svmax = supportVector;
            this.gmax = gi;
        }
        this.minmaxflag = true;
    }

    private double k(int i, int j) {
        double k = Double.NaN;
        double[] ki = this.K[i];
        if (ki != null) {
            k = ki[j];
        }
        if (Double.isNaN(k)) {
            k = this.kernel.k(this.x[i], this.x[j]);
            if (ki != null) {
                ki[j] = k;
            }
        }
        return k;
    }

    private boolean smo(SupportVector<T> v1, SupportVector<T> v2, double epsgr) {
        double ostep;
        double step;
        double curv;
        double gain;
        double mu;
        double k;
        double Z;
        double best;
        double gm;
        double km;
        if (v1 == null && v2 == null) {
            this.minmax();
            if (this.gmax > -this.gmin) {
                v2 = this.svmax;
            } else {
                v1 = this.svmin;
            }
        }
        double k12 = Double.NaN;
        if (v2 == null) {
            km = v1.k;
            gm = v1.g;
            best = 0.0;
            for (SupportVector supportVector : this.sv) {
                Z = supportVector.g - gm;
                k = this.k(v1.i, supportVector.i);
                double curv2 = km + supportVector.k - 2.0 * k;
                if (curv2 <= 0.0) {
                    curv2 = 1.0E-12;
                }
                if (!((mu = Z / curv2) > 0.0 && supportVector.alpha < supportVector.cmax) && (!(mu < 0.0) || !(supportVector.alpha > supportVector.cmin)) || !((gain = Z * mu) > best)) continue;
                best = gain;
                v2 = supportVector;
                k12 = k;
            }
        }
        if (v1 == null) {
            km = v2.k;
            gm = v2.g;
            best = 0.0;
            for (SupportVector supportVector : this.sv) {
                Z = gm - supportVector.g;
                k = this.k(v2.i, supportVector.i);
                double curv2 = km + supportVector.k - 2.0 * k;
                if (curv2 <= 0.0) {
                    curv2 = 1.0E-12;
                }
                if (!((mu = Z / curv2) > 0.0 && supportVector.alpha > supportVector.cmin) && (!(mu < 0.0) || !(supportVector.alpha < supportVector.cmax)) || !((gain = Z * mu) > best)) continue;
                best = gain;
                v1 = supportVector;
                k12 = k;
            }
        }
        if (v1 == null || v2 == null) {
            return false;
        }
        if (Double.isNaN(k12)) {
            k12 = this.kernel.k(v1.x, v2.x);
        }
        if ((curv = v1.k + v2.k - 2.0 * k12) <= 0.0) {
            curv = 1.0E-12;
        }
        if ((step = (v2.g - v1.g) / curv) >= 0.0) {
            ostep = v1.alpha - v1.cmin;
            if (ostep < step) {
                step = ostep;
            }
            if ((ostep = v2.cmax - v2.alpha) < step) {
                step = ostep;
            }
        } else {
            ostep = v2.cmin - v2.alpha;
            if (ostep > step) {
                step = ostep;
            }
            if ((ostep = v1.alpha - v1.cmax) > step) {
                step = ostep;
            }
        }
        v1.alpha -= step;
        v2.alpha += step;
        for (SupportVector supportVector : this.sv) {
            supportVector.g -= step * (this.k(v2.i, supportVector.i) - this.k(v1.i, supportVector.i));
        }
        this.minmaxflag = false;
        this.minmax();
        this.b = (this.gmax + this.gmin) / 2.0;
        return this.gmax - this.gmin > epsgr;
    }

    private boolean process(int i, T x, int y) {
        if (y != 1 && y != -1) {
            throw new IllegalArgumentException("Invalid label: " + y);
        }
        for (SupportVector supportVector : this.sv) {
            if (supportVector.x != x) continue;
            return true;
        }
        double g = y;
        double[] cache = new double[this.K.length];
        Arrays.fill(cache, Double.NaN);
        g -= ((Stream)this.sv.stream().parallel()).mapToDouble(v -> {
            double k;
            cache[v.i] = k = this.kernel.k(v.x, x);
            return v.alpha * k;
        }).sum();
        this.minmax();
        if (this.gmin < this.gmax && (y > 0 && g < this.gmin || y < 0 && g > this.gmax)) {
            return false;
        }
        SupportVector<T> v3 = new SupportVector<T>(i, x, y, 0.0, g, this.Cp, this.Cn, this.kernel.k(x, x));
        this.sv.addFirst(v3);
        this.K[i] = cache;
        if (y > 0) {
            this.smo(null, v3, 0.0);
        } else {
            this.smo(v3, null, 0.0);
        }
        this.minmaxflag = false;
        return true;
    }

    private boolean reprocess(double epsgr) {
        boolean status = this.smo(null, null, epsgr);
        this.evict();
        return status;
    }

    private void finish() {
        this.finish(this.tol, this.sv.size());
        int bsv = 0;
        for (SupportVector supportVector : this.sv) {
            if (supportVector.alpha != supportVector.cmin && supportVector.alpha != supportVector.cmax) continue;
            ++bsv;
        }
        logger.info("{} samples, {} support vectors, {} bounded", new Object[]{this.x.length, this.sv.size(), bsv});
    }

    private void finish(double epsgr, int maxIter) {
        logger.info("Finalizing the training by reprocess.");
        for (int count = 1; count <= maxIter && this.smo(null, null, epsgr); ++count) {
            if (count % 1000 != 0) continue;
            logger.info("{} reprocess iterations.", (Object)count);
        }
        this.evict();
    }

    private void evict() {
        this.minmax();
        Iterator iter = this.sv.iterator();
        while (iter.hasNext()) {
            SupportVector v = (SupportVector)iter.next();
            if (v.alpha != 0.0 || !(v.g >= this.gmax && 0.0 >= v.cmax) && (!(v.g <= this.gmin) || !(0.0 <= v.cmin))) continue;
            this.K[v.i] = null;
            iter.remove();
        }
    }
}

