/*
 * Decompiled with CFR 0.152.
 */
package de.crysandt.hmm;

import de.crysandt.hmm.Alpha;
import de.crysandt.hmm.Betha;
import de.crysandt.hmm.GaussianDistribution;
import de.crysandt.hmm.GaussianDistributionDiagonal;
import de.crysandt.hmm.GaussianDistributionFull;
import de.crysandt.hmm.SortByColumn;
import de.crysandt.math.Function;
import de.crysandt.math.LinAlg;
import de.crysandt.util.Debug;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.TreeSet;

public class HMM {
    public static final int INIT_RELATIVE_TIME = 0;
    public static final int INIT_BEGIN_OF_SEQUENCE = 1;
    public static final int INIT_NONE = 2;
    static final float FLOAT_MIN_VALUE = 1.435E-42f;
    public final int N;
    public final int SIZE;
    private final float[] init;
    private final GaussianDistribution[] dist;
    private final float[][] trans;
    private transient double[][] trans_log = null;

    public HMM(int N, int size, float[] init, GaussianDistribution[] dist, float[][] trans) {
        this.N = N;
        this.SIZE = size;
        this.init = init;
        this.dist = dist;
        this.trans = trans;
    }

    HMM(int N, int size) {
        this(N, size, new float[N], new GaussianDistribution[N], new float[N][N]);
    }

    static GaussianDistribution[] initGaussianDist(int SIZE, int N, Collection sequences) {
        double[] mean = new double[SIZE];
        double[] var = new double[SIZE];
        TreeSet<float[]> vectors = new TreeSet<float[]>(new SortByColumn(0));
        for (float[][] sequence : sequences) {
            int n_max = sequence.length;
            for (int n = 0; n < n_max; ++n) {
                assert (sequence[n].length == SIZE);
                vectors.add(sequence[n]);
            }
        }
        for (float[] vector : vectors) {
            for (int c = 0; c < SIZE; ++c) {
                int n = c;
                mean[n] = mean[n] + (double)vector[c];
                int n2 = c;
                var[n2] = var[n2] + (double)(vector[c] * vector[c]);
            }
        }
        for (int c = 0; c < SIZE; ++c) {
            int n = c;
            mean[n] = mean[n] / (double)vectors.size();
            var[c] = var[c] / (double)vectors.size() - mean[c] * mean[c];
        }
        int index = 0;
        for (int i = 1; i < var.length; ++i) {
            if (!(var[i] > var[index])) continue;
            index = i;
        }
        if (index > 0) {
            TreeSet<float[]> vectors_new = new TreeSet<float[]>(new SortByColumn(index));
            vectors_new.addAll(vectors);
            vectors = vectors_new;
        }
        GaussianDistribution[] dist = new GaussianDistribution[N];
        int GAP = (vectors.size() - 1) / (N + 1) - 1;
        Iterator i = vectors.iterator();
        for (int n = 0; n < N; ++n) {
            for (int k = 0; k < GAP && i.hasNext(); ++k) {
                i.next();
            }
            float[] vector = (float[])i.next();
            float[] center = new float[SIZE];
            float[][] cov_inv = new float[SIZE][SIZE];
            for (int k = 0; k < SIZE; ++k) {
                center[k] = vector[k];
                cov_inv[k][k] = (float)(1.0 / var[k] / (double)SIZE);
            }
            dist[n] = new GaussianDistributionFull(center, cov_inv);
        }
        return dist;
    }

    public static HMM createModel(int N, int size, float[][] sequence, int init_type) {
        ArrayList<float[][]> list = new ArrayList<float[][]>(1);
        list.add(sequence);
        return HMM.createModel(N, size, list, init_type);
    }

