/*
 * Decompiled with CFR 0.152.
 */
package smile.feature.selection;

import java.util.Arrays;
import java.util.HashMap;
import java.util.stream.IntStream;
import smile.classification.ClassLabels;
import smile.data.DataFrame;
import smile.data.measure.Measure;
import smile.data.measure.NominalScale;
import smile.data.transform.ColumnTransform;
import smile.data.type.StructField;
import smile.data.type.StructType;
import smile.data.vector.ValueVector;
import smile.sort.QuickSort;
import smile.util.function.Function;

public record InformationValue(String feature, double iv, double[] woe, double[] breaks) implements Comparable<InformationValue>
{
    @Override
    public int compareTo(InformationValue other) {
        return Double.compare(this.iv, other.iv);
    }

    @Override
    public String toString() {
        return String.format("InformationValue(%s, %.4f)", this.feature, this.iv);
    }

    private static String predictivePower(double iv) {
        if (Double.isNaN(iv)) {
            return "";
        }
        if (iv < 0.02) {
            return "Not useful";
        }
        if (iv <= 0.1) {
            return "Weak";
        }
        if (iv <= 0.3) {
            return "Medium";
        }
        if (iv <= 0.5) {
            return "Strong";
        }
        return "Suspicious";
    }

    public static String toString(InformationValue[] ivs) {
        StringBuilder builder = new StringBuilder();
        builder.append("Feature                   Information Value    Predictive Power\n");
        for (InformationValue iv : ivs) {
            builder.append(String.format("%-25s %17.4f    %16s%n", iv.feature, iv.iv, InformationValue.predictivePower(iv.iv)));
        }
        return builder.toString();
    }

    public static ColumnTransform toTransform(InformationValue[] values) {
        HashMap<String, 1> transforms = new HashMap<String, 1>();
        for (final InformationValue iv : values) {
            Function transform = new Function(){

                public double f(double x) {
                    if (iv.breaks == null) {
                        int i = (int)x;
                        if (i < 0 || i >= iv.woe.length) {
                            throw new IllegalArgumentException("Invalid nominal value: " + i);
                        }
                        return iv.woe[i];
                    }
                    int i = Arrays.binarySearch(iv.breaks, x);
                    if (i < 0) {
                        i = -i - 1;
                    }
                    return iv.woe[i];
                }

                public String toString() {
                    return iv.feature + "_WoE";
                }
            };
            transforms.put(iv.feature, transform);
        }
        return new ColumnTransform("WoE", transforms);
    }

    public static InformationValue[] fit(DataFrame data, String clazz) {
        return InformationValue.fit(data, clazz, 10);
    }

    public static InformationValue[] fit(DataFrame data, String clazz, int nbins) {
        if (nbins < 2) {
            throw new IllegalArgumentException("Invalid number of bins: " + nbins);
        }
        ValueVector y = data.column(clazz);
        ClassLabels codec = ClassLabels.fit(y);
        if (codec.k != 2) {
            throw new UnsupportedOperationException("Information Value is applicable only to binary classification");
        }
        int n = data.size();
        StructType schema = data.schema();
        return (InformationValue[])IntStream.range(0, schema.length()).mapToObj(i -> {
            int j;
            int[] nonevents;
            int[] events;
            double[] breaks = null;
            StructField field = schema.field(i);
            Measure patt0$temp = field.measure();
            if (patt0$temp instanceof NominalScale) {
                NominalScale scale = (NominalScale)patt0$temp;
                int k = scale.size();
                events = new int[k];
                nonevents = new int[k];
                int[] xi = data.column(i).toIntArray();
                for (int j2 = 0; j2 < n; ++j2) {
                    if (codec.y[j2] == 1) {
                        int n2 = xi[j2];
                        events[n2] = events[n2] + 1;
                        continue;
                    }
                    int n3 = xi[j2];
                    nonevents[n3] = nonevents[n3] + 1;
                }
            } else if (field.isNumeric()) {
                events = new int[nbins];
                nonevents = new int[nbins];
                breaks = new double[nbins - 1];
                double[] xi = data.column(i).toDoubleArray();
                int[] order = QuickSort.sort((double[])xi);
                int begin = 0;
                for (j = 0; j < nbins; ++j) {
                    int end = (j + 1) * n / nbins;
                    if (j < nbins - 1) {
                        breaks[j] = xi[end];
                    }
                    for (int k = begin; k < end; ++k) {
                        if (codec.y[order[k]] == 1) {
                            int n4 = j;
                            events[n4] = events[n4] + 1;
                            continue;
                        }
                        int n5 = j;
                        nonevents[n5] = nonevents[n5] + 1;
                    }
                    begin = end;
                }
            } else {
                return null;
            }
            int k = events.length;
            double[] woe = new double[k];
            double iv = 0.0;
            for (j = 0; j < k; ++j) {
                double pnonevents = Math.max((double)nonevents[j], 0.5) / (double)codec.ni[0];
                double pevents = Math.max((double)events[j], 0.5) / (double)codec.ni[1];
                woe[j] = Math.log(pnonevents / pevents);
                iv += (pnonevents - pevents) * woe[j];
            }
            return new InformationValue(field.name(), iv, woe, breaks);
        }).filter(iv -> iv != null && !iv.feature.equals(clazz)).toArray(InformationValue[]::new);
    }
}

