package smile.neighbor;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;
import smile.math.MathEx;
import smile.neighbor.lsh.Bucket;
import smile.neighbor.lsh.Hash;
import smile.neighbor.lsh.MultiProbeHash;
import smile.neighbor.lsh.MultiProbeSample;
import smile.neighbor.lsh.PosterioriModel;
import smile.sort.HeapSelect;
import smile.util.IntArrayList;

/* loaded from: input_file:smile/neighbor/MPLSH.class */
public class MPLSH<E> extends LSH<E> {
    private static final long serialVersionUID = 2;
    private List<PosterioriModel> model;

    public MPLSH(int i, int i2, int i3, double d) {
        this(i, i2, i3, d, 1017881);
    }

    public MPLSH(int i, int i2, int i3, double d, int i4) {
        super(i, i2, i3, d, i4);
    }

    @Override // smile.neighbor.LSH
    protected void initHashTable(int i, int i2, int i3, double d, int i4) {
        this.hash = new ArrayList(i2);
        for (int i5 = 0; i5 < i2; i5++) {
            this.hash.add(new MultiProbeHash(i, i3, d, i4));
        }
    }

    @Override // smile.neighbor.LSH
    public String toString() {
        return "Multi-Probe " + super.toString();
    }

    public void fit(RNNSearch<double[], double[]> rNNSearch, double[][] dArr, double d) {
        fit(rNNSearch, dArr, d, 2500);
    }

    public void fit(RNNSearch<double[], double[]> rNNSearch, double[][] dArr, double d, int i) {
        fit(rNNSearch, dArr, d, i, 0.2d);
    }

    public void fit(RNNSearch<double[], double[]> rNNSearch, double[][] dArr, double d, int i, double d2) {
        MultiProbeSample[] multiProbeSampleArr = new MultiProbeSample[dArr.length];
        for (int i2 = 0; i2 < dArr.length; i2++) {
            multiProbeSampleArr[i2] = new MultiProbeSample(dArr[i2], new LinkedList());
            ArrayList arrayList = new ArrayList();
            rNNSearch.search(dArr[i2], d, arrayList);
            Iterator<E> it = arrayList.iterator();
            while (it.hasNext()) {
                multiProbeSampleArr[i2].neighbors.add(this.keys.get(((Neighbor) it.next()).index));
            }
        }
        this.model = new ArrayList(this.hash.size());
        Iterator<Hash> it2 = this.hash.iterator();
        while (it2.hasNext()) {
            this.model.add(new PosterioriModel((MultiProbeHash) it2.next(), multiProbeSampleArr, i, d2));
        }
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // smile.neighbor.LSH, smile.neighbor.KNNSearch
    public Neighbor<double[], E> nearest(double[] dArr) {
        return this.model == null ? super.nearest(dArr) : nearest(dArr, 0.95d, 100);
    }

    public Neighbor<double[], E> nearest(double[] dArr, double d, int i) {
        if (d > 1.0d || d < 0.0d) {
            throw new IllegalArgumentException("Invalid recall: " + d);
        }
        double[] dArr2 = null;
        int i2 = -1;
        double d2 = Double.MAX_VALUE;
        Iterator<Integer> it = getCandidates(dArr, d, i).iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            double[] dArr3 = this.keys.get(intValue);
            if (dArr != dArr3) {
                double distance = MathEx.distance(dArr, dArr3);
                if (distance < d2) {
                    i2 = intValue;
                    d2 = distance;
                    dArr2 = dArr3;
                }
            }
        }
        if (i2 == -1) {
            return null;
        }
        return new Neighbor<>(dArr2, this.data.get(i2), i2, d2);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // smile.neighbor.LSH, smile.neighbor.KNNSearch
    public Neighbor<double[], E>[] search(double[] dArr, int i) {
        return this.model == null ? super.search(dArr, i) : search(dArr, i, 0.95d, 100);
    }

    public Neighbor<double[], E>[] search(double[] dArr, int i, double d, int i2) {
        if (d > 1.0d || d < 0.0d) {
            throw new IllegalArgumentException("Invalid recall: " + d);
        }
        if (i < 1) {
            throw new IllegalArgumentException("Invalid k: " + i);
        }
        Set<Integer> candidates = getCandidates(dArr, d, i2);
        HeapSelect heapSelect = new HeapSelect(new Neighbor[Math.min(i, candidates.size())]);
        Iterator<Integer> it = candidates.iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            double[] dArr2 = this.keys.get(intValue);
            if (dArr != dArr2) {
                heapSelect.add(new Neighbor(dArr2, this.data.get(intValue), intValue, MathEx.distance(dArr, dArr2)));
            }
        }
        heapSelect.sort();
        return (Neighbor[]) heapSelect.toArray();
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // smile.neighbor.LSH, smile.neighbor.RNNSearch
    public void search(double[] dArr, double d, List<Neighbor<double[], E>> list) {
        if (this.model == null) {
            super.search(dArr, d, (List) list);
        } else {
            search(dArr, d, list, 0.95d, 100);
        }
    }

    public void search(double[] dArr, double d, List<Neighbor<double[], E>> list, double d2, int i) {
        if (d <= 0.0d) {
            throw new IllegalArgumentException("Invalid radius: " + d);
        }
        if (d2 > 1.0d || d2 < 0.0d) {
            throw new IllegalArgumentException("Invalid recall: " + d2);
        }
        Iterator<Integer> it = getCandidates(dArr, d2, i).iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            double[] dArr2 = this.keys.get(intValue);
            if (dArr != dArr2) {
                double distance = MathEx.distance(dArr, dArr2);
                if (distance <= d) {
                    list.add(new Neighbor<>(dArr2, this.data.get(intValue), intValue, distance));
                }
            }
        }
    }

    private Set<Integer> getCandidates(double[] dArr, double d, int i) {
        double pow = 1.0d - Math.pow(1.0d - d, 1.0d / this.hash.size());
        LinkedHashSet linkedHashSet = new LinkedHashSet();
        for (int i2 = 0; i2 < this.hash.size(); i2++) {
            IntArrayList probeSequence = this.model.get(i2).getProbeSequence(dArr, pow, i);
            for (int i3 = 0; i3 < probeSequence.size(); i3++) {
                Bucket bucket = this.hash.get(i2).get(probeSequence.get(i3));
                if (bucket != null) {
                    IntArrayList points = bucket.points();
                    for (int i4 = 0; i4 < points.size(); i4++) {
                        linkedHashSet.add(Integer.valueOf(points.get(i4)));
                    }
                }
            }
        }
        return linkedHashSet;
    }
}
