JavaでErgodic離散隠れマルコフモデル
この記事に書かれている事
javaによるErgodicマルコフの実装
Ergodic離散隠れマルコフモデル
前回(と言っても随分昔の話ですが)、Left to Right型の隠れマルコフモデルを実装しましたが、今回はより一般的なErgodic隠れマルコフモデルの実装を行います。
実装する前に隠れマルコフモデルのパラメータとアルゴリズムに付いて少しだけ触れます。
隠れマルコフモデルのパラメータ
隠れマルコフモデルのパラメータは主に状態遷移確率行列、出力確率行列、初期状態確率ベクトルの3つです。
状態遷移確率行列は、"状態Iから状態Jへ遷移する確率"の組み合わせを行列化したものです。
出力確率行列は、"状態IがKを出力する確率"の組み合わせを行列化したものです。
また、初期状態確率はHMMにシンボル系列が与えられた時に、ある状態が初期状態として選択される確率を並べたものです。
IとかJなどの変数名は今適当に考えたので特に意味は無いです
以上のパラメータが与えられた隠れマルコフモデルに対し、長さNのシンボル系列Xを与える事で色々な事が出来ます。
シンボル系列は観測したデータの事です。音声を正規化したものだったり、DNAの塩基配列だったり、対象とする系列データそのものだって考えておけば大丈夫です。
隠れマルコフモデルで出来ること(の一部)
隠れマルコフモデルでは主に次のような事が出来ます。
- シンボル系列の評価
- 状態系列の復号化
- パラメータの推定
シンボル系列の評価
シンボル系列の評価とは、学習し終えた(パラメータが既知の)隠れマルコフモデルに対しシンボル系列を与えた時に、シンボル系列がその隠れマルコフモデルから得られる確率を評価します。
具体的には、ヒトの塩基配列を学習したHMMに対しある塩基配列を与えた時に、その塩基配列がどの程度ヒトに似ているかを評価します。
状態系列の復号化
状態系列の復号化では、学習し終えた(パラメータが既知の)隠れマルコフモデルに対しシンボル系列を与えた時に、隠れマルコフモデルがそのシンボル系列を出力するにあたって、どのような状態遷移を行ったかを推定します。
隠れマルコフモデルは各状態毎に各シンボルの出力確率が異なる為、状態の遷移に応じてシンボル系列の傾向が偏っていると考えられます。その傾向の偏りから状態の遷移を推測します。
具体的には、「こんにちは」と言う発声を一文字一文字(本当は音韻とか様々な事を考慮します)状態化し、学習したLeft to Right型の隠れマルコフモデルに対し、別の人が発話した「こんにちは」と言う音声を与えた時、その音声特徴系列を「こ/ん/に/ち/は」とセパレートする事が出来ます。
こちらも隠れマルコフモデル単体では実現する事が出来ず、動的計画法の1つであるViterbiアルゴリズムを適用しています。
Javaによる実装
数式を貼ろうかと思ったのですが、hatenaのlatexの動作がまだよく分かってないので、その辺りを調べたらいつか追記します。
public class HMM { private int C; //状態数 private int M; //出力シンボル数 private double[][] A; //状態遷移確率行列 c*c private double[][] B; //出力確率行列 c*m private double[] roe; //初期状態確率ベクトル c //状態数と出力シンボル数のみを指定して初期化 private HMM(int stateNum, int outNum){ this.C = stateNum; this.M = outNum; this.A = new double[this.C][this.C]; this.B = new double[this.C][this.M]; this.roe = new double[this.C]; //状態遷移確率の初期化 for(int i=0; i<this.A.length; i++){ for(int j=0; j<this.A[i].length; j++){ this.A[i][j] = 1.0 / (double)this.A[i].length; } } //出力確率の初期化 for(int i=0; i<this.B.length; i++){ for(int j=0; j<this.B[i].length; j++){ this.B[i][j] = 1.0 / (double)this.B[i].length; } } //初期状態確率の初期化 for(int i=0; i<this.roe.length; i++){ this.roe[i] = 1.0 / (double)this.roe.length; } } //初期パラメータを与えて初期化 private HMM(double[][] _A, double[][] _B, double[] _roe){ this.A = _A; this.B = _B; this.roe = _roe; this.C = this.A.length; this.M = this.B[0].length; } //ForwardAlgorithm private double forward(double[][] alpha, int[] x, int[] omega, double[][] a, double[][] b){ //初期化 int N = x.length; for(int i=0; i<this.C; i++){ alpha[0][i] = this.roe[i] * b[omega[i]][x[0]]; } //再帰的計算 for(int t=1; t<N; t++){ for(int j=0; j<this.C; j++){ double sum = 0; for(int i=0; i<this.C; i++){ sum += alpha[t-1][i] * a[i][j]; } alpha[t][j] = sum * b[omega[j]][x[t]]; } } //確率計算 double p = 0; for(int i=0; i<this.C; i++){ p += alpha[N-1][i]; } return p; } //backwardAlgorithm private double backward(double[][] beta, int [] x, int[] omega, double[][] a, double[][] b){ //初期化 int N = x.length; for(int i=0; i<this.C; i++){ beta[N-1][i] = 1.0; } //再帰的計算 for(int t=N-2; t>=0; t--){ for(int i=0; i<this.C; i++){ double sum = 0; for(int j=0; j<this.C; j++){ sum += a[i][j] * b[omega[j]][x[t+1]] * beta[t+1][j]; } beta[t][i] = sum; } } //確率計算 double p = 0; for(int i=0; i<this.C; i++){ p += beta[0][i]; } return p; } //max関数(配列の中から最大値を取得) private double max(double[] x){ double a = x[0]; for(int i=1; i<x.length; i++){ if(a < x[i]){ a = x[i]; } } return a; } //arg max関数(配列の中から最大値となるインデックスを取得) private int argmax(double[] x){ int n = 0; double a = x[0]; for(int i=1; i<x.length; i++){ if(a < x[i]){ n = i; a = x[i]; } } return n; } //ビタビアルゴリズム private void viterbi(double[][] psi, int[][] PSI, int[] I, int[] S_star, int[] x, int[] omega, double[][] a, double[][] b){ //初期化 int N = x.length; for(int i=0; i<this.C; i++){ psi[0][i] = this.roe[i] * b[omega[i]][x[0]]; PSI[0][i] = 0; } //再帰的計算 for(int t=1; t<N; t++){ for(int j=0; j<this.C; j++){ double[] _psia = new double[this.C]; for(int i=0; i<this.C; i++){ _psia[i] = psi[t-1][i] * a[i][j]; } psi[t][j] = max(_psia) * b[omega[j]][x[t]]; PSI[t][j] = argmax(_psia); } } //終了 I[N-1] = argmax(psi[N-1]); S_star[N-1] = omega[I[N-1]]; //復元 for(int t=N-2; t>=0; t--){ I[t] = PSI[t+1][I[t+1]]; S_star[t] = omega[I[t]]; } } //デルタ関数 private double delta(int a, int b){ if(a == b)return 1.0; return 0.0; } //baumwelchAlgorithm private void baumWelch(int[] x, int[] omega){ //初期化 int N = x.length; //観測系列長 double[][] a = this.A; //状態遷移確率行列 double[][] b = this.B; //出力確率行列 double[][] a_hat = new double[this.C][this.C]; //更新後の状態遷移確率行列 double[][] b_hat = new double[N][this.C]; //更新後の出力確率行列 double[] roe_hat = new double[this.C]; //更新後の初期状態確率ベクトル double[][] alpha = new double[N][this.C]; //ForwardAlgorithmで使用するトレリス double[][] beta = new double[N][this.C]; //BackWardAlgorithmで使用するトレリス this.forward(alpha, x, omega, this.A, this.B); this.backward(beta, x, omega, this.A, this.B); //再帰的計算 状態遷移確率 for(int i=0; i<this.C; i++){ for(int j=0; j<this.C; j++){ double sumA=0, sumB=0; for(int t=0; t<N-1; t++){ sumA += alpha[t][i]*a[i][j]*b[omega[j]][x[t+1]]*beta[t+1][j]; sumB += alpha[t][i]*beta[t][i]; } a_hat[i][j] = sumA / sumB; } } //再帰的計算 出力確率 for(int j=0; j<this.C; j++){ for(int k=0; k<this.M; k++){ double sumA=0, sumB=0; for(int t=0; t<N; t++){ sumA += delta(x[t], x[k])*alpha[t][j]*beta[t][j]; sumB += alpha[t][j]*beta[t][j]; } b_hat[j][k] = sumA / sumB; } } //再帰的計算 初期状態確率 for(int i=0; i<this.C; i++){ double sumA = 0; for(int j=0; j<this.C; j++){ sumA += alpha[N-1][j]; } roe_hat[i] = alpha[0][i]*beta[0][i] / sumA; } //パラメータの更新 this.A = a_hat; this.B = b_hat; this.roe = roe_hat; } //ForwardAlgorithmのテスト public static void fowardTest(){ int C = 3; int M = 2; double[][] a = { {0.1, 0.7, 0.2}, {0.2, 0.1, 0.7}, {0.7, 0.2, 0.1}, }; double[][] b = { {0.9, 0.1}, {0.6, 0.4}, {0.1, 0.9}, }; double[] roe = {1.0/3.0, 1.0/3.0, 1.0/3.0}; int[] omega = {0, 1, 2}; HMM hmm = new HMM(a, b, roe); int[] x = {0,1,0}; double[][] alpha = new double[x.length][C]; double p = hmm.forward(alpha, x, omega, hmm.A, hmm.B); System.out.println(p); for(int i=0; i<alpha.length; i++){ for(int j=0; j<alpha[i].length; j++){ System.out.print(alpha[i][j]+"\t"); } System.out.println(); } } //ViterbiAlgorithmのテスト public static void viterbiTest(){ int C = 3; int M = 2; double[][] a = { {0.1, 0.7, 0.2}, {0.2, 0.1, 0.7}, {0.7, 0.2, 0.1}, }; double[][] b = { {0.9, 0.1}, {0.6, 0.4}, {0.1, 0.9}, }; double[] roe = {1.0/3.0, 1.0/3.0, 1.0/3.0}; int[] omega = {0, 1, 2}; HMM hmm = new HMM(a, b, roe); int[] x = {0,1,0}; double[][] psi = new double[x.length][C]; int[][] PSI = new int[x.length][C]; int[] I = new int[x.length]; int[] S_star = new int[x.length]; hmm.viterbi(psi, PSI, I, S_star, x, omega, hmm.A, hmm.B); System.out.println("####\t psi \t###"); for(int i=0; i<x.length; i++){ for(int j=0; j<C; j++){ System.out.print(psi[i][j]+"\t"); } System.out.println(); } System.out.println("####\t PSI \t###"); for(int i=0; i<x.length; i++){ for(int j=0; j<C; j++){ System.out.print(PSI[i][j]+"\t"); } System.out.println(); } System.out.println("####\t S* \t###"); for(int i=0; i<x.length; i++){ System.out.print(S_star[i]+"\t"); } System.out.println(); } //パラメータの可視化 public void show(){ System.out.println("####\t A \t###"); for(int i=0; i<this.C; i++){ for(int j=0; j<C; j++){ System.out.print(this.A[i][j]+"\t"); } System.out.println(); } System.out.println("####\t B \t###"); for(int i=0; i<this.C; i++){ for(int j=0; j<this.M; j++){ System.out.print(this.B[i][j]+"\t"); } System.out.println(); } System.out.println("####\t roe \t###"); for(int i=0; i<this.C; i++){ System.out.print(this.roe[i]+"\t"); } System.out.println(); } }
使い方は動作確認用に追加したtestメソッドを参照してください。
また、baumwelchアルゴリズムは現状ではスケーリング関数や対数確率にしていない為、長いシンボル系列を与えるとアンダーフローが発生すると思います。