/*
 * Decompiled with CFR 0.152.
 */
package smile.neighbor.lsh;

import smile.neighbor.lsh.MultiProbeHash;
import smile.neighbor.lsh.MultiProbeSample;
import smile.stat.distribution.GaussianDistribution;

public class HashValueParzenModel {
    private final GaussianDistribution gaussian;
    private final NeighborHashValueModel[] neighborHashValueModels;
    private double mean;
    private double sd;

    public HashValueParzenModel(MultiProbeHash hash, MultiProbeSample[] samples, double sigma) {
        int k = hash.k;
        this.gaussian = new GaussianDistribution(0.0, sigma);
        int n = 0;
        for (MultiProbeSample sample : samples) {
            if (sample.neighbors().size() <= 1) continue;
            ++n;
        }
        this.neighborHashValueModels = new NeighborHashValueModel[n];
        int l = 0;
        for (MultiProbeSample sample : samples) {
            if (sample.neighbors().size() <= 1) continue;
            double[] H = new double[k];
            double[] mu = new double[k];
            double[] var = new double[k];
            for (int i = 0; i < k; ++i) {
                H[i] = hash.hash(sample.query(), i);
                double sum = 0.0;
                double sumsq = 0.0;
                for (double[] v : sample.neighbors()) {
                    double h = hash.hash(v, i);
                    sum += h;
                    sumsq += h * h;
                }
                mu[i] = sum / (double)sample.neighbors().size();
                var[i] = sumsq / (double)sample.neighbors().size() - mu[i] * mu[i];
            }
            this.neighborHashValueModels[l++] = new NeighborHashValueModel(H, mu, var);
        }
    }

    public void estimate(int m, double h) {
        double mm = 0.0;
        double vv = 0.0;
        double ss = 0.0;
        for (NeighborHashValueModel model : this.neighborHashValueModels) {
            double k = this.gaussian.p(model.H[m] - h);
            mm += k * model.mean[m];
            vv += k * model.var[m];
            ss += k;
        }
        if (ss > 1.0E-7) {
            this.mean = mm / ss;
            this.sd = Math.sqrt(vv / ss);
        } else {
            this.mean = h;
            this.sd = 0.0;
        }
        if (this.sd < 1.0E-5) {
            this.sd = 0.0;
            for (NeighborHashValueModel model : this.neighborHashValueModels) {
                this.sd += model.var[m];
            }
            this.sd = Math.sqrt(this.sd / (double)this.neighborHashValueModels.length);
        }
    }

    public double mean() {
        return this.mean;
    }

    public double sd() {
        return this.sd;
    }

    private record NeighborHashValueModel(double[] H, double[] mean, double[] var) {
    }
}

