/*
 * Decompiled with CFR 0.152.
 */
package smile.validation;

import java.util.Arrays;
import smile.math.MathEx;
import smile.validation.ClusterMeasure;
import smile.validation.ContingencyTable;
import smile.validation.MutualInformation;

public class AdjustedMutualInformation
implements ClusterMeasure {
    private final Method method;

    public AdjustedMutualInformation(Method method) {
        this.method = method;
    }

    @Override
    public double measure(int[] y1, int[] y2) {
        switch (this.method) {
            case MAX: {
                return AdjustedMutualInformation.max(y1, y2);
            }
            case MIN: {
                return AdjustedMutualInformation.min(y1, y2);
            }
            case SUM: {
                return AdjustedMutualInformation.sum(y1, y2);
            }
            case SQRT: {
                return AdjustedMutualInformation.sqrt(y1, y2);
            }
        }
        throw new IllegalStateException("Unknown normalization method: " + (Object)((Object)this.method));
    }

    public static double max(int[] y1, int[] y2) {
        ContingencyTable contingency = new ContingencyTable(y1, y2);
        double n = contingency.n;
        double[] p1 = Arrays.stream(contingency.a).mapToDouble(a -> (double)a / n).toArray();
        double[] p2 = Arrays.stream(contingency.b).mapToDouble(b -> (double)b / n).toArray();
        double h1 = MathEx.entropy((double[])p1);
        double h2 = MathEx.entropy((double[])p2);
        double I = MutualInformation.of(contingency.n, p1, p2, contingency.table);
        double E = AdjustedMutualInformation.E(contingency.n, contingency.a, contingency.b);
        return (I - E) / (Math.max(h1, h2) - E);
    }

    public static double sum(int[] y1, int[] y2) {
        ContingencyTable contingency = new ContingencyTable(y1, y2);
        double n = contingency.n;
        double[] p1 = Arrays.stream(contingency.a).mapToDouble(a -> (double)a / n).toArray();
        double[] p2 = Arrays.stream(contingency.b).mapToDouble(b -> (double)b / n).toArray();
        double h1 = MathEx.entropy((double[])p1);
        double h2 = MathEx.entropy((double[])p2);
        double I = MutualInformation.of(contingency.n, p1, p2, contingency.table);
        double E = AdjustedMutualInformation.E(contingency.n, contingency.a, contingency.b);
        return (I - E) / (0.5 * (h1 + h2) - E);
    }

    public static double sqrt(int[] y1, int[] y2) {
        ContingencyTable contingency = new ContingencyTable(y1, y2);
        double n = contingency.n;
        double[] p1 = Arrays.stream(contingency.a).mapToDouble(a -> (double)a / n).toArray();
        double[] p2 = Arrays.stream(contingency.b).mapToDouble(b -> (double)b / n).toArray();
        double h1 = MathEx.entropy((double[])p1);
        double h2 = MathEx.entropy((double[])p2);
        double I = MutualInformation.of(contingency.n, p1, p2, contingency.table);
        double E = AdjustedMutualInformation.E(contingency.n, contingency.a, contingency.b);
        return (I - E) / (Math.sqrt(h1 * h2) - E);
    }

    public static double min(int[] y1, int[] y2) {
        ContingencyTable contingency = new ContingencyTable(y1, y2);
        double n = contingency.n;
        double[] p1 = Arrays.stream(contingency.a).mapToDouble(a -> (double)a / n).toArray();
        double[] p2 = Arrays.stream(contingency.b).mapToDouble(b -> (double)b / n).toArray();
        double h1 = MathEx.entropy((double[])p1);
        double h2 = MathEx.entropy((double[])p2);
        double I = MutualInformation.of(contingency.n, p1, p2, contingency.table);
        double E = AdjustedMutualInformation.E(contingency.n, contingency.a, contingency.b);
        return (I - E) / (Math.min(h1, h2) - E);
    }

    private static double E(int n, int[] a, int[] b) {
        int n1 = a.length;
        int n2 = b.length;
        double N = n;
        double E = 0.0;
        for (int i = 0; i < n1; ++i) {
            int ai = a[i];
            for (int j = 0; j < n2; ++j) {
                int bj = b[j];
                int begin = Math.max(1, ai + bj - n);
                int end = Math.min(ai, bj);
                for (int nij = begin; nij <= end; ++nij) {
                    E += (double)nij / N * Math.log((double)nij * N / (double)(ai * bj)) * Math.exp(MathEx.lfactorial((int)ai) + MathEx.lfactorial((int)bj) + MathEx.lfactorial((int)(n - ai)) + MathEx.lfactorial((int)(n - bj)) - (MathEx.lfactorial((int)n) + MathEx.lfactorial((int)nij) + MathEx.lfactorial((int)(ai - nij)) + MathEx.lfactorial((int)(bj - nij)) + MathEx.lfactorial((int)(n - ai - bj + nij))));
                }
            }
        }
        return E;
    }

    public String toString() {
        return "Adjusted Mutual Information";
    }

    public static enum Method {
        MAX,
        MIN,
        SUM,
        SQRT;

    }
}

