/*
 * Decompiled with CFR 0.152.
 */
package smile.classification;

import java.util.Arrays;
import java.util.function.BiFunction;
import java.util.stream.IntStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.classification.AbstractClassifier;
import smile.classification.ClassLabels;
import smile.classification.Classifier;
import smile.classification.DataFrameClassifier;
import smile.classification.PlattScaling;
import smile.data.DataFrame;
import smile.data.Tuple;
import smile.data.formula.Formula;
import smile.data.type.StructType;
import smile.math.MathEx;
import smile.util.IntSet;

public class OneVersusRest<T>
extends AbstractClassifier<T> {
    private static final long serialVersionUID = 2L;
    private static final Logger logger = LoggerFactory.getLogger(OneVersusRest.class);
    private final int k;
    private final Classifier<T>[] classifiers;
    private final PlattScaling[] platt;

    public OneVersusRest(Classifier<T>[] classifiers, PlattScaling[] platt) {
        this(classifiers, platt, IntSet.of((int)classifiers.length));
    }

    public OneVersusRest(Classifier<T>[] classifiers, PlattScaling[] platt, IntSet labels) {
        super(labels);
        this.classifiers = classifiers;
        this.platt = platt;
        this.k = classifiers.length;
    }

    public static <T> OneVersusRest<T> fit(T[] x, int[] y, BiFunction<T[], int[], Classifier<T>> trainer) {
        return OneVersusRest.fit(x, y, 1, -1, trainer);
    }

    public static <T> OneVersusRest<T> fit(T[] x, int[] y, int pos, int neg, BiFunction<T[], int[], Classifier<T>> trainer) {
        if (x.length != y.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
        }
        ClassLabels codec = ClassLabels.fit(y);
        int k = codec.k;
        if (k <= 2) {
            throw new IllegalArgumentException(String.format("Only %d classes", k));
        }
        int n = x.length;
        int[] labels = codec.y;
        Classifier[] classifiers = new Classifier[k];
        PlattScaling[] platts = new PlattScaling[k];
        IntStream.range(0, k).parallel().forEach(i -> {
            int[] yi = new int[n];
            for (int j = 0; j < n; ++j) {
                yi[j] = labels[j] == i ? pos : neg;
            }
            classifiers[i] = (Classifier)trainer.apply(x, yi);
            try {
                platts[i] = PlattScaling.fit(classifiers[i], x, yi);
            }
            catch (UnsupportedOperationException ex) {
                logger.info("The classifier doesn't support score function. Don't fit Platt scaling.");
            }
        });
        return new OneVersusRest<T>(classifiers, platts[0] == null ? null : platts);
    }

    public static DataFrameClassifier fit(final Formula formula, DataFrame data, BiFunction<Formula, DataFrame, DataFrameClassifier> trainer) {
        final StructType schema = formula.x(data.get(0)).schema();
        Tuple[] x = (Tuple[])data.stream().toArray(Tuple[]::new);
        int[] y = formula.y(data).toIntArray();
        final OneVersusRest<Tuple> model = OneVersusRest.fit(x, y, 1, 0, (rows, labels) -> {
            DataFrame df = DataFrame.of((StructType)schema, Arrays.asList(rows));
            return (Classifier)trainer.apply(formula, df);
        });
        return new DataFrameClassifier(){

            @Override
            public int numClasses() {
                return model.numClasses();
            }

            @Override
            public int[] classes() {
                return model.classes();
            }

            @Override
            public int predict(Tuple x) {
                return model.predict(x);
            }

            @Override
            public Formula formula() {
                return formula;
            }

            @Override
            public StructType schema() {
                return schema;
            }
        };
    }

    @Override
    public int predict(T x) {
        if (this.platt == null) {
            throw new UnsupportedOperationException("Platt scaling is not available. Please try OneVersusOne.");
        }
        int y = 0;
        double maxf = Double.NEGATIVE_INFINITY;
        for (int i = 0; i < this.k; ++i) {
            double f = this.platt[i].scale(this.classifiers[i].score(x));
            if (!(f > maxf)) continue;
            y = i;
            maxf = f;
        }
        return this.classes.valueOf(y);
    }

    @Override
    public boolean soft() {
        return true;
    }

    @Override
    public int predict(T x, double[] posteriori) {
        if (this.platt == null) {
            throw new UnsupportedOperationException("Platt scaling is not available");
        }
        for (int i = 0; i < this.k; ++i) {
            posteriori[i] = this.platt[i].scale(this.classifiers[i].score(x));
        }
        MathEx.unitize1((double[])posteriori);
        return this.classes.valueOf(MathEx.whichMax((double[])posteriori));
    }
}