    public static HMM createModel(int N, int size, Collection sequences, int init_type) {
        HMM hmm;
        block6: {
            hmm = new HMM(N, size);
            Arrays.fill(hmm.init, 1.0f / (float)N);
            System.arraycopy(HMM.initGaussianDist(size, N, sequences), 0, hmm.dist, 0, N);
            for (int n = 0; n < N; ++n) {
                Arrays.fill(hmm.trans[n], 1.0f / (float)N);
            }
            try {
                double opt_path = hmm.getLogProbBestPath(sequences);
                int num_iter = 0;
                while (true) {
                    HMM hmm_new = hmm.optimizeModel(sequences, init_type);
                    double opt_path_new = hmm_new.getLogProbBestPath(sequences);
                    assert (Debug.println(System.err, num_iter + ", " + opt_path_new + " (" + (opt_path_new - opt_path) + ")"));
                    if (Math.abs(1.0 - opt_path_new / opt_path) < 1.0E-5) {
                        hmm = hmm_new;
                        opt_path = opt_path_new;
                        break;
                    }
                    hmm = hmm_new;
                    opt_path = opt_path_new;
                    ++num_iter;
                }
            }
            catch (Exception e) {
                if ($assertionsDisabled || HMM.printStackTrace(e)) break block6;
                throw new AssertionError();
            }
        }
        return hmm;
    }

    private static boolean printStackTrace(Exception e) {
        e.printStackTrace();
        return true;
    }

    public double getLogProbBestPath(float[][] sequence) {
        int i;
        if (this.trans_log == null) {
            this.trans_log = new double[this.N][this.N];
            for (int i2 = 0; i2 < this.N; ++i2) {
                for (int j = 0; j < this.N; ++j) {
                    this.trans_log[i2][j] = Math.log(1.435E-42f + this.trans[i2][j]);
                    assert (this.trans_log[i2][j] < 0.001);
                }
            }
        }
        double[] max_log_prob = new double[this.N];
        int sequence_last = sequence.length - 1;
        for (i = 0; i < this.N; ++i) {
            max_log_prob[i] = this.dist[i].getLogProb(sequence[sequence_last]);
        }
        for (int r = sequence.length - 2; r >= 0; --r) {
            int i3;
            double[] max_log_prob_new = new double[this.N];
            float[] vector = sequence[r];
            for (i3 = 0; i3 < this.N; ++i3) {
                max_log_prob_new[i3] = this.dist[i3].getLogProb(vector);
            }
            i3 = 0;
            while (i3 < this.N) {
                double best_path = Double.NEGATIVE_INFINITY;
                double[] trans_log_i = this.trans_log[i3];
                for (int j = 0; j < this.N; ++j) {
                    best_path = Math.max(best_path, trans_log_i[j] + max_log_prob[j]);
                }
                assert (!Double.isNaN(best_path));
                assert (!Double.isInfinite(best_path));
                int n = i3++;
                max_log_prob_new[n] = max_log_prob_new[n] + best_path;
            }
            max_log_prob = max_log_prob_new;
        }
        if (this.init != null) {
            for (i = 0; i < this.N; ++i) {
                int n = i;
                max_log_prob[n] = max_log_prob[n] + Math.log(this.init[i]);
            }
        }
        return Function.max(max_log_prob);
    }

    private double getLogProbBestPath(Collection sequences) {
        double sum_log_prob = 0.0;
        Iterator i = sequences.iterator();
        while (i.hasNext()) {
            sum_log_prob += this.getLogProbBestPath((float[][])i.next());
        }
        return sum_log_prob;
    }

