/*
 * Decompiled with CFR 0.152.
 */
package smile.vq;

import java.io.Serializable;
import java.util.Arrays;
import java.util.Comparator;
import java.util.stream.IntStream;
import smile.graph.AdjacencyMatrix;
import smile.graph.Graph;
import smile.math.MathEx;
import smile.sort.QuickSort;
import smile.util.function.TimeFunction;
import smile.vq.VectorQuantizer;

public class NeuralGas
implements VectorQuantizer {
    private static final long serialVersionUID = 2L;
    private final Neuron[] neurons;
    private final AdjacencyMatrix graph;
    private final TimeFunction alpha;
    private final TimeFunction theta;
    private final TimeFunction lifetime;
    private final double[] dist;
    private final double eps = 1.0E-7;
    private int t = 0;

    public NeuralGas(double[][] neurons, TimeFunction alpha, TimeFunction theta, TimeFunction lifetime) {
        this.neurons = (Neuron[])IntStream.range(0, neurons.length).mapToObj(i -> new Neuron(i, (double[])neurons[i].clone())).toArray(Neuron[]::new);
        this.alpha = alpha;
        this.theta = theta;
        this.lifetime = lifetime;
        this.graph = new AdjacencyMatrix(neurons.length);
        this.dist = new double[neurons.length];
    }

    public double[][] neurons() {
        Arrays.sort(this.neurons, Comparator.comparingInt(x -> x.i));
        return (double[][])Arrays.stream(this.neurons).map(neuron -> neuron.w).toArray(x$0 -> new double[x$0][]);
    }

    public Graph network() {
        double lifetime = this.lifetime.apply(this.t);
        for (int i = 0; i < this.neurons.length; ++i) {
            for (Graph.Edge e : this.graph.getEdges(i)) {
                if (!((double)this.t - e.weight() > lifetime)) continue;
                this.graph.setWeight(e.u(), e.v(), 0.0);
            }
        }
        return this.graph;
    }

    @Override
    public void update(double[] x) {
        int k = this.neurons.length;
        int d = x.length;
        IntStream.range(0, this.neurons.length).parallel().forEach(i -> {
            this.dist[i] = MathEx.distance((double[])this.neurons[i].w, (double[])x);
        });
        QuickSort.sort((double[])this.dist, (Object[])this.neurons);
        double alpha = this.alpha.apply(this.t);
        double theta = this.theta.apply(this.t);
        for (int i2 = 0; i2 < k; ++i2) {
            double delta = alpha * Math.exp((double)(-i2) / theta);
            if (!(delta > 1.0E-7)) continue;
            double[] w = this.neurons[i2].w;
            for (int j = 0; j < d; ++j) {
                int n = j;
                w[n] = w[n] + delta * (x[j] - w[j]);
            }
        }
        this.graph.setWeight(this.neurons[0].i, this.neurons[1].i, (double)this.t);
        ++this.t;
    }

    @Override
    public double[] quantize(double[] x) {
        IntStream.range(0, this.neurons.length).parallel().forEach(i -> {
            this.dist[i] = MathEx.distance((double[])this.neurons[i].w, (double[])x);
        });
        return this.neurons[MathEx.whichMin((double[])this.dist)].w;
    }

    private record Neuron(int i, double[] w) implements Serializable
    {
    }
}

