元理系院生の新入社員がPythonとJavaで色々頑張るブログ

プログラミングや機械学習について調べた事を書いていきます

JavaでErgodic離散隠れマルコフモデル

この記事に書かれている事

javaによるErgodicマルコフの実装

Ergodic離散隠れマルコフモデル

前回(と言っても随分昔の話ですが)、Left to Right型の隠れマルコフモデルを実装しましたが、今回はより一般的なErgodic隠れマルコフモデルの実装を行います。

実装する前に隠れマルコフモデルのパラメータとアルゴリズムに付いて少しだけ触れます。

隠れマルコフモデルのパラメータ

隠れマルコフモデルのパラメータは主に状態遷移確率行列、出力確率行列、初期状態確率ベクトルの3つです。
状態遷移確率行列は、"状態Iから状態Jへ遷移する確率"の組み合わせを行列化したものです。
出力確率行列は、"状態IがKを出力する確率"の組み合わせを行列化したものです。
また、初期状態確率はHMMにシンボル系列が与えられた時に、ある状態が初期状態として選択される確率を並べたものです。
IとかJなどの変数名は今適当に考えたので特に意味は無いです

以上のパラメータが与えられた隠れマルコフモデルに対し、長さNのシンボル系列Xを与える事で色々な事が出来ます。
シンボル系列は観測したデータの事です。音声を正規化したものだったり、DNAの塩基配列だったり、対象とする系列データそのものだって考えておけば大丈夫です。

隠れマルコフモデルで出来ること(の一部)

隠れマルコフモデルでは主に次のような事が出来ます。

  1. シンボル系列の評価
  2. 状態系列の復号化
  3. パラメータの推定
シンボル系列の評価

シンボル系列の評価とは、学習し終えた(パラメータが既知の)隠れマルコフモデルに対しシンボル系列を与えた時に、シンボル系列がその隠れマルコフモデルから得られる確率を評価します。

具体的には、ヒトの塩基配列を学習したHMMに対しある塩基配列を与えた時に、その塩基配列がどの程度ヒトに似ているかを評価します。

この評価は隠れマルコフモデルに対して、Forwardアルゴリズムを適用することで実現できます。

状態系列の復号化

状態系列の復号化では、学習し終えた(パラメータが既知の)隠れマルコフモデルに対しシンボル系列を与えた時に、隠れマルコフモデルがそのシンボル系列を出力するにあたって、どのような状態遷移を行ったかを推定します。

隠れマルコフモデルは各状態毎に各シンボルの出力確率が異なる為、状態の遷移に応じてシンボル系列の傾向が偏っていると考えられます。その傾向の偏りから状態の遷移を推測します。

具体的には、「こんにちは」と言う発声を一文字一文字(本当は音韻とか様々な事を考慮します)状態化し、学習したLeft to Right型の隠れマルコフモデルに対し、別の人が発話した「こんにちは」と言う音声を与えた時、その音声特徴系列を「こ/ん/に/ち/は」とセパレートする事が出来ます。

こちらも隠れマルコフモデル単体では実現する事が出来ず、動的計画法の1つであるViterbiアルゴリズムを適用しています。

パラメータの推定

先の2つの"出来ること"では隠れマルコフモデルのパラメータが既知である前提がありました。
パラメータの推定ではシンボル系列から隠れマルコフモデルのパラメータを学習することが出来ます。

隠れマルコフモデルに大量の時系列データを与えて学習する~と言った処理は、大抵これに当たります。

この処理も例に漏れず、EMアルゴリズムの1つであるbaumwelchアルゴリズムを用いて実現しています。

Javaによる実装

数式を貼ろうかと思ったのですが、hatenaのlatexの動作がまだよく分かってないので、その辺りを調べたらいつか追記します。

public class HMM {
	private int C;				//状態数
	private int M;				//出力シンボル数
	private double[][] A;		//状態遷移確率行列	c*c
	private double[][] B;		//出力確率行列	c*m
	private double[] roe;		//初期状態確率ベクトル	c

	//状態数と出力シンボル数のみを指定して初期化
	private HMM(int stateNum, int outNum){
		this.C = stateNum;
		this.M = outNum;
		this.A = new double[this.C][this.C];
		this.B = new double[this.C][this.M];
		this.roe = new double[this.C];

		//状態遷移確率の初期化
		for(int i=0; i<this.A.length; i++){
			for(int j=0; j<this.A[i].length; j++){
				this.A[i][j] = 1.0 / (double)this.A[i].length;
			}
		}
		//出力確率の初期化
		for(int i=0; i<this.B.length; i++){
			for(int j=0; j<this.B[i].length; j++){
				this.B[i][j] = 1.0 / (double)this.B[i].length;
			}
		}
		//初期状態確率の初期化
		for(int i=0; i<this.roe.length; i++){
			this.roe[i] = 1.0 / (double)this.roe.length;
		}
	}

