元理系院生の新入社員がPythonとJavaで色々頑張るブログ

プログラミングや機械学習について調べた事を書いていきます

【Java】スケーリング処理を用いた離散隠れマルコフモデルの実装

前回実装した離散隠れマルコフモデルは、フォワード、バックワードアルゴリズムの確率計算を行う際に、長い時系列を適用するとアンダーフローが発生する問題がありました。emoson.hateblo.jp

スケーリング処理はトレリス計算中に逐次スケーリング処理を施し、アンダーフローの発生を抑制します。


スケーリング処理を適用したHMMのコードは次のようになります。

package emoson.hmm;

public class HMM {
	private int N;			//状態数
	private int L;			//シンボル数
	private double[][] A;	//状態遷移確率
	private double[][][] B;	//出力確率
	private double[] pi;	//初期状態確率

	public HMM(double[][] A, double[][][] B, double[] pi){
		this.N = A.length;
		this.L = B[0][0].length;
		this.A = A;
		this.B = B;
		this.pi = pi;
	}
	
	//スケーリング係数の導出
	private double getC(double[] alpha){
		double sum = 0;
		for(int i=0; i<alpha.length; i++)sum += alpha[i];
		return 1.0 / sum;
	}
	
	private double forwardScaling(int[] o, double[][] alphaPrime, double[] C){
		int T = o.length;
		double[][] alphaStar = new double[T+1][this.N];
		for(int i=0; i<this.N; i++){
			alphaStar[0][i] = this.pi[i];
		}
		C[0] = this.getC(alphaStar[0]);
		for(int i=0; i<this.N; i++){
			alphaPrime[0][i] = alphaStar[0][i] * C[0];
		}

		for(int t=1; t<C.length; t++){
			for(int i=0; i<this.N; i++){
				double sum = 0;
				for(int j=0; j<this.N; j++){
					sum += alphaPrime[t-1][j] * this.A[j][i] * this.B[j][i][o[t-1]];
				}
				alphaStar[t][i] = sum;
			}
			C[t] = this.getC(alphaStar[t]);
			for(int i=0; i<this.N; i++){
				alphaPrime[t][i] = C[t] * alphaStar[t][i];
			}
		}

		double sum = 0;
		for(int t=0; t<C.length; t++){
			sum += Math.log(C[t]);
		}
		return sum;
	}

	private double backwardScaling(int[] o, double[][] betaPrime, double[] C){
		int T = o.length;
		double[][] betaStar = new double[T+1][this.N];
		for(int i=0; i<this.N; i++){
			betaStar[T][i] = 1.0;
			betaPrime[T][i] = C[T] * betaStar[T][i];
		}

		for(int t=T-1; t>=0; t--){
			for(int i=0; i<this.N; i++){
				double sum = 0;
				for(int j=0; j<this.N; j++){
					sum += this.A[i][j]*this.B[i][j][o[t]]*betaPrime[t+1][j];
				}
				betaStar[t][i] = sum;
				betaPrime[t][i] = C[t] * sum;
			}
		}
		double sum = 0;
		for(int t=0; t<T; t++){
			sum += Math.log(C[t]);
		}
		return sum;
	}

	private double getLikelihood(int[] simbol){
		int T = simbol.length;
		double[] C = new double[T+1];
		double[][] alphaPrime = new double[T+1][A.length];
		return this.forwardScaling(simbol, alphaPrime, C);
	}
	
	private double viterbi(int[] o, int[] sStar){
		int T = o.length;
		double[][] delta = new double[T+1][this.N];
		int[][] phi = new int[T+1][this.N];

		for(int i=0; i<this.N; i++){
			delta[0][i] = this.pi[i];
			phi[0][i] = 0;
		}

		for(int t=1; t<T+1; t++){
			for(int j=0; j<this.N; j++){
				double max = delta[t-1][0] * this.A[0][j] * this.B[0][j][o[t-1]];
				int argmax = 0;
				for(int i=1; i<this.N; i++){
					double v = delta[t-1][i] * this.A[i][j] * this.B[i][j][o[t-1]];
					if(v > max){
						max = v;
						argmax = i;
					}
				}
				delta[t][j] = max;
				phi[t][j] = argmax;
			}
		}

		double pStar = delta[T][0];
		double pstarmax = delta[T][0];
		int sstarmax = 0;
		for(int i=0; i<this.N; i++){
			if(delta[T][i] > pstarmax){
				pstarmax = delta[T][i];
				sstarmax = i;
			}
		}
		sStar[T] = sstarmax;
		pStar = pstarmax;

		for(int t=T-1; t>=0; t--){
			sStar[t] = phi[t+1][sStar[t+1]];
		}
		return pStar;
	}

