/*
 * Decompiled with CFR 0.152.
 */
package smile.neighbor;

import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.IntStream;
import smile.math.MathEx;
import smile.neighbor.KNNSearch;
import smile.neighbor.Neighbor;
import smile.neighbor.RandomProjectionForest;

public class RandomProjectionTree
implements KNNSearch<double[], double[]> {
    private static final float EPS = 1.0E-8f;
    private final double[][] data;
    private final Node root;
    private final int leafSize;
    private final boolean angular;

    private RandomProjectionTree(double[][] data, Node root, int leafSize, boolean angular) {
        this.data = data;
        this.root = root;
        this.leafSize = leafSize;
        this.angular = angular;
    }

    @Override
    public Neighbor<double[], double[]>[] search(double[] q, int k) {
        if (k > this.leafSize) {
            throw new IllegalArgumentException("k must be <= leafSize");
        }
        Node leaf = this.root.search(q);
        int[] samples = leaf.samples();
        Object[] neighbors = (Neighbor[])Array.newInstance(Neighbor.class, samples.length);
        for (int i = 0; i < samples.length; ++i) {
            int index = samples[i];
            double[] x = this.data[index];
            double dist = this.angular ? MathEx.angular(q, x) : MathEx.distance(q, x);
            neighbors[i] = Neighbor.of(x, index, dist);
        }
        Arrays.sort(neighbors);
        return samples.length <= k ? neighbors : (Neighbor[])Arrays.copyOf(neighbors, k);
    }

    public int numNodes() {
        return this.root.numNodes();
    }

    public int numLeaves() {
        return this.root.numLeaves();
    }

    public List<int[]> leafSamples() {
        ArrayList<int[]> samples = new ArrayList<int[]>();
        this.root.recursiveLeafSamples(samples);
        return samples;
    }

    RandomProjectionForest.FlatTree flatten() {
        int numNodes = this.root.numNodes();
        int numLeaves = this.root.numLeaves();
        double[][] hyperplanes = new double[numNodes][];
        double[] offsets = new double[numNodes];
        int[][] children = new int[numNodes][];
        int[][] indices = new int[numLeaves][];
        this.root.recursiveFlatten(hyperplanes, offsets, children, indices, 0, 0);
        return new RandomProjectionForest.FlatTree(hyperplanes, offsets, children, indices);
    }

    private static double[] normalize(double[] x) {
        double norm = MathEx.norm(x);
        if (Math.abs(norm) < (double)1.0E-8f) {
            norm = 1.0;
        }
        int n = x.length;
        double[] y = new double[n];
        for (int i = 0; i < n; ++i) {
            y[i] = x[i] / norm;
        }
        return y;
    }

    static boolean isRightSide(double[] point, double[] hyperplane, double offset) {
        double margin = offset;
        for (int i = 0; i < point.length; ++i) {
            margin += hyperplane[i] * point[i];
        }
        return Math.abs(margin) < (double)1.0E-8f ? MathEx.random() < 0.5 : margin < 0.0;
    }

    private static double[][] randomPoints(double[][] data, int[] samples, boolean angular) {
        int i = samples[MathEx.randomInt(samples.length)];
        double[] xi = data[i];
        int other = -1;
        double farthest = Double.NEGATIVE_INFINITY;
        for (int j : samples) {
            double dist;
            if (j == i) continue;
            double[] xj = data[j];
            double d = dist = angular ? MathEx.angular(xi, xj) : MathEx.distance(xi, xj);
            if (!(dist > farthest)) continue;
            other = j;
            farthest = dist;
        }
        double[] left = RandomProjectionTree.normalize(data[i]);
        double[] right = RandomProjectionTree.normalize(data[other]);
        return new double[][]{left, right};
    }

    private static Split angularSplit(double[][] data, int[] samples) {
        int dim = data[0].length;
        for (int iter = 0; iter < 5; ++iter) {
            double[][] points = RandomProjectionTree.randomPoints(data, samples, true);
            double[] left = points[0];
            double[] right = points[1];
            for (int d = 0; d < dim; ++d) {
                int n = d;
                left[n] = left[n] - right[d];
            }
            double[] hyperplane = RandomProjectionTree.normalize(left);
            Split split = RandomProjectionTree.split(data, samples, hyperplane, 0.0);
            if (split == null) continue;
            return split;
        }
        return null;
    }

    private static Split euclideanSplit(double[][] data, int[] samples) {
        int dim = data[0].length;
        for (int iter = 0; iter < 5; ++iter) {
            double[][] points = RandomProjectionTree.randomPoints(data, samples, false);
            double[] left = points[0];
            double[] right = points[1];
            for (int d = 0; d < dim; ++d) {
                int n = d;
                left[n] = left[n] - right[d];
            }
            double offset = 0.0;
            double[] hyperplane = new double[dim];
            for (int d = 0; d < dim; ++d) {
                double delta;
                double ld = left[d];
                double rd = right[d];
                hyperplane[d] = delta = ld - rd;
                offset -= delta * (ld + rd);
            }
            Split split = RandomProjectionTree.split(data, samples, hyperplane, offset /= 2.0);
            if (split == null) continue;
            return split;
        }
        return null;
    }

    private static Split split(double[][] data, int[] samples, double[] hyperplane, double offset) {
        int numLeft = 0;
        int numRight = 0;
        boolean[] rightSide = new boolean[samples.length];
        for (int i = 0; i < samples.length; ++i) {
            rightSide[i] = RandomProjectionTree.isRightSide(data[samples[i]], hyperplane, offset);
            if (rightSide[i]) {
                ++numRight;
                continue;
            }
            ++numLeft;
        }
        if (numLeft < 2 || numRight < 2) {
            return null;
        }
        int[] leftSamples = new int[numLeft];
        int[] rightSamples = new int[numRight];
        int l = 0;
        int r = 0;
        for (int i = 0; i < rightSide.length; ++i) {
            if (rightSide[i]) {
                rightSamples[r++] = samples[i];
                continue;
            }
            leftSamples[l++] = samples[i];
        }
        return new Split(leftSamples, rightSamples, hyperplane, offset);
    }

    private static Node makeEuclideanTree(double[][] data, int[] samples, int leafSize) {
        if (samples.length <= leafSize) {
            return new Node(samples);
        }
        Split split = RandomProjectionTree.euclideanSplit(data, samples);
        if (split == null) {
            return new Node(samples);
        }
        Node leftNode = RandomProjectionTree.makeEuclideanTree(data, split.leftSamples, leafSize);
        Node rightNode = RandomProjectionTree.makeEuclideanTree(data, split.rightSamples, leafSize);
        return new Node(split.hyperplane, split.offset, leftNode, rightNode);
    }

    private static Node makeAngularTree(double[][] data, int[] samples, int leafSize) {
        if (samples.length <= leafSize) {
            return new Node(samples);
        }
        Split split = RandomProjectionTree.angularSplit(data, samples);
        if (split == null) {
            return new Node(samples);
        }
        Node leftNode = RandomProjectionTree.makeAngularTree(data, split.leftSamples, leafSize);
        Node rightNode = RandomProjectionTree.makeAngularTree(data, split.rightSamples, leafSize);
        return new Node(split.hyperplane, split.offset, leftNode, rightNode);
    }

    public static RandomProjectionTree of(double[][] data, int leafSize, boolean angular) {
        if (leafSize < 3) {
            throw new IllegalArgumentException("leafSize must be at least 3");
        }
        int[] samples = IntStream.range(0, data.length).toArray();
        Node root = angular ? RandomProjectionTree.makeAngularTree(data, samples, leafSize) : RandomProjectionTree.makeEuclideanTree(data, samples, leafSize);
        return new RandomProjectionTree(data, root, leafSize, angular);
    }

    record Node(int[] samples, double[] hyperplane, double offset, Node leftChild, Node rightChild) {
        Node(int[] samples) {
            this(samples, null, 0.0, null, null);
        }

        Node(double[] hyperplane, double offset, Node leftChild, Node rightChild) {
            this(null, hyperplane, offset, leftChild, rightChild);
        }

        boolean isLeaf() {
            return this.leftChild == null && this.rightChild == null;
        }

        int numNodes() {
            return 1 + (this.leftChild != null ? this.leftChild.numNodes() : 0) + (this.rightChild != null ? this.rightChild.numNodes() : 0);
        }

        int numLeaves() {
            return this.isLeaf() ? 1 : (this.leftChild != null ? this.leftChild.numLeaves() : 0) + (this.rightChild != null ? this.rightChild.numLeaves() : 0);
        }

        Node search(double[] point) {
            if (this.isLeaf()) {
                return this;
            }
            boolean rightSide = RandomProjectionTree.isRightSide(point, this.hyperplane, this.offset);
            return rightSide ? this.rightChild.search(point) : this.leftChild.search(point);
        }

        int[] recursiveFlatten(double[][] hyperplanes, double[] offsets, int[][] children, int[][] indices, int nodeNum, int leafNum) {
            if (this.isLeaf()) {
                children[nodeNum] = new int[]{-leafNum, -1};
                indices[leafNum] = this.samples;
                return new int[]{nodeNum, leafNum + 1};
            }
            hyperplanes[nodeNum] = this.hyperplane;
            offsets[nodeNum] = this.offset;
            int[] flattenInfo = this.leftChild.recursiveFlatten(hyperplanes, offsets, children, indices, nodeNum + 1, leafNum);
            children[nodeNum] = new int[]{nodeNum + 1, flattenInfo[0] + 1};
            return this.rightChild.recursiveFlatten(hyperplanes, offsets, children, indices, flattenInfo[0] + 1, flattenInfo[1]);
        }

        void recursiveLeafSamples(List<int[]> sampleList) {
            if (this.isLeaf()) {
                sampleList.add(this.samples);
            } else {
                this.leftChild.recursiveLeafSamples(sampleList);
                this.rightChild.recursiveLeafSamples(sampleList);
            }
        }
    }

    record Split(int[] leftSamples, int[] rightSamples, double[] hyperplane, double offset) {
    }
}

