/*
 * Decompiled with CFR 0.152.
 */
package smile.neighbor;

import java.io.Serializable;
import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.math.MathEx;
import smile.math.distance.Metric;
import smile.neighbor.KNNSearch;
import smile.neighbor.Neighbor;
import smile.neighbor.RNNSearch;
import smile.sort.DoubleHeapSelect;
import smile.util.DoubleArrayList;

public class CoverTree<K, V>
implements KNNSearch<K, V>,
RNNSearch<K, V>,
Serializable {
    private static final long serialVersionUID = 2L;
    private static final Logger logger = LoggerFactory.getLogger(CoverTree.class);
    private final List<K> keys;
    private final List<V> data;
    private final Metric<K> distance;
    private Node root;
    private final double base;
    private final double invLogBase;

    public CoverTree(K[] keys, V[] data, Metric<K> distance) {
        this(keys, data, distance, 1.3);
    }

    public CoverTree(List<K> keys, List<V> data, Metric<K> distance) {
        this(keys, data, distance, 1.3);
    }

    public CoverTree(K[] keys, V[] data, Metric<K> distance, double base) {
        this(Arrays.asList(keys), Arrays.asList(data), distance, base);
    }

    public CoverTree(List<K> keys, List<V> data, Metric<K> distance, double base) {
        if (keys.size() != data.size()) {
            throw new IllegalArgumentException("Different size of keys and data objects");
        }
        this.keys = keys;
        this.data = data;
        this.distance = distance;
        this.base = base;
        this.invLogBase = 1.0 / Math.log(base);
        this.buildCoverTree();
    }

    public static <T> CoverTree<T, T> of(T[] data, Metric<T> distance) {
        return new CoverTree<T, T>(data, data, distance);
    }

    public static <T> CoverTree<T, T> of(T[] data, Metric<T> distance, double base) {
        return new CoverTree<T, T>(data, data, distance, base);
    }

    public static <T> CoverTree<T, T> of(List<T> data, Metric<T> distance) {
        return new CoverTree<T, T>(data, data, distance);
    }

    public static <T> CoverTree<T, T> of(List<T> data, Metric<T> distance, double base) {
        return new CoverTree<T, T>(data, data, distance, base);
    }

    public String toString() {
        return String.format("Cover Tree (%s)", this.distance);
    }

    private void buildCoverTree() {
        ArrayList<DistanceSet> pointSet = new ArrayList<DistanceSet>();
        ArrayList<DistanceSet> consumedSet = new ArrayList<DistanceSet>();
        K point = this.keys.getFirst();
        int idx = 0;
        int n = this.keys.size();
        double maxDist = -1.0;
        for (int i = 1; i < n; ++i) {
            DistanceSet set = new DistanceSet(this, i);
            double dist = this.distance.d(point, this.keys.get(i));
            set.dist.add(dist);
            pointSet.add(set);
            if (!(dist > maxDist)) continue;
            maxDist = dist;
        }
        this.root = this.batchInsert(idx, this.getScale(maxDist), this.getScale(maxDist), pointSet, consumedSet);
    }

    private Node batchInsert(int p, int maxScale, int topScale, ArrayList<DistanceSet> pointSet, ArrayList<DistanceSet> consumedSet) {
        if (pointSet.isEmpty()) {
            return this.newLeaf(p);
        }
        double maxDist = this.max(pointSet);
        int nextScale = Math.min(maxScale - 1, this.getScale(maxDist));
        if (nextScale == Integer.MIN_VALUE) {
            ArrayList<Node> children = new ArrayList<Node>();
            Node leaf = this.newLeaf(p);
            children.add(leaf);
            while (!pointSet.isEmpty()) {
                DistanceSet set = pointSet.getLast();
                pointSet.removeLast();
                leaf = this.newLeaf(set.idx);
                children.add(leaf);
                consumedSet.add(set);
            }
            Node node = new Node(this, p);
            node.scale = 100;
            node.maxDist = 0.0;
            node.children = children;
            return node;
        }
        ArrayList<DistanceSet> far = new ArrayList<DistanceSet>();
        this.split(pointSet, far, maxScale);
        Node child = this.batchInsert(p, nextScale, topScale, pointSet, consumedSet);
        if (pointSet.isEmpty()) {
            pointSet.addAll(far);
            return child;
        }
        ArrayList<Node> children = new ArrayList<Node>();
        children.add(child);
        ArrayList<DistanceSet> newPointSet = new ArrayList<DistanceSet>();
        ArrayList<DistanceSet> newConsumedSet = new ArrayList<DistanceSet>();
        while (!pointSet.isEmpty()) {
            DistanceSet set = pointSet.getLast();
            pointSet.removeLast();
            double newDist = set.dist.get(set.dist.size() - 1);
            consumedSet.add(set);
            this.distSplit(pointSet, newPointSet, set.getKey(), maxScale);
            this.distSplit(far, newPointSet, set.getKey(), maxScale);
            Node newChild = this.batchInsert(set.idx, nextScale, topScale, newPointSet, newConsumedSet);
            newChild.parentDist = newDist;
            children.add(newChild);
            double fmax = this.getCoverRadius(maxScale);
            for (DistanceSet ds : newPointSet) {
                ds.dist.remove(ds.dist.size() - 1);
                if (ds.dist.get(ds.dist.size() - 1) <= fmax) {
                    pointSet.add(ds);
                    continue;
                }
                far.add(ds);
            }
            for (DistanceSet ds : newConsumedSet) {
                ds.dist.remove(ds.dist.size() - 1);
                consumedSet.add(ds);
            }
            newPointSet.clear();
            newConsumedSet.clear();
        }
        pointSet.addAll(far);
        Node node = new Node(this, p);
        node.scale = topScale - maxScale;
        node.maxDist = this.max(consumedSet);
        node.children = children;
        return node;
    }

    private double getCoverRadius(int s) {
        return Math.pow(this.base, s);
    }

    private int getScale(double d) {
        return (int)Math.ceil(this.invLogBase * Math.log(d));
    }

    private Node newLeaf(int idx) {
        return new Node(this, idx, 0.0, 0.0, null, 100);
    }

    private double max(ArrayList<DistanceSet> v) {
        double max = 0.0;
        for (DistanceSet n : v) {
            if (!(max < n.dist.get(n.dist.size() - 1))) continue;
            max = n.dist.get(n.dist.size() - 1);
        }
        return max;
    }

    private void split(ArrayList<DistanceSet> pointSet, ArrayList<DistanceSet> farSet, int maxScale) {
        double fmax = this.getCoverRadius(maxScale);
        ArrayList<DistanceSet> newSet = new ArrayList<DistanceSet>();
        for (DistanceSet ds : pointSet) {
            if (ds.dist.get(ds.dist.size() - 1) <= fmax) {
                newSet.add(ds);
                continue;
            }
            farSet.add(ds);
        }
        pointSet.clear();
        pointSet.addAll(newSet);
    }

    private void distSplit(ArrayList<DistanceSet> pointSet, ArrayList<DistanceSet> newPointSet, K newPoint, int maxScale) {
        double fmax = this.getCoverRadius(maxScale);
        ArrayList<DistanceSet> newSet = new ArrayList<DistanceSet>();
        for (DistanceSet ds : pointSet) {
            double newDist = this.distance.d(newPoint, ds.getKey());
            if (newDist <= fmax) {
                ds.dist.add(newDist);
                newPointSet.add(ds);
                continue;
            }
            newSet.add(ds);
        }
        pointSet.clear();
        pointSet.addAll(newSet);
    }

    @Override
    public Neighbor<K, V>[] search(K q, int k) {
        if (k <= 0) {
            throw new IllegalArgumentException("Invalid k: " + k);
        }
        if (k > this.data.size()) {
            throw new IllegalArgumentException("Neighbor array length is larger than the dataset size");
        }
        Object e = this.root.getKey();
        double d = this.distance.d(e, q);
        Neighbor n1 = new Neighbor(e, this.root.getValue(), this.root.idx, d);
        Neighbor[] a1 = (Neighbor[])Array.newInstance(n1.getClass(), 1);
        if (this.root.children == null) {
            a1[0] = n1;
            return a1;
        }
        ArrayList<DistanceNode> currentCoverSet = new ArrayList<DistanceNode>();
        ArrayList<DistanceNode> zeroSet = new ArrayList<DistanceNode>();
        currentCoverSet.add(new DistanceNode(this, d, this.root));
        DoubleHeapSelect heap = new DoubleHeapSelect(k);
        heap.add(Double.MAX_VALUE);
        boolean emptyHeap = true;
        if (this.root.getKey() != q) {
            heap.add(d);
            emptyHeap = false;
        }
        while (!currentCoverSet.isEmpty()) {
            ArrayList<DistanceNode> nextCoverSet = new ArrayList<DistanceNode>();
            for (DistanceNode par : currentCoverSet) {
                Node parent = par.node;
                for (int c = 0; c < parent.children.size(); ++c) {
                    double upperBound;
                    Node child = parent.children.get(c);
                    d = c == 0 ? par.dist : this.distance.d(child.getKey(), q);
                    double d2 = upperBound = emptyHeap ? Double.MAX_VALUE : heap.peek();
                    if (!(d <= upperBound + child.maxDist)) continue;
                    if (c > 0 && d < upperBound && child.getKey() != q) {
                        heap.add(d);
                    }
                    if (child.children != null) {
                        nextCoverSet.add(new DistanceNode(this, d, child));
                        continue;
                    }
                    if (!(d <= upperBound)) continue;
                    zeroSet.add(new DistanceNode(this, d, child));
                }
            }
            currentCoverSet = nextCoverSet;
        }
        ArrayList list = new ArrayList();
        double upperBound = heap.peek();
        for (DistanceNode ds : zeroSet) {
            if (!(ds.dist <= upperBound) || ds.node.getKey() == q) continue;
            e = ds.node.getKey();
            list.add(new Neighbor(e, ds.node.getValue(), ds.node.idx, ds.dist));
        }
        Object[] neighbors = list.toArray(a1);
        if (neighbors.length < k) {
            logger.warn("CoverTree.knn({}) returns only {} neighbors", (Object)k, (Object)neighbors.length);
        }
        Arrays.sort(neighbors);
        if (neighbors.length > k) {
            neighbors = (Neighbor[])Arrays.copyOf(neighbors, k);
        }
        MathEx.reverse(neighbors);
        return neighbors;
    }

    @Override
    public void search(K q, double radius, List<Neighbor<K, V>> neighbors) {
        if (radius <= 0.0) {
            throw new IllegalArgumentException("Invalid radius: " + radius);
        }
        ArrayList<DistanceNode> currentCoverSet = new ArrayList<DistanceNode>();
        ArrayList<DistanceNode> zeroSet = new ArrayList<DistanceNode>();
        double d = this.distance.d(this.root.getKey(), q);
        currentCoverSet.add(new DistanceNode(this, d, this.root));
        while (!currentCoverSet.isEmpty()) {
            ArrayList<DistanceNode> nextCoverSet = new ArrayList<DistanceNode>();
            for (DistanceNode par : currentCoverSet) {
                Node parent = par.node;
                for (int c = 0; c < parent.children.size(); ++c) {
                    Node child = parent.children.get(c);
                    d = c == 0 ? par.dist : this.distance.d(child.getKey(), q);
                    if (!(d <= radius + child.maxDist)) continue;
                    if (child.children != null) {
                        nextCoverSet.add(new DistanceNode(this, d, child));
                        continue;
                    }
                    if (!(d <= radius)) continue;
                    zeroSet.add(new DistanceNode(this, d, child));
                }
            }
            currentCoverSet = nextCoverSet;
        }
        for (DistanceNode ds : zeroSet) {
            if (ds.node.getKey() == q) continue;
            neighbors.add(new Neighbor(ds.node.getKey(), ds.node.getValue(), ds.node.idx, ds.dist));
        }
    }

    class DistanceSet {
        final int idx;
        final DoubleArrayList dist;
        final /* synthetic */ CoverTree this$0;

        DistanceSet(CoverTree this$0, int idx) {
            CoverTree coverTree = this$0;
            Objects.requireNonNull(coverTree);
            this.this$0 = coverTree;
            this.idx = idx;
            this.dist = new DoubleArrayList();
        }

        K getKey() {
            return this.this$0.keys.get(this.idx);
        }

        V getValue() {
            return this.this$0.data.get(this.idx);
        }
    }

    class Node
    implements Serializable {
        final int idx;
        double maxDist;
        double parentDist;
        ArrayList<Node> children;
        int scale;
        final /* synthetic */ CoverTree this$0;

        Node(CoverTree this$0, int idx) {
            CoverTree coverTree = this$0;
            Objects.requireNonNull(coverTree);
            this.this$0 = coverTree;
            this.idx = idx;
        }

        Node(CoverTree this$0, int idx, double maxDist, double parentDist, ArrayList<Node> children, int scale) {
            CoverTree coverTree = this$0;
            Objects.requireNonNull(coverTree);
            this.this$0 = coverTree;
            this.idx = idx;
            this.maxDist = maxDist;
            this.parentDist = parentDist;
            this.children = children;
            this.scale = scale;
        }

        K getKey() {
            return this.this$0.keys.get(this.idx);
        }

        V getValue() {
            return this.this$0.data.get(this.idx);
        }

        boolean isLeaf() {
            return this.children == null;
        }
    }

    class DistanceNode
    implements Comparable<DistanceNode> {
        final double dist;
        final Node node;

        DistanceNode(CoverTree this$0, double dist, Node node) {
            Objects.requireNonNull(this$0);
            this.dist = dist;
            this.node = node;
        }

        @Override
        public int compareTo(DistanceNode o) {
            return Double.compare(this.dist, o.dist);
        }
    }
}

