/*
 * Decompiled with CFR 0.152.
 */
package smile.validation.metric;

import smile.math.MathEx;
import smile.validation.metric.Averaging;
import smile.validation.metric.ClassificationMetric;

public class Precision
implements ClassificationMetric {
    private static final long serialVersionUID = 2L;
    public static final Precision instance = new Precision();
    private final Averaging strategy;

    public Precision() {
        this(null);
    }

    public Precision(Averaging strategy) {
        this.strategy = strategy;
    }

    @Override
    public double score(int[] truth, int[] prediction) {
        return Precision.of(truth, prediction, this.strategy);
    }

    public String toString() {
        return this.strategy == null ? "Precision" : String.valueOf((Object)this.strategy) + "-Precision";
    }

    public static double of(int[] truth, int[] prediction) {
        for (int t : truth) {
            if (t == 0 || t == 1) continue;
            throw new IllegalArgumentException("Precision can only be applied to binary classification: " + t);
        }
        for (int p : prediction) {
            if (p == 0 || p == 1) continue;
            throw new IllegalArgumentException("Precision can only be applied to binary classification: " + p);
        }
        return Precision.of(truth, prediction, null);
    }

    public static double of(int[] truth, int[] prediction, Averaging strategy) {
        if (truth.length != prediction.length) {
            throw new IllegalArgumentException(String.format("The vector sizes don't match: %d != %d.", truth.length, prediction.length));
        }
        int numClasses = Math.max(MathEx.max((int[])truth), MathEx.max((int[])prediction)) + 1;
        if (numClasses > 2 && strategy == null) {
            throw new IllegalArgumentException("Averaging strategy is null for multi-class");
        }
        int length = strategy == Averaging.Macro || strategy == Averaging.Weighted ? numClasses : 1;
        int[] tp = new int[length];
        int[] fp = new int[length];
        int[] size = new int[numClasses];
        int n = truth.length;
        int[] nArray = truth;
        int n2 = nArray.length;
        for (int i = 0; i < n2; ++i) {
            int target;
            int n3 = target = nArray[i];
            size[n3] = size[n3] + 1;
        }
        if (strategy == null) {
            for (i = 0; i < n; ++i) {
                if (prediction[i] != 1) continue;
                if (truth[i] == 1) {
                    tp[0] = tp[0] + 1;
                    continue;
                }
                fp[0] = fp[0] + 1;
            }
        } else if (strategy == Averaging.Micro) {
            for (i = 0; i < n; ++i) {
                tp[0] = tp[0] + (truth[i] == prediction[i] ? 1 : 0);
                fp[0] = fp[0] + (truth[i] != prediction[i] ? 1 : 0);
            }
        } else {
            for (i = 0; i < n; ++i) {
                int n4 = truth[i];
                tp[n4] = tp[n4] + (truth[i] == prediction[i] ? 1 : 0);
                int n5 = prediction[i];
                fp[n5] = fp[n5] + (truth[i] != prediction[i] ? 1 : 0);
            }
        }
        double[] precision = new double[tp.length];
        for (int i = 0; i < tp.length; ++i) {
            precision[i] = (double)tp[i] / (double)(tp[i] + fp[i]);
        }
        if (strategy == Averaging.Macro) {
            return MathEx.mean((double[])precision);
        }
        if (strategy == Averaging.Weighted) {
            double weighted = 0.0;
            for (int i = 0; i < numClasses; ++i) {
                weighted += precision[i] * (double)size[i];
            }
            return weighted / (double)n;
        }
        return precision[0];
    }
}

