/*
 * Decompiled with CFR 0.152.
 */
package smile.classification;

import java.io.Serializable;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Properties;
import java.util.function.ToDoubleFunction;
import java.util.function.ToIntFunction;
import java.util.stream.IntStream;
import smile.data.Dataset;
import smile.data.Instance;
import smile.math.MathEx;

public interface Classifier<T>
extends ToIntFunction<T>,
ToDoubleFunction<T>,
Serializable {
    public int numClasses();

    public int[] classes();

    public int predict(T var1);

    default public double score(T x) {
        throw new UnsupportedOperationException();
    }

    @Override
    default public int applyAsInt(T x) {
        return this.predict(x);
    }

    @Override
    default public double applyAsDouble(T x) {
        return this.score(x);
    }

    default public int[] predict(T[] x) {
        return Arrays.stream(x).mapToInt(this::predict).toArray();
    }

    default public int[] predict(List<T> x) {
        return x.stream().mapToInt(this::predict).toArray();
    }

    default public int[] predict(Dataset<T> x) {
        return x.stream().mapToInt(this::predict).toArray();
    }

    default public boolean soft() {
        try {
            this.predict(null, new double[this.numClasses()]);
        }
        catch (UnsupportedOperationException e) {
            return !e.getMessage().equals("soft classification with a hard classifier");
        }
        catch (Exception e) {
            return true;
        }
        return false;
    }

    default public int predict(T x, double[] posteriori) {
        throw new UnsupportedOperationException("soft classification with a hard classifier");
    }

    default public int[] predict(T[] x, double[][] posteriori) {
        int n = x.length;
        return IntStream.range(0, n).parallel().map(i -> this.predict(x[i], posteriori[i])).toArray();
    }

    default public int[] predict(List<T> x, List<double[]> posteriori) {
        int n = x.size();
        int k = this.numClasses();
        double[][] prob = new double[n][k];
        Collections.addAll(posteriori, prob);
        return IntStream.range(0, n).parallel().map(i -> this.predict(x.get(i), prob[i])).toArray();
    }

    default public int[] predict(Dataset<T> x, List<double[]> posteriori) {
        int n = x.size();
        int k = this.numClasses();
        double[][] prob = new double[n][k];
        Collections.addAll(posteriori, prob);
        return IntStream.range(0, n).parallel().map(i -> this.predict(x.get(i), prob[i])).toArray();
    }

    default public boolean online() {
        try {
            this.update(null, 0);
        }
        catch (UnsupportedOperationException e) {
            return !e.getMessage().equals("update a batch learner");
        }
        catch (Exception e) {
            return true;
        }
        return false;
    }

    default public void update(T x, int y) {
        throw new UnsupportedOperationException("update a batch learner");
    }

    default public void update(T[] x, int[] y) {
        if (x.length != y.length) {
            throw new IllegalArgumentException(String.format("Input vector x of size %d not equal to length %d of y", x.length, y.length));
        }
        for (int i = 0; i < x.length; ++i) {
            this.update(x[i], y[i]);
        }
    }

    default public void update(Dataset<Instance<T>> batch) {
        batch.stream().forEach(sample -> this.update(sample.x(), sample.label()));
    }

    @SafeVarargs
    public static <T> Classifier<T> ensemble(final Classifier<T> ... models) {
        return new Classifier<T>(){
            private final boolean soft;
            private final boolean online;
            {
                this.soft = Arrays.stream(models).allMatch(Classifier::soft);
                this.online = Arrays.stream(models).allMatch(Classifier::online);
            }

            @Override
            public boolean soft() {
                return this.soft;
            }

            @Override
            public boolean online() {
                return this.online;
            }

            @Override
            public int numClasses() {
                return models[0].numClasses();
            }

            @Override
            public int[] classes() {
                return models[0].classes();
            }

            @Override
            public int predict(T x) {
                int[] labels = new int[models.length];
                for (int i = 0; i < models.length; ++i) {
                    labels[i] = models[i].predict(x);
                }
                return MathEx.mode((int[])labels);
            }

            @Override
            public int predict(T x, double[] posteriori) {
                Arrays.fill(posteriori, 0.0);
                double[] prob = new double[posteriori.length];
                for (Classifier model : models) {
                    model.predict(x, prob);
                    for (int i = 0; i < prob.length; ++i) {
                        int n = i;
                        posteriori[n] = posteriori[n] + prob[i];
                    }
                }
                int i = 0;
                while (i < posteriori.length) {
                    int n = i++;
                    posteriori[n] = posteriori[n] / (double)models.length;
                }
                return MathEx.whichMax((double[])posteriori);
            }

            @Override
            public void update(T x, int y) {
                for (Classifier model : models) {
                    model.update(x, y);
                }
            }
        };
    }

    public static interface Trainer<T, M extends Classifier<T>> {
        default public M fit(T[] x, int[] y) {
            Properties params = new Properties();
            return this.fit(x, y, params);
        }

        public M fit(T[] var1, int[] var2, Properties var3);
    }
}