	public double baumWelch(int[] o){
		int T = o.length;
		double[] C = new double[T+1];
		double[][] alphaPrime = new double[T+1][A.length];
		double[][] betaPrime = new double[T+1][A.length];
		//フォワードバックワードアルゴリズム
		this.forwardScaling(o, alphaPrime, C);
		this.backwardScaling(o, betaPrime, C);

		double[] piHat = new double[this.N];
		for(int i=0; i<this.N; i++){
			piHat[i] = alphaPrime[0][i] * betaPrime[0][i];
		}

		double[][] aHat = new double[this.N][this.N];
		double[][][] bHat = new double[this.N][this.N][this.L];
		for(int i=0; i<this.N; i++){
			for(int j=0; j<this.N; j++){
				double sumA=0, sumB=0;
				for(int t=0; t<T; t++){
					sumA += alphaPrime[t][i]*this.A[i][j]*this.B[i][j][o[t]]*betaPrime[t+1][j];
					sumB += alphaPrime[t][i]*betaPrime[t][i] / C[t];
				}
				aHat[i][j] = sumA / sumB;

			}
		}

		for(int i=0; i<this.N; i++){
			for(int j=0; j<this.N; j++){
				for(int v=0; v<this.L; v++){
					double sumA=0, sumB=0;
					for(int t=0; t<T; t++){
						double del = o[t] == v ? 1.0 : 0.0;
						sumA += del * alphaPrime[t][i]*this.A[i][j]*this.B[i][j][o[t]]*betaPrime[t+1][j];
						sumB += alphaPrime[t][i] * this.A[i][j] * this.B[i][j][o[t]]*betaPrime[t+1][j];
					}
					if(sumB != 0.0)bHat[i][j][v] = sumA / sumB;

				}
			}
		}
		this.A = aHat;
		this.B = bHat;
		this.pi = piHat;
		return this.forwardScaling(o, alphaPrime, C);
	}
	
	public static void main(String[] args) {
		//状態遷移確率(Left to Right)
		double[][] A = {
				{0.3, 0.5, 0.2},
				{0.0, 0.4, 0.6},
				{0.0, 0.0, 1.0},
		};
		//シンボル出力確率(Left to Right)
		double[][][] B = {
				{{0.5, 0.5}, {0.5, 0.5}, {0.5, 0.5}},
				{{0.0, 0.0}, {0.5, 0.5}, {0.5, 0.5}},
				{{0.0, 0.0}, {0.0, 0.0}, {0.5, 0.5}},
		};
		//初期状態確率(Left to Right)
		double[] pi = {1.0, 0.0, 0.0};
		
		//サンプル・データとして性質の反転した2つの時系列を用意
		int[][] o = {
				{0,0,0,0,1,1,1,1,1,1,1},
				{1,1,1,1,1,1,0,0,0,0,0,0}
		};

		//それぞれ学習
		HMM[] hmm = new HMM[2];
		//学習回数
		int EPOCHMAX = 10;
		
		for(int i=0; i<o.length; i++){
			System.out.println("●信号"+i+"を学習します");
			hmm[i] = new HMM(A, B, pi);
			for(int j=0; j<EPOCHMAX; j++)
				System.out.println("E\t"+hmm[i].baumWelch(o[i]));
		}
		
		//テストデータ
		int[][] test = {
				{0,0,0,0,1,1,1,1,1,1,1,1,1,1,1},	//信号0に似せた
				{0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1},	//信号0に似せた
				{1,1,1,0,0,0,},						//信号1に似せた
				{1,1,1,1,1,1,1,1,1,1,1,1,0,0,0},	//信号1に似せた
		};
		for(int i=0; i<test.length; i++){
			double[] li = new double[2];
			li[0] = hmm[0].getLikelihood(test[i]);
			li[1] = hmm[1].getLikelihood(test[i]);
			int n = 0;
			if(li[0] > li[1])n = 1;
			System.out.println("test\t"+i+"\t 識別結果\t"+n+
					"\t HMM0の尤度\t"+li[0]+"\t HMM1の尤度\t"+li[1]);
		}
	}
}

main関数で動作のテストを行っています。
状態数3、シンボル数を2とし2つのモデルに対し2種類の信号を学習させています。
その後、適当に作成した信号をモデルに与えて尤度を算出しました。
実行結果は次のようになります。

test	0	 識別結果	0	 HMM0の尤度	1.6746796462964881	 HMM1の尤度	180.89354940726736
test	1	 識別結果	0	 HMM0の尤度	4.732027840593439	 HMM1の尤度	110.70960962753293
test	2	 識別結果	1	 HMM0の尤度	72.42022891475443	 HMM1の尤度	1.9097022789346096
test	3	 識別結果	1	 HMM0の尤度	72.42022982714094	 HMM1の尤度	3.853832763144159

ちゃんと識別できていますね。
しかし、このままでは一つのモデルに対し一つの信号しか学習できないので、次は大量のデータを学習できるようにHMMを拡張します。