package smile.base.svm;

import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Stream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.math.MathEx;
import smile.math.kernel.MercerKernel;

/* loaded from: input_file:smile/base/svm/SVR.class */
public class SVR<T> {
    private static final Logger logger = LoggerFactory.getLogger(SVR.class);
    private static final double TAU = 1.0E-12d;
    private final MercerKernel<T> kernel;
    private final double eps;
    private final double C;
    private final double tol;
    private List<SVR<T>.SupportVector> vectors;
    private double b = 0.0d;
    private SVR<T>.SupportVector svmin = null;
    private SVR<T>.SupportVector svmax = null;
    private double gmin = Double.MAX_VALUE;
    private double gmax = -1.7976931348623157E308d;
    private int gminindex;
    private int gmaxindex;
    private double[][] K;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:smile/base/svm/SVR$SupportVector.class */
    public class SupportVector {
        final int i;
        final T x;
        double[] alpha = new double[2];
        double[] g = new double[2];
        double k;

        SupportVector(int i, T t, double d) {
            this.i = i;
            this.x = t;
            this.g[0] = SVR.this.eps + d;
            this.g[1] = SVR.this.eps - d;
            this.k = SVR.this.kernel.k(t, t);
        }
    }

    public SVR(MercerKernel<T> mercerKernel, double d, double d2, double d3) {
        if (d <= 0.0d) {
            throw new IllegalArgumentException("Invalid error threshold: " + d);
        }
        if (d2 < 0.0d) {
            throw new IllegalArgumentException("Invalid soft margin penalty: " + d2);
        }
        if (d3 <= 0.0d) {
            throw new IllegalArgumentException("Invalid tolerance of convergence test:" + d3);
        }
        this.kernel = mercerKernel;
        this.eps = d;
        this.C = d2;
        this.tol = d3;
    }

    /* JADX WARN: Type inference failed for: r1v3, types: [double[], double[][]] */
    public smile.regression.KernelMachine<T> fit(T[] tArr, double[] dArr) {
        if (tArr.length != dArr.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", Integer.valueOf(tArr.length), Integer.valueOf(dArr.length)));
        }
        int length = tArr.length;
        this.K = new double[length];
        this.vectors = new ArrayList(length);
        for (int i = 0; i < length; i++) {
            this.vectors.add(new SupportVector(i, tArr[i], dArr[i]));
        }
        minmax();
        int min = Math.min(length, 1000);
        int i2 = 1;
        while (smo(this.tol)) {
            if (i2 % min == 0) {
                logger.info("{} SMO iterations", Integer.valueOf(i2));
            }
            i2++;
        }
        int i3 = 0;
        int i4 = 0;
        for (int i5 = 0; i5 < length; i5++) {
            SVR<T>.SupportVector supportVector = this.vectors.get(i5);
            if (supportVector.alpha[0] == supportVector.alpha[1]) {
                this.vectors.set(i5, null);
            } else {
                i3++;
                if (supportVector.alpha[0] == this.C || supportVector.alpha[1] == this.C) {
                    i4++;
                }
            }
        }
        double[] dArr2 = new double[i3];
        Object[] objArr = (Object[]) Array.newInstance(tArr.getClass().getComponentType(), i3);
        int i6 = 0;
        for (SVR<T>.SupportVector supportVector2 : this.vectors) {
            if (supportVector2 != null) {
                objArr[i6] = supportVector2.x;
                int i7 = i6;
                i6++;
                dArr2[i7] = supportVector2.alpha[1] - supportVector2.alpha[0];
            }
        }
        logger.info("{} samples, {} support vectors, {} bounded", new Object[]{Integer.valueOf(length), Integer.valueOf(i3), Integer.valueOf(i4)});
        return new smile.regression.KernelMachine<>(this.kernel, objArr, dArr2, this.b);
    }

    private void minmax() {
        this.gmin = Double.MAX_VALUE;
        this.gmax = -1.7976931348623157E308d;
        for (SVR<T>.SupportVector supportVector : this.vectors) {
            double d = -supportVector.g[0];
            double d2 = supportVector.alpha[0];
            if (d < this.gmin && d2 > 0.0d) {
                this.svmin = supportVector;
                this.gmin = d;
                this.gminindex = 0;
            }
            if (d > this.gmax && d2 < this.C) {
                this.svmax = supportVector;
                this.gmax = d;
                this.gmaxindex = 0;
            }
            double d3 = supportVector.g[1];
            double d4 = supportVector.alpha[1];
            if (d3 < this.gmin && d4 < this.C) {
                this.svmin = supportVector;
                this.gmin = d3;
                this.gminindex = 1;
            }
            if (d3 > this.gmax && d4 > 0.0d) {
                this.svmax = supportVector;
                this.gmax = d3;
                this.gmaxindex = 1;
            }
        }
    }

    private double[] gram(SVR<T>.SupportVector supportVector) {
        if (this.K[supportVector.i] == null) {
            double[] dArr = new double[this.vectors.size()];
            ((Stream) this.vectors.stream().parallel()).forEach(supportVector2 -> {
                dArr[supportVector2.i] = this.kernel.k(supportVector.x, supportVector2.x);
            });
            this.K[supportVector.i] = dArr;
        }
        return this.K[supportVector.i];
    }

