/*
 * Decompiled with CFR 0.152.
 */
package smile.sequence;

import java.util.Arrays;
import java.util.function.ToIntFunction;
import smile.sequence.HMM;
import smile.sequence.SequenceLabeler;

public class HMMLabeler<T>
implements SequenceLabeler<T> {
    private static final long serialVersionUID = 2L;
    private final HMM model;
    private final ToIntFunction<T> ordinal;

    public HMMLabeler(HMM model, ToIntFunction<T> ordinal) {
        this.model = model;
        this.ordinal = ordinal;
    }

    public static <T> HMMLabeler<T> fit(T[][] observations, int[][] labels, ToIntFunction<T> ordinal) {
        if (observations.length != labels.length) {
            throw new IllegalArgumentException("The number of observation sequences and that of label sequences are different.");
        }
        HMM model = HMM.fit((int[][])Arrays.stream(observations).map(sequence -> Arrays.stream(sequence).mapToInt(ordinal).toArray()).toArray(x$0 -> new int[x$0][]), labels);
        return new HMMLabeler<T>(model, ordinal);
    }

    public void update(T[][] observations, int iterations) {
        this.model.update((int[][])Arrays.stream(observations).map(sequence -> Arrays.stream(sequence).mapToInt(this.ordinal).toArray()).toArray(x$0 -> new int[x$0][]), iterations);
    }

    public String toString() {
        return this.model.toString();
    }

    private int[] translate(T[] o) {
        return Arrays.stream(o).mapToInt(this.ordinal).toArray();
    }

    public double p(T[] o, int[] s) {
        return this.model.p(this.translate(o), s);
    }

    public double logp(T[] o, int[] s) {
        return this.model.logp(this.translate(o), s);
    }

    public double p(T[] o) {
        return this.model.p(this.translate(o));
    }

    public double logp(T[] o) {
        return this.model.logp(this.translate(o));
    }

    @Override
    public int[] predict(T[] o) {
        return this.model.predict(this.translate(o));
    }
}