	//初期パラメータを与えて初期化
	private HMM(double[][] _A, double[][] _B, double[] _roe){
		this.A = _A;
		this.B = _B;
		this.roe = _roe;
		this.C = this.A.length;
		this.M = this.B[0].length;
	}

	//ForwardAlgorithm
	private double forward(double[][] alpha, int[] x, int[] omega, double[][] a, double[][] b){
		//初期化
		int N = x.length;
		for(int i=0; i<this.C; i++){
			alpha[0][i] = this.roe[i] * b[omega[i]][x[0]];
		}

		//再帰的計算
		for(int t=1; t<N; t++){
			for(int j=0; j<this.C; j++){
				double sum = 0;
				for(int i=0; i<this.C; i++){
					sum += alpha[t-1][i] * a[i][j];
				}
				alpha[t][j] = sum * b[omega[j]][x[t]];
			}
		}

		//確率計算
		double p = 0;
		for(int i=0; i<this.C; i++){
			p += alpha[N-1][i];
		}
		return p;
	}

	//backwardAlgorithm
	private double backward(double[][] beta, int [] x, int[] omega, double[][] a, double[][] b){
		//初期化
		int N = x.length;
		for(int i=0; i<this.C; i++){
			beta[N-1][i] = 1.0;
		}

		//再帰的計算
		for(int t=N-2; t>=0; t--){
			for(int i=0; i<this.C; i++){
				double sum = 0;
				for(int j=0; j<this.C; j++){
					sum += a[i][j] * b[omega[j]][x[t+1]] * beta[t+1][j];
				}
				beta[t][i] = sum;
			}
		}

		//確率計算
		double p = 0;
		for(int i=0; i<this.C; i++){
			p += beta[0][i];
		}
		return p;
	}

	//max関数(配列の中から最大値を取得)
	private double max(double[] x){
		double a = x[0];
		for(int i=1; i<x.length; i++){
			if(a < x[i]){
				a = x[i];
			}
		}
		return a;
	}

	//arg max関数(配列の中から最大値となるインデックスを取得)
	private int argmax(double[] x){
		int n = 0;
		double a = x[0];
		for(int i=1; i<x.length; i++){
			if(a < x[i]){
				n = i;
				a = x[i];
			}
		}
		return n;
	}

	//ビタビアルゴリズム
	private void viterbi(double[][] psi, int[][] PSI, int[] I, int[] S_star, int[] x, int[] omega, double[][] a, double[][] b){
		//初期化
		int N = x.length;
		for(int i=0; i<this.C; i++){
			psi[0][i] = this.roe[i] * b[omega[i]][x[0]];
			PSI[0][i] = 0;
		}

		//再帰的計算
		for(int t=1; t<N; t++){
			for(int j=0; j<this.C; j++){
				double[] _psia = new double[this.C];
				for(int i=0; i<this.C; i++){
					_psia[i] = psi[t-1][i] * a[i][j];
				}
				psi[t][j] = max(_psia) * b[omega[j]][x[t]];
				PSI[t][j] = argmax(_psia);
			}
		}

		//終了
		I[N-1] = argmax(psi[N-1]);
		S_star[N-1] = omega[I[N-1]];

		//復元
		for(int t=N-2; t>=0; t--){
			I[t] = PSI[t+1][I[t+1]];
			S_star[t] = omega[I[t]];
		}
	}

	//デルタ関数
	private double delta(int a, int b){
		if(a == b)return 1.0;
		return 0.0;
	}

