/*
 * Decompiled with CFR 0.152.
 */
package smile.sequence;

import java.io.Serializable;
import java.util.Arrays;
import java.util.function.ToIntFunction;
import smile.math.MathEx;
import smile.tensor.DenseMatrix;

public class HMM
implements Serializable {
    private static final long serialVersionUID = 2L;
    private final double[] pi;
    private final DenseMatrix a;
    private final DenseMatrix b;

    public HMM(double[] pi, DenseMatrix a, DenseMatrix b) {
        if (pi.length == 0) {
            throw new IllegalArgumentException("Invalid initial state probabilities.");
        }
        if (pi.length != a.nrow()) {
            throw new IllegalArgumentException("Invalid state transition probability matrix.");
        }
        if (a.nrow() != b.nrow()) {
            throw new IllegalArgumentException("Invalid symbol emission probability matrix.");
        }
        this.pi = pi;
        this.a = a;
        this.b = b;
    }

    public double[] getInitialStateProbabilities() {
        return this.pi;
    }

    public DenseMatrix getStateTransitionProbabilities() {
        return this.a;
    }

    public DenseMatrix getSymbolEmissionProbabilities() {
        return this.b;
    }

    public double p(int[] o, int[] s) {
        return Math.exp(this.logp(o, s));
    }

    public double logp(int[] o, int[] s) {
        if (o.length != s.length) {
            throw new IllegalArgumentException("The observation sequence and state sequence are not the same length.");
        }
        int n = s.length;
        double p = MathEx.log((double)this.pi[s[0]]) + MathEx.log((double)this.b.get(s[0], o[0]));
        for (int i = 1; i < n; ++i) {
            p += MathEx.log((double)this.a.get(s[i - 1], s[i])) + MathEx.log((double)this.b.get(s[i], o[i]));
        }
        return p;
    }

    public double p(int[] o) {
        return Math.exp(this.logp(o));
    }

    public double logp(int[] o) {
        double[][] alpha = new double[o.length][this.a.nrow()];
        double[] scaling = new double[o.length];
        this.forward(o, alpha, scaling);
        double p = 0.0;
        for (int t = 0; t < o.length; ++t) {
            p += Math.log(scaling[t]);
        }
        return p;
    }

    private void scale(double[] scaling, double[][] alpha, int t) {
        double[] table = alpha[t];
        double sum = 0.0;
        for (double x : table) {
            sum += x;
        }
        scaling[t] = sum;
        int i = 0;
        while (i < table.length) {
            int n = i++;
            table[n] = table[n] / sum;
        }
    }

    private void forward(int[] o, double[][] alpha, double[] scaling) {
        int N = this.a.nrow();
        for (int k = 0; k < N; ++k) {
            alpha[0][k] = this.pi[k] * this.b.get(k, o[0]);
        }
        this.scale(scaling, alpha, 0);
        for (int t = 1; t < o.length; ++t) {
            for (int k = 0; k < N; ++k) {
                double sum = 0.0;
                for (int i = 0; i < N; ++i) {
                    sum += alpha[t - 1][i] * this.a.get(i, k);
                }
                alpha[t][k] = sum * this.b.get(k, o[t]);
            }
            this.scale(scaling, alpha, t);
        }
    }

    private void backward(int[] o, double[][] beta, double[] scaling) {
        int N = this.a.nrow();
        int n = o.length - 1;
        for (int i = 0; i < N; ++i) {
            beta[n][i] = 1.0 / scaling[n];
        }
        int t = n;
        while (t-- > 0) {
            for (int i = 0; i < N; ++i) {
                double sum = 0.0;
                for (int j = 0; j < N; ++j) {
                    sum += beta[t + 1][j] * this.a.get(i, j) * this.b.get(j, o[t + 1]);
                }
                beta[t][i] = sum / scaling[t];
            }
        }
    }

    public int[] predict(int[] o) {
        int N = this.a.nrow();
        double[][] trellis = new double[o.length][N];
        int[][] psy = new int[o.length][N];
        int[] s = new int[o.length];
        for (int i = 0; i < N; ++i) {
            trellis[0][i] = MathEx.log((double)this.pi[i]) + MathEx.log((double)this.b.get(i, o[0]));
            psy[0][i] = 0;
        }
        for (int t = 1; t < o.length; ++t) {
            for (int j = 0; j < N; ++j) {
                double maxDelta = Double.NEGATIVE_INFINITY;
                int maxPsy = 0;
                for (int i = 0; i < N; ++i) {
                    double delta = trellis[t - 1][i] + MathEx.log((double)this.a.get(i, j));
                    if (!(maxDelta < delta)) continue;
                    maxDelta = delta;
                    maxPsy = i;
                }
                trellis[t][j] = maxDelta + MathEx.log((double)this.b.get(j, o[t]));
                psy[t][j] = maxPsy;
            }
        }
        int n = o.length - 1;
        double maxDelta = Double.NEGATIVE_INFINITY;
        for (int i = 0; i < N; ++i) {
            if (!(maxDelta < trellis[n][i])) continue;
            maxDelta = trellis[n][i];
            s[n] = i;
        }
        int t = n;
        while (t-- > 0) {
            s[t] = psy[t + 1][s[t + 1]];
        }
        return s;
    }

    public static HMM fit(int[][] observations, int[][] labels) {
        int i;
        if (observations.length != labels.length) {
            throw new IllegalArgumentException("The number of observation sequences and that of label sequences are different.");
        }
        int N = 0;
        int M2 = 0;
        for (int i2 = 0; i2 < observations.length; ++i2) {
            if (observations[i2].length != labels[i2].length) {
                throw new IllegalArgumentException(String.format("The length of observation sequence %d and that of corresponding label sequence are different.", i2));
            }
            N = Math.max(N, MathEx.max((int[])labels[i2]) + 1);
            M2 = Math.max(M2, MathEx.max((int[])observations[i2]) + 1);
        }
        double[] pi = new double[N];
        double[][] a = new double[N][N];
        double[][] b = new double[N][M2];
        for (i = 0; i < observations.length; ++i) {
            int n = labels[i][0];
            pi[n] = pi[n] + 1.0;
            double[] dArray = b[labels[i][0]];
            int n2 = observations[i][0];
            dArray[n2] = dArray[n2] + 1.0;
            for (int j = 1; j < observations[i].length; ++j) {
                double[] dArray2 = a[labels[i][j - 1]];
                int n3 = labels[i][j];
                dArray2[n3] = dArray2[n3] + 1.0;
                double[] dArray3 = b[labels[i][j]];
                int n4 = observations[i][j];
                dArray3[n4] = dArray3[n4] + 1.0;
            }
        }
        MathEx.unitize1((double[])pi);
        for (i = 0; i < N; ++i) {
            MathEx.unitize1((double[])a[i]);
            MathEx.unitize1((double[])b[i]);
        }
        return new HMM(pi, DenseMatrix.of((double[][])a), DenseMatrix.of((double[][])b));
    }

    public static <T> HMM fit(T[][] observations, int[][] labels, ToIntFunction<T> ordinal) {
        if (observations.length != labels.length) {
            throw new IllegalArgumentException("The number of observation sequences and that of label sequences are different.");
        }
        return HMM.fit((int[][])Arrays.stream(observations).map(sequence -> Arrays.stream(sequence).mapToInt(ordinal).toArray()).toArray(x$0 -> new int[x$0][]), labels);
    }

    public <T> void update(T[][] observations, int iterations, ToIntFunction<T> ordinal) {
        this.update((int[][])Arrays.stream(observations).map(sequence -> Arrays.stream(sequence).mapToInt(ordinal).toArray()).toArray(x$0 -> new int[x$0][]), iterations);
    }

    public void update(int[][] observations, int iterations) {
        for (int iter = 0; iter < iterations; ++iter) {
            this.iterate(observations);
        }
    }

    private void iterate(int[][] sequences) {
        int i;
        int N = this.a.nrow();
        int M2 = this.b.ncol();
        double[][][] gamma = new double[sequences.length][][];
        double[][] aijNum = new double[N][N];
        double[] aijDen = new double[N];
        for (int k = 0; k < sequences.length; ++k) {
            if (sequences[k].length <= 2) {
                throw new IllegalArgumentException(String.format("Training sequence %d is too short.", k));
            }
            int[] o = sequences[k];
            double[][] alpha = new double[o.length][N];
            double[][] beta = new double[o.length][N];
            double[] scaling = new double[o.length];
            this.forward(o, alpha, scaling);
            this.backward(o, beta, scaling);
            double[][][] xi = this.estimateXi(o, alpha, beta);
            gamma[k] = this.estimateGamma(xi);
            double[][] g = gamma[k];
            int n = o.length - 1;
            for (int i2 = 0; i2 < N; ++i2) {
                for (int t = 0; t < n; ++t) {
                    int n2 = i2;
                    aijDen[n2] = aijDen[n2] + g[t][i2];
                    for (int j = 0; j < N; ++j) {
                        double[] dArray = aijNum[i2];
                        int n3 = j;
                        dArray[n3] = dArray[n3] + xi[t][i2][j];
                    }
                }
            }
        }
        for (i = 0; i < N; ++i) {
            if (aijDen[i] == 0.0) continue;
            for (int j = 0; j < N; ++j) {
                this.a.set(i, j, aijNum[i][j] / aijDen[i]);
            }
        }
        Arrays.fill(this.pi, 0.0);
        for (int j = 0; j < sequences.length; ++j) {
            for (int i3 = 0; i3 < N; ++i3) {
                int n = i3;
                this.pi[n] = this.pi[n] + gamma[j][0][i3];
            }
        }
        i = 0;
        while (i < N) {
            int n = i++;
            this.pi[n] = this.pi[n] / (double)sequences.length;
        }
        this.b.fill(0.0);
        for (i = 0; i < N; ++i) {
            int j;
            double sum = 0.0;
            for (j = 0; j < sequences.length; ++j) {
                int[] o = sequences[j];
                for (int t = 0; t < o.length; ++t) {
                    this.b.add(i, o[t], gamma[j][t][i]);
                    sum += gamma[j][t][i];
                }
            }
            for (j = 0; j < M2; ++j) {
                this.b.div(i, j, sum);
            }
        }
    }

    private double[][][] estimateXi(int[] o, double[][] alpha, double[][] beta) {
        if (o.length <= 1) {
            throw new IllegalArgumentException("Observation sequence is too short.");
        }
        int N = this.a.nrow();
        int n = o.length - 1;
        double[][][] xi = new double[n][N][N];
        for (int t = 0; t < n; ++t) {
            for (int i = 0; i < N; ++i) {
                for (int j = 0; j < N; ++j) {
                    xi[t][i][j] = alpha[t][i] * this.a.get(i, j) * this.b.get(j, o[t + 1]) * beta[t + 1][j];
                }
            }
        }
        return xi;
    }

    private double[][] estimateGamma(double[][][] xi) {
        int N = this.a.nrow();
        double[][] gamma = new double[xi.length + 1][N];
        for (int t = 0; t < xi.length; ++t) {
            for (int i = 0; i < N; ++i) {
                for (int j = 0; j < N; ++j) {
                    double[] dArray = gamma[t];
                    int n = i;
                    dArray[n] = dArray[n] + xi[t][i][j];
                }
            }
        }
        int n = xi.length - 1;
        for (int j = 0; j < N; ++j) {
            for (int i = 0; i < N; ++i) {
                double[] dArray = gamma[xi.length];
                int n2 = j;
                dArray[n2] = dArray[n2] + xi[n][i][j];
            }
        }
        return gamma;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append(String.format("HMM (%d states, %d emission symbols)%n", this.a.nrow(), this.b.ncol()));
        sb.append("Initial state probability: ");
        sb.append(Arrays.toString(this.pi));
        sb.append("\nState transition probability:\n");
        sb.append(this.a);
        sb.append("Symbol emission probability:\n");
        sb.append(this.b);
        return sb.toString();
    }
}