    public HMM optimizeModel(Collection sequences, int init_type) {
        double[][][] gamma = new double[sequences.size()][][];
        double[][][][] xi = new double[sequences.size()][][][];
        Iterator iter = sequences.iterator();
        int e = 0;
        while (iter.hasNext()) {
            float[][] sequence = (float[][])iter.next();
            double[][] dist_prob = new double[sequence.length][this.N];
            int t_max = dist_prob.length;
            for (int t = 0; t < t_max; ++t) {
                double[] dist_prob_t = dist_prob[t];
                float[] vector = sequence[t];
                for (int i = 0; i < this.N; ++i) {
                    dist_prob_t[i] = Math.min(3.4028234663852886E38, this.dist[i].getProb(vector) + (double)1.4E-45f);
                }
            }
            Alpha a = new Alpha(dist_prob, this.init, this.trans);
            double[][] alpha = a.getAlpha();
            Betha b = new Betha(dist_prob, this.trans);
            double[][] betha = b.getBetha();
            double[] betha_scal = b.getBethaScal();
            gamma[e] = HMM.calcGammaE(alpha, betha);
            xi[e] = HMM.calcXiE(dist_prob, this.trans, betha, betha_scal, gamma[e]);
            ++e;
        }
        double[] sum_sum_gamma_et = HMM.calcSumSumGammaET(this.N, gamma);
        GaussianDistribution[] dist_new = HMM.calcGaussianDistributions(this.N, this.SIZE, this.dist, gamma, sum_sum_gamma_et, sequences);
        float[][] trans_new = HMM.calcTransitions(xi);
        float[] init_new = new float[this.N];
        switch (init_type) {
            case 1: {
                for (int i = 0; i < this.N; ++i) {
                    double sum = 0.0;
                    int e_max = gamma.length;
                    for (int e2 = 0; e2 < e_max; ++e2) {
                        sum += gamma[e2][0][i];
                    }
                    init_new[i] = (float)(sum / (double)sequences.size());
                }
                break;
            }
            case 0: {
                int i;
                double sum_sum_gamma_eti = 0.0;
                for (i = 0; i < this.N; ++i) {
                    sum_sum_gamma_eti += sum_sum_gamma_et[i];
                }
                for (i = 0; i < this.N; ++i) {
                    init_new[i] = (float)(sum_sum_gamma_et[i] / sum_sum_gamma_eti);
                }
                break;
            }
            case 2: {
                init_new = null;
                break;
            }
            default: {
                assert (false);
                break;
            }
        }
        return new HMM(this.N, this.SIZE, init_new, dist_new, trans_new);
    }

    static double[] calcSumSumGammaET(int N, double[][][] gamma) {
        double[] sum_sum_gamma_et = new double[N];
        for (int e = 0; e < gamma.length; ++e) {
            int t_max = gamma[e].length;
            for (int t = 0; t < t_max; ++t) {
                for (int i = 0; i < N; ++i) {
                    assert (!Double.isNaN(gamma[e][t][i]));
                    int n = i;
                    sum_sum_gamma_et[n] = sum_sum_gamma_et[n] + gamma[e][t][i];
                }
            }
        }
        return sum_sum_gamma_et;
    }

    private static boolean testDistProbT(double[] dist_prob_t) {
        int i;
        for (i = 0; i < dist_prob_t.length; ++i) {
            assert (!Double.isNaN(dist_prob_t[i]));
            assert (!Double.isInfinite(dist_prob_t[i]));
            assert (dist_prob_t[i] >= 0.0);
        }
        int i_max = dist_prob_t.length;
        for (i = 0; i < i_max && !(dist_prob_t[i] > Double.MIN_VALUE); ++i) {
        }
        return i < dist_prob_t.length;
    }

    static double[][] calcGammaE(double[][] alpha, double[][] betha) {
        int T = alpha.length;
        int N = alpha[0].length;
        double[][] gamma_e = new double[T][N];
        for (int t = 0; t < T; ++t) {
            int i;
            double[] alpha_t = alpha[t];
            double[] betha_t = betha[t];
            double sum = 1.4E-45f;
            double[] gamma_et = gamma_e[t];
            for (i = 0; i < N; ++i) {
                gamma_et[i] = alpha_t[i] * betha_t[i];
                sum += gamma_et[i];
                assert (!Double.isNaN(gamma_et[i]));
                assert (!Double.isInfinite(gamma_et[i]));
            }
            assert (!Double.isNaN(sum));
            assert (!Double.isInfinite(sum));
            for (i = 0; i < N; ++i) {
                gamma_et[i] = gamma_et[i] / sum;
            }
        }
        assert (HMM.testGammaE(gamma_e));
        return gamma_e;
    }

    private static boolean testGammaE(double[][] gamma_e) {
        for (double[] gamma_et : gamma_e) {
            int i_max = gamma_et.length;
            for (int i = 0; i < i_max; ++i) {
                assert (!Double.isNaN(gamma_et[i]));
                assert (!Double.isInfinite(gamma_et[i]));
                assert (gamma_et[i] > 0.0);
            }
        }
        return true;
    }

    static double[][][] calcXiE(double[][] dist_prob, float[][] trans, double[][] betha, double[] betha_scal, double[][] gamma_e) {
        int T = betha.length;
        int N = betha[0].length;
        double[][][] xi_e = new double[T - 1][N][N];
        for (int t = 0; t < xi_e.length; ++t) {
            double[][] xi_et = xi_e[t];
            double[] gamma_et = gamma_e[t];
            double[] betha_t = betha[t];
            double[] betha_t1 = betha[t + 1];
            double betha_scal_t = betha_scal[t];
            double[] dist_prob_t = dist_prob[t + 1];
            for (int i = 0; i < N; ++i) {
                double[] xi_eti = xi_et[i];
                float[] trans_i = trans[i];
                double factor = gamma_et[i] / betha_t[i] / betha_scal_t;
                for (int j = 0; j < N; ++j) {
                    xi_eti[j] = (double)1.435E-42f + factor * (double)trans_i[j] * dist_prob_t[j] * betha_t1[j];
                }
            }
        }
        return xi_e;
    }

    static GaussianDistribution[] calcGaussianDistributions(int N, int SIZE, GaussianDistribution[] dist, double[][][] gamma, double[] sum_sum_gamma_et, Collection sequences) {
        GaussianDistribution[] dist_new = new GaussianDistribution[N];
        for (int i = 0; i < N; ++i) {
            double sum_sum_gamma_eti = sum_sum_gamma_et[i];
            double[] mean_new_sum = new double[SIZE];
            Iterator iter = sequences.iterator();
            for (int e = 0; e < sequences.size(); ++e) {
                float[][] sequence = (float[][])iter.next();
                int t_max = sequence.length;
                for (int t = 0; t < t_max; ++t) {
                    double gamma_eti = gamma[e][t][i];
                    float[] vector = sequence[t];
                    int l_max = mean_new_sum.length;
                    for (int l = 0; l < l_max; ++l) {
                        int n = l;
                        mean_new_sum[n] = mean_new_sum[n] + gamma_eti * (double)vector[l];
                    }
                }
            }
            assert (!iter.hasNext());
            float[] mean_new = new float[SIZE];
            int l_max = mean_new.length;
            for (int l = 0; l < l_max; ++l) {
                mean_new[l] = (float)(mean_new_sum[l] / sum_sum_gamma_eti);
            }
            iter = sequences.iterator();
            double[][] cov_new_sum = new double[SIZE][SIZE];
            float[] center_i = dist[i].getCenter();
            int e_max = sequences.size();
            for (int e = 0; e < e_max; ++e) {
                float[][] sequence = (float[][])iter.next();
                double[] diff = new double[SIZE];
                for (int t = 0; t < sequence.length; ++t) {
                    float[] vector = sequence[t];
                    double gamma_eti = gamma[e][t][i];
                    for (int k = 0; k < SIZE; ++k) {
                        diff[k] = vector[k] - center_i[k];
                    }
                    for (int m = 0; m < SIZE; ++m) {
                        double diff_m = diff[m];
                        double[] dArray = cov_new_sum[m];
                        int n = m;
                        dArray[n] = dArray[n] + gamma_eti * diff_m * diff_m;
                        double[] cov_new_sum_m = cov_new_sum[m];
                        for (int n2 = 0; n2 < m; ++n2) {
                            int n3 = n2;
                            cov_new_sum_m[n3] = cov_new_sum_m[n3] + gamma_eti * diff_m * diff[n2];
                            assert (!Double.isNaN(cov_new_sum_m[n2]));
                            assert (!Double.isInfinite(cov_new_sum_m[n2]));
                        }
                    }
                }
            }
            assert (!iter.hasNext());
            float[][] cov_new = new float[SIZE][SIZE];
            for (int m = 0; m < SIZE; ++m) {
                for (int n = 0; n <= m; ++n) {
                    float f = (float)(cov_new_sum[m][n] / sum_sum_gamma_eti);
                    cov_new[n][m] = f;
                    cov_new[m][n] = f;
                }
            }
            try {
                double det = 1.0 / LinAlg.det(cov_new);
                if (Double.isNaN(det)) {
                    throw new IllegalArgumentException("determinant of inverse covariance matrix is NaN");
                }
                if (det > 3.4028234663852886E38) {
                    throw new IllegalArgumentException("determinant of inverse covariance matrix is too large");
                }
                if (det < 0.0) {
                    throw new IllegalArgumentException("determinant of inverse covariance matrix must be positive");
                }
                float[][] cov_new_inv = LinAlg.inv(cov_new);
                GaussianDistributionFull gdf = new GaussianDistributionFull(mean_new, cov_new_inv, (float)det);
                dist_new[i] = gdf;
                continue;
            }
            catch (IllegalArgumentException e) {
                assert (Debug.println(System.err, "Oops: " + e.getMessage() + ". " + "Taking old covariance matrix instead."));
                if (dist[i] instanceof GaussianDistributionDiagonal) {
                    GaussianDistributionDiagonal gdg = (GaussianDistributionDiagonal)dist[i];
                    dist_new[i] = new GaussianDistributionDiagonal(mean_new, gdg.getVarianceInverse());
                    continue;
                }
                dist_new[i] = new GaussianDistributionFull(mean_new, dist[i].getCovarianceInverse(), dist[i].getDeterminant());
            }
        }
        return dist_new;
    }

    static float[][] calcTransitions(double[][][][] xi) {
        int i;
        int N = xi[0][0].length;
        float[][] trans_new = new float[N][N];
        double[][] sum_xi_et = new double[N][N];
        int e_max = xi.length;
        for (int e = 0; e < e_max; ++e) {
            int t_max = xi[e].length;
            for (int t = 0; t < t_max; ++t) {
                double[][] xi_et = xi[e][t];
                for (int i2 = 0; i2 < N; ++i2) {
                    double[] sum_xi_eti = sum_xi_et[i2];
                    double[] xi_eti = xi_et[i2];
                    for (int j = 0; j < N; ++j) {
                        int n = j;
                        sum_xi_eti[n] = sum_xi_eti[n] + xi_eti[j];
                    }
                }
            }
        }
        for (i = 0; i < N; ++i) {
            for (int j = 0; j < N; ++j) {
                trans_new[i][j] = (float)sum_xi_et[i][j];
            }
        }
        for (i = 0; i < N; ++i) {
            int j;
            float[] trans_new_i = trans_new[i];
            double trans_new_i_sum = 0.0;
            for (j = 0; j < N; ++j) {
                trans_new_i_sum += (double)trans_new_i[j];
            }
            if (trans_new_i_sum > 0.0) {
                j = 0;
                while (j < N) {
                    int n = j++;
                    trans_new_i[n] = (float)((double)trans_new_i[n] / trans_new_i_sum);
                }
                continue;
            }
            assert (Debug.println(System.err, "Sum of transitions equals zero"));
            Arrays.fill(trans_new_i, 1.0f / (float)N);
        }
        return trans_new;
    }

    public float[] getInit() {
        if (this.init == null) {
            return null;
        }
        float[] tmp = new float[this.N];
        System.arraycopy(this.init, 0, tmp, 0, this.N);
        return tmp;
    }

    public float[][] getTransitions() {
        float[][] tmp = new float[this.N][this.N];
        for (int i = 0; i < this.N; ++i) {
            System.arraycopy(this.trans[i], 0, tmp[i], 0, this.N);
        }
        return tmp;
    }

    public GaussianDistribution[] getDist() {
        GaussianDistribution[] gd = new GaussianDistribution[this.dist.length];
        System.arraycopy(this.dist, 0, gd, 0, this.dist.length);
        return gd;
    }

    public GaussianDistribution getDist(int index) {
        return this.dist[index];
    }
}