    private boolean smo(double d) {
        SVR<T>.SupportVector supportVector = this.svmax;
        int i = this.gmaxindex;
        double d2 = supportVector.alpha[i];
        double[] gram = gram(supportVector);
        SVR<T>.SupportVector supportVector2 = this.svmin;
        int i2 = this.gminindex;
        double d3 = supportVector2.alpha[i2];
        double d4 = 0.0d;
        double d5 = i == 0 ? -supportVector.g[0] : supportVector.g[1];
        for (SVR<T>.SupportVector supportVector3 : this.vectors) {
            double d6 = (supportVector.k + supportVector3.k) - (2.0d * gram[supportVector3.i]);
            if (d6 <= 0.0d) {
                d6 = 1.0E-12d;
            }
            double d7 = -supportVector3.g[0];
            if (supportVector3.alpha[0] > 0.0d && d7 < d5) {
                double d8 = (-MathEx.pow2(d5 - d7)) / d6;
                if (d8 < d4) {
                    d4 = d8;
                    supportVector2 = supportVector3;
                    i2 = 0;
                    d3 = supportVector2.alpha[0];
                }
            }
            double d9 = supportVector3.g[1];
            if (supportVector3.alpha[1] < this.C && d9 < d5) {
                double d10 = (-MathEx.pow2(d5 - d9)) / d6;
                if (d10 < d4) {
                    d4 = d10;
                    supportVector2 = supportVector3;
                    i2 = 1;
                    d3 = supportVector2.alpha[1];
                }
            }
        }
        double[] gram2 = gram(supportVector2);
        double d11 = (supportVector.k + supportVector2.k) - (2.0d * gram[supportVector2.i]);
        if (d11 <= 0.0d) {
            d11 = 1.0E-12d;
        }
        if (i != i2) {
            double d12 = ((-supportVector.g[i]) - supportVector2.g[i2]) / d11;
            double d13 = supportVector.alpha[i] - supportVector2.alpha[i2];
            double[] dArr = supportVector.alpha;
            dArr[i] = dArr[i] + d12;
            double[] dArr2 = supportVector2.alpha;
            int i3 = i2;
            dArr2[i3] = dArr2[i3] + d12;
            if (d13 > 0.0d) {
                if (supportVector2.alpha[i2] < 0.0d) {
                    supportVector2.alpha[i2] = 0.0d;
                    supportVector.alpha[i] = d13;
                }
            } else if (supportVector.alpha[i] < 0.0d) {
                supportVector.alpha[i] = 0.0d;
                supportVector2.alpha[i2] = -d13;
            }
            if (d13 > 0.0d) {
                if (supportVector.alpha[i] > this.C) {
                    supportVector.alpha[i] = this.C;
                    supportVector2.alpha[i2] = this.C - d13;
                }
            } else if (supportVector2.alpha[i2] > this.C) {
                supportVector2.alpha[i2] = this.C;
                supportVector.alpha[i] = this.C + d13;
            }
        } else {
            double d14 = (supportVector.g[i] - supportVector2.g[i2]) / d11;
            double d15 = supportVector.alpha[i] + supportVector2.alpha[i2];
            double[] dArr3 = supportVector.alpha;
            dArr3[i] = dArr3[i] - d14;
            double[] dArr4 = supportVector2.alpha;
            int i4 = i2;
            dArr4[i4] = dArr4[i4] + d14;
            if (d15 > this.C) {
                if (supportVector.alpha[i] > this.C) {
                    supportVector.alpha[i] = this.C;
                    supportVector2.alpha[i2] = d15 - this.C;
                }
            } else if (supportVector2.alpha[i2] < 0.0d) {
                supportVector2.alpha[i2] = 0.0d;
                supportVector.alpha[i] = d15;
            }
            if (d15 > this.C) {
                if (supportVector2.alpha[i2] > this.C) {
                    supportVector2.alpha[i2] = this.C;
                    supportVector.alpha[i] = d15 - this.C;
                }
            } else if (supportVector.alpha[i] < 0.0d) {
                supportVector.alpha[i] = 0.0d;
                supportVector2.alpha[i2] = d15;
            }
        }
        double d16 = supportVector.alpha[i] - d2;
        double d17 = supportVector2.alpha[i2] - d3;
        int i5 = (2 * i) - 1;
        int i6 = (2 * i2) - 1;
        for (SVR<T>.SupportVector supportVector4 : this.vectors) {
            double[] dArr5 = supportVector4.g;
            dArr5[0] = dArr5[0] - (((i5 * gram[supportVector4.i]) * d16) + ((i6 * gram2[supportVector4.i]) * d17));
            double[] dArr6 = supportVector4.g;
            dArr6[1] = dArr6[1] + (i5 * gram[supportVector4.i] * d16) + (i6 * gram2[supportVector4.i] * d17);
        }
        minmax();
        this.b = (-(this.gmax + this.gmin)) / 2.0d;
        return this.gmax - this.gmin > d;
    }
}