	//baumwelchAlgorithm
	private void baumWelch(int[] x, int[] omega){
		//初期化
		int N = x.length;								//観測系列長
		double[][] a = this.A;							//状態遷移確率行列
		double[][] b = this.B;							//出力確率行列
		double[][] a_hat = new double[this.C][this.C];	//更新後の状態遷移確率行列
		double[][] b_hat = new double[N][this.C];		//更新後の出力確率行列
		double[] roe_hat = new double[this.C];			//更新後の初期状態確率ベクトル

		double[][] alpha = new double[N][this.C];		//ForwardAlgorithmで使用するトレリス
		double[][] beta = new double[N][this.C];		//BackWardAlgorithmで使用するトレリス

		this.forward(alpha, x, omega, this.A, this.B);
		this.backward(beta, x, omega, this.A, this.B);

		//再帰的計算		状態遷移確率
		for(int i=0; i<this.C; i++){
			for(int j=0; j<this.C; j++){
				double sumA=0, sumB=0;
				for(int t=0; t<N-1; t++){
					sumA += alpha[t][i]*a[i][j]*b[omega[j]][x[t+1]]*beta[t+1][j];
					sumB += alpha[t][i]*beta[t][i];
				}
				a_hat[i][j] = sumA / sumB;
			}
		}
		//再帰的計算		出力確率
		for(int j=0; j<this.C; j++){
			for(int k=0; k<this.M; k++){
				double sumA=0, sumB=0;
				for(int t=0; t<N; t++){
					sumA += delta(x[t], x[k])*alpha[t][j]*beta[t][j];
					sumB += alpha[t][j]*beta[t][j];
				}
				b_hat[j][k] = sumA / sumB;
			}
		}
		//再帰的計算		初期状態確率
		for(int i=0; i<this.C; i++){
			double sumA = 0;
			for(int j=0; j<this.C; j++){
				sumA += alpha[N-1][j];
			}
			roe_hat[i] = alpha[0][i]*beta[0][i] / sumA;
		}

		//パラメータの更新
		this.A = a_hat;
		this.B = b_hat;
		this.roe = roe_hat;
	}

	//ForwardAlgorithmのテスト
	public static void fowardTest(){
		int C = 3;
		int M = 2;

		double[][] a = {
				{0.1, 0.7, 0.2},
				{0.2, 0.1, 0.7},
				{0.7, 0.2, 0.1},
		};

		double[][] b = {
				{0.9, 0.1},
				{0.6, 0.4},
				{0.1, 0.9},
		};
		double[] roe = {1.0/3.0, 1.0/3.0, 1.0/3.0};
		int[] omega = {0, 1, 2};

		HMM hmm = new HMM(a, b, roe);
		int[] x = {0,1,0};
		double[][] alpha = new double[x.length][C];
		double p = hmm.forward(alpha, x, omega, hmm.A, hmm.B);
		System.out.println(p);
		for(int i=0; i<alpha.length; i++){
			for(int j=0; j<alpha[i].length; j++){
				System.out.print(alpha[i][j]+"\t");
			}
			System.out.println();
		}
	}

	//ViterbiAlgorithmのテスト
	public static void viterbiTest(){
		int C = 3;
		int M = 2;

		double[][] a = {
				{0.1, 0.7, 0.2},
				{0.2, 0.1, 0.7},
				{0.7, 0.2, 0.1},
		};

		double[][] b = {
				{0.9, 0.1},
				{0.6, 0.4},
				{0.1, 0.9},
		};
		double[] roe = {1.0/3.0, 1.0/3.0, 1.0/3.0};
		int[] omega = {0, 1, 2};

		HMM hmm = new HMM(a, b, roe);
		int[] x = {0,1,0};

		double[][] psi = new double[x.length][C];
		int[][] PSI = new int[x.length][C];
		int[] I = new int[x.length];
		int[] S_star = new int[x.length];

		hmm.viterbi(psi, PSI, I, S_star, x, omega, hmm.A, hmm.B);
		System.out.println("####\t psi \t###");
		for(int i=0; i<x.length; i++){
			for(int j=0; j<C; j++){
				System.out.print(psi[i][j]+"\t");
			}
			System.out.println();
		}

		System.out.println("####\t PSI \t###");
		for(int i=0; i<x.length; i++){
			for(int j=0; j<C; j++){
				System.out.print(PSI[i][j]+"\t");
			}
			System.out.println();
		}

		System.out.println("####\t S* \t###");
		for(int i=0; i<x.length; i++){
			System.out.print(S_star[i]+"\t");
		}
		System.out.println();

	}

	//パラメータの可視化
	public void show(){
		System.out.println("####\t A \t###");
		for(int i=0; i<this.C; i++){
			for(int j=0; j<C; j++){
				System.out.print(this.A[i][j]+"\t");
			}
			System.out.println();
		}

		System.out.println("####\t B \t###");
		for(int i=0; i<this.C; i++){
			for(int j=0; j<this.M; j++){
				System.out.print(this.B[i][j]+"\t");
			}
			System.out.println();
		}

		System.out.println("####\t roe \t###");
		for(int i=0; i<this.C; i++){
			System.out.print(this.roe[i]+"\t");
		}
		System.out.println();
	}
}

使い方は動作確認用に追加したtestメソッドを参照してください。
また、baumwelchアルゴリズムは現状ではスケーリング関数や対数確率にしていない為、長いシンボル系列を与えるとアンダーフローが発生すると思います。