【Java】複数の信号を学習する離散隠れマルコフモデル
前回実装したスケーリング処理を用いた離散隠れマルコフモデルは、EMアルゴリズムによって最適化パラメータを算出した際に、即時にHMMのパラメータを更新する為、複数の学習データを用いて学習することが出来ませんでした。emoson.hateblo.jp
そこで今回は、最初に全ての学習データに対して最適化パラメータを算出し、その後、更新パラメータの平均をHMMに適用する様にします。
新たなパラメータ更新法を適用したHMMのコードは次のようになります。
package emoson.hmm; public class HMM { private int N; //状態数 private int L; //シンボル数 private double[][] A; //状態遷移確率 private double[][][] B; //出力確率 private double[] pi; //初期状態確率 public HMM(double[][] A, double[][][] B, double[] pi){ this.N = A.length; this.L = B[0][0].length; this.A = A; this.B = B; this.pi = pi; } //スケーリング係数の導出 private double getC(double[] alpha){ double sum = 0; for(int i=0; i<alpha.length; i++)sum += alpha[i]; return 1.0 / sum; } //フォワードアルゴリズム private double forwardScaling(int[] o, double[][] alphaPrime, double[] C){ int T = o.length; double[][] alphaStar = new double[T+1][this.N]; for(int i=0; i<this.N; i++){ alphaStar[0][i] = this.pi[i]; } C[0] = this.getC(alphaStar[0]); for(int i=0; i<this.N; i++){ alphaPrime[0][i] = alphaStar[0][i] * C[0]; } for(int t=1; t<C.length; t++){ for(int i=0; i<this.N; i++){ double sum = 0; for(int j=0; j<this.N; j++){ sum += alphaPrime[t-1][j] * this.A[j][i] * this.B[j][i][o[t-1]]; } alphaStar[t][i] = sum; } C[t] = this.getC(alphaStar[t]); for(int i=0; i<this.N; i++){ alphaPrime[t][i] = C[t] * alphaStar[t][i]; } } double sum = 0; for(int t=0; t<C.length; t++){ sum += Math.log(C[t]); } return sum; } //バックワードアルゴリズム private double backwardScaling(int[] o, double[][] betaPrime, double[] C){ int T = o.length; double[][] betaStar = new double[T+1][this.N]; for(int i=0; i<this.N; i++){ betaStar[T][i] = 1.0; betaPrime[T][i] = C[T] * betaStar[T][i]; } for(int t=T-1; t>=0; t--){ for(int i=0; i<this.N; i++){ double sum = 0; for(int j=0; j<this.N; j++){ sum += this.A[i][j]*this.B[i][j][o[t]]*betaPrime[t+1][j]; } betaStar[t][i] = sum; betaPrime[t][i] = C[t] * sum; } } double sum = 0; for(int t=0; t<T; t++){ sum += Math.log(C[t]); } return sum; } //尤度の算出 public double getLikelihood(int[] simbol){ int T = simbol.length; double[] C = new double[T+1]; double[][] alphaPrime = new double[T+1][A.length]; double l = this.forwardScaling(simbol, alphaPrime, C); if(Double.isNaN(l))l=Double.POSITIVE_INFINITY; return l; } private double viterbi(int[] o, int[] sStar){ int T = o.length; double[][] delta = new double[T+1][this.N]; int[][] phi = new int[T+1][this.N]; for(int i=0; i<this.N; i++){ delta[0][i] = this.pi[i]; phi[0][i] = 0; } for(int t=1; t<T+1; t++){ for(int j=0; j<this.N; j++){ double max = delta[t-1][0] * this.A[0][j] * this.B[0][j][o[t-1]]; int argmax = 0; for(int i=1; i<this.N; i++){ double v = delta[t-1][i] * this.A[i][j] * this.B[i][j][o[t-1]]; if(v > max){ max = v; argmax = i; } } delta[t][j] = max; phi[t][j] = argmax; } } double pStar = delta[T][0]; double pstarmax = delta[T][0]; int sstarmax = 0; for(int i=0; i<this.N; i++){ if(delta[T][i] > pstarmax){ pstarmax = delta[T][i]; sstarmax = i; } } sStar[T] = sstarmax; pStar = pstarmax; for(int t=T-1; t>=0; t--){ sStar[t] = phi[t+1][sStar[t+1]]; } return pStar; } //複数のデータを学習するbaumWelchアルゴリズム private double baumWelchBatch(int[][] o){ double[] piHat = new double[this.N]; double[][] aHat = new double[this.N][this.N]; double[][][] bHat = new double[this.N][this.N][this.L]; //ファイル毎に学習 for(int f=0; f<o.length; f++){ int T = o[f].length; double[] C = new double[T+1]; double[][] alphaPrime = new double[T+1][A.length]; double[][] betaPrime = new double[T+1][A.length]; //フォワードバックワードアルゴリズム this.forwardScaling(o[f], alphaPrime, C); this.backwardScaling(o[f], betaPrime, C); //初期状態確率の算出 for(int i=0; i<this.N; i++){ piHat[i] += alphaPrime[0][i] * betaPrime[0][i]; } //状態遷移確率の算出 for(int i=0; i<this.N; i++){ for(int j=0; j<this.N; j++){ double sumA=0, sumB=0; for(int t=0; t<T; t++){ sumA += alphaPrime[t][i]*this.A[i][j]*this.B[i][j][o[f][t]]*betaPrime[t+1][j]; sumB += alphaPrime[t][i]*betaPrime[t][i] / C[t]; } //加算するようにする aHat[i][j] += sumA / sumB; } } //出力確率の算出 for(int i=0; i<this.N; i++){ for(int j=0; j<this.N; j++){ for(int v=0; v<this.L; v++){ double sumA=0, sumB=0; for(int t=0; t<T; t++){ double del = o[f][t] == v ? 1.0 : 0.0; sumA += del * alphaPrime[t][i]*this.A[i][j]*this.B[i][j][o[f][t]]*betaPrime[t+1][j]; sumB += alphaPrime[t][i] * this.A[i][j] * this.B[i][j][o[f][t]]*betaPrime[t+1][j]; } //加算するようにする if(sumB != 0.0)bHat[i][j][v] += sumA / sumB; } } } } //状態遷移確率等の更新値の平均を取る for(int i=0; i<this.N; i++){ piHat[i] /= (double)o.length; for(int j=0; j<this.N; j++){ aHat[i][j] /= (double)o.length; for(int l=0; l<this.L; l++){ bHat[i][j][l] /= (double)o.length; } } } //パラメータの更新 this.A = aHat; this.B = bHat; this.pi = piHat; //平均尤度の算出 double like = 0; for(int f=0; f<o.length; f++){ like += this.getLikelihood(o[f]); } return like / (double)o.length; } //HMMの学習 public static HMM train(int[][] o, int N, int L, int EPOCH_MAX, boolean showLog){ /* input: * o 学習ベクトル * N 状態数 * L シンボル数 * EPOCH_MAX 学習回数 * showLog trueならば学習時に尤度を表示 * * output: * HMM */ HMM hmm = new HMM(makeA(N), makeB(N, L), makePi(N)); for(int epoch=0; epoch<EPOCH_MAX; epoch++){ double err = hmm.baumWelchBatch(o); if(showLog)System.out.println(err); } return hmm; } //HMMの学習 public static HMM train(int[][] o, int N, int L, double threshold, boolean showLog){ /* input: * o 学習ベクトル * N 状態数 * L シンボル数 * threshold 学習収束条件 * showLog trueならば学習時に尤度を表示 * * output: * HMM */ HMM hmm = new HMM(makeA(N), makeB(N, L), makePi(N)); double err = Double.MAX_VALUE; while(true){ double _err = hmm.baumWelchBatch(o); if(showLog)System.out.println(err); if(Math.abs(err-_err) < threshold)break; err = _err; } return hmm; } //状態遷移配列の生成 private static double[][] makeA(int N){ double[][] a = new double[N][N]; for(int i=0; i<N; i++){ for(int j=0; j<N; j++){ a[i][j] = 1.0 / (double)(N-i); } } return a; } //出力確率配列の生成 private static double[][][] makeB(int N, int L){ double[][][] b = new double[N][N][L]; for(int i=0; i<N; i++){ for(int j=i; j<N; j++){ for(int l=0; l<L; l++){ b[i][j][l] = 1.0 /(double)L; } } } return b; } //初期状態確率配列の生成 private static double[] makePi(int N){ double[] pi = new double[N]; pi[0] = 1.0; return pi; } }
HMMの学習は次のように行います
//学習データ int[][] o = { {0,0,0,0,1,1,1,1,1,2,2,2,2,2,2,2}, {0,0,1,1,1,1,1,2,2,2,2,2,2}, {0,0,0,0,0,0,0,1,1,1,2,2,2,2,2,2}, }; //状態数 int N = 3; //シンボル数 int L = 3; //学習回数 int EPOCH_MAX = 100; HMM hmm = HMM.train(o, N, L, EPOCH_MAX, true);
また、新たな信号に対して尤度を求める際には次のようにします
int[] testO = {0,0,0,1,1,1,1,1,1,1,1,1,2,2,2}; double like = hmm.getLikelihood(testO); System.out.println("尤度\t"+like);
これでアンダフロー等のHMMを実装する際の諸問題を解決できたので、次回はいよいよ手書き数字認識等に適用したいと思います。