/*
 * Decompiled with CFR 0.152.
 */
package smile.classification;

import smile.classification.ClassLabels;
import smile.classification.SoftClassifier;
import smile.math.MathEx;
import smile.math.distance.Distance;
import smile.math.distance.EuclideanDistance;
import smile.math.distance.Metric;
import smile.neighbor.CoverTree;
import smile.neighbor.KDTree;
import smile.neighbor.KNNSearch;
import smile.neighbor.LinearSearch;
import smile.neighbor.NearestNeighborSearch;
import smile.neighbor.Neighbor;
import smile.util.IntSet;

public class KNN<T>
implements SoftClassifier<T> {
    private static final long serialVersionUID = 2L;
    private KNNSearch<T, T> knn;
    private int[] y;
    private int k;
    private IntSet labels;

    public KNN(KNNSearch<T, T> knn, int[] y, int k) {
        this.knn = knn;
        this.k = k;
        this.y = y;
        this.labels = ClassLabels.fit((int[])y).labels;
    }

    public static <T> KNN<T> fit(T[] x, int[] y, Distance<T> distance) {
        return KNN.fit(x, y, distance, 1);
    }

    public static <T> KNN<T> fit(T[] x, int[] y, Distance<T> distance, int k) {
        if (x.length != y.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
        }
        if (k < 1) {
            throw new IllegalArgumentException("Illegal k = " + k);
        }
        NearestNeighborSearch knn = distance instanceof Metric ? new CoverTree<T>(x, (Metric)distance) : new LinearSearch<T>(x, distance);
        return new KNN<T>(knn, y, k);
    }

    public static KNN<double[]> fit(double[][] x, int[] y) {
        return KNN.fit(x, y, 1);
    }

    public static KNN<double[]> fit(double[][] x, int[] y, int k) {
        if (x.length != y.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
        }
        if (k < 1) {
            throw new IllegalArgumentException("Illegal k = " + k);
        }
        NearestNeighborSearch knn = x[0].length < 10 ? new KDTree(x, (E[])x) : new CoverTree((E[])x, new EuclideanDistance());
        return new KNN<double[]>((KNNSearch<double[], double[]>)((Object)knn), y, k);
    }

    @Override
    public int predict(T x) {
        Neighbor<T, T>[] neighbors = this.knn.knn(x, this.k);
        if (this.k == 1) {
            return this.y[neighbors[0].index];
        }
        int[] count = new int[this.labels.size()];
        for (int i = 0; i < this.k; ++i) {
            int n = this.labels.indexOf(this.y[neighbors[i].index]);
            count[n] = count[n] + 1;
        }
        return this.labels.valueOf(MathEx.whichMax((int[])count));
    }

    @Override
    public int predict(T x, double[] posteriori) {
        int i;
        Neighbor<T, T>[] neighbors = this.knn.knn(x, this.k);
        if (this.k == 1) {
            return this.y[neighbors[0].index];
        }
        int[] count = new int[this.labels.size()];
        for (i = 0; i < this.k; ++i) {
            int n = this.labels.indexOf(this.y[neighbors[i].index]);
            count[n] = count[n] + 1;
        }
        for (i = 0; i < count.length; ++i) {
            posteriori[i] = (double)count[i] / (double)this.k;
        }
        return this.labels.valueOf(MathEx.whichMax((int[])count));
    }
}

