package smile.classification;

import java.util.Arrays;
import smile.math.MathEx;
import smile.math.distance.Distance;
import smile.math.distance.EuclideanDistance;
import smile.math.distance.Metric;
import smile.neighbor.CoverTree;
import smile.neighbor.KDTree;
import smile.neighbor.KNNSearch;
import smile.neighbor.LinearSearch;
import smile.neighbor.Neighbor;
import smile.util.IntSet;

/* loaded from: input_file:smile/classification/KNN.class */
public class KNN<T> implements SoftClassifier<T> {
    private static final long serialVersionUID = 2;
    private KNNSearch<T, T> knn;
    private int[] y;
    private int k;
    private IntSet labels;

    public KNN(KNNSearch<T, T> kNNSearch, int[] iArr, int i) {
        this.knn = kNNSearch;
        this.k = i;
        this.y = iArr;
        this.labels = ClassLabels.fit(iArr).labels;
    }

    public static <T> KNN<T> fit(T[] tArr, int[] iArr, Distance<T> distance) {
        return fit(tArr, iArr, 1, distance);
    }

    public static <T> KNN<T> fit(T[] tArr, int[] iArr, int i, Distance<T> distance) {
        if (tArr.length != iArr.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", Integer.valueOf(tArr.length), Integer.valueOf(iArr.length)));
        }
        if (i < 1) {
            throw new IllegalArgumentException("Illegal k = " + i);
        }
        return new KNN<>(distance instanceof Metric ? new CoverTree(tArr, (Metric) distance) : new LinearSearch(tArr, distance), iArr, i);
    }

    public static KNN<double[]> fit(double[][] dArr, int[] iArr) {
        return fit(dArr, iArr, 1);
    }

    public static KNN<double[]> fit(double[][] dArr, int[] iArr, int i) {
        if (dArr.length != iArr.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", Integer.valueOf(dArr.length), Integer.valueOf(iArr.length)));
        }
        if (i < 1) {
            throw new IllegalArgumentException("Illegal k = " + i);
        }
        return new KNN<>(dArr[0].length < 10 ? new KDTree(dArr, dArr) : new CoverTree(dArr, new EuclideanDistance()), iArr, i);
    }

    @Override // smile.classification.Classifier
    public int predict(T t) {
        Neighbor<T, T>[] knn = this.knn.knn(t, this.k);
        if (this.k == 1) {
            if (knn[0] == null) {
                throw new IllegalStateException("No neighbor found.");
            }
            return this.y[knn[0].index];
        }
        int[] iArr = new int[this.labels.size()];
        for (Neighbor<T, T> neighbor : knn) {
            if (neighbor != null) {
                int indexOf = this.labels.indexOf(this.y[neighbor.index]);
                iArr[indexOf] = iArr[indexOf] + 1;
            }
        }
        int whichMax = MathEx.whichMax(iArr);
        if (iArr[whichMax] == 0) {
            throw new IllegalStateException("No neighbor found.");
        }
        return this.labels.valueOf(whichMax);
    }

    @Override // smile.classification.SoftClassifier
    public int predict(T t, double[] dArr) {
        Neighbor<T, T>[] knn = this.knn.knn(t, this.k);
        if (this.k == 1) {
            if (knn[0] == null) {
                throw new IllegalStateException("No neighbor found.");
            }
            Arrays.fill(dArr, 0.0d);
            dArr[this.labels.indexOf(this.y[knn[0].index])] = 1.0d;
            return this.y[knn[0].index];
        }
        int[] iArr = new int[this.labels.size()];
        for (int i = 0; i < this.k; i++) {
            int indexOf = this.labels.indexOf(this.y[knn[i].index]);
            iArr[indexOf] = iArr[indexOf] + 1;
        }
        int whichMax = MathEx.whichMax(iArr);
        if (iArr[whichMax] == 0) {
            throw new IllegalStateException("No neighbor found.");
        }
        for (int i2 = 0; i2 < iArr.length; i2++) {
            dArr[i2] = iArr[i2] / this.k;
        }
        return this.labels.valueOf(whichMax);
    }
}
