元理系院生の新入社員がPythonとJavaで色々頑張るブログ

プログラミングや機械学習について調べた事を書いていきます

【Java】複数の信号を学習する離散隠れマルコフモデル

前回実装したスケーリング処理を用いた離散隠れマルコフモデルは、EMアルゴリズムによって最適化パラメータを算出した際に、即時にHMMのパラメータを更新する為、複数の学習データを用いて学習することが出来ませんでした。emoson.hateblo.jp

そこで今回は、最初に全ての学習データに対して最適化パラメータを算出し、その後、更新パラメータの平均をHMMに適用する様にします。


新たなパラメータ更新法を適用したHMMのコードは次のようになります。

package emoson.hmm;

public class HMM {
	private int N;			//状態数
	private int L;			//シンボル数
	private double[][] A;	//状態遷移確率
	private double[][][] B;	//出力確率
	private double[] pi;	//初期状態確率

	public HMM(double[][] A, double[][][] B, double[] pi){
		this.N = A.length;
		this.L = B[0][0].length;
		this.A = A;
		this.B = B;
		this.pi = pi;
	}

	//スケーリング係数の導出
	private double getC(double[] alpha){
		double sum = 0;
		for(int i=0; i<alpha.length; i++)sum += alpha[i];
		return 1.0 / sum;
	}

	//フォワードアルゴリズム
	private double forwardScaling(int[] o, double[][] alphaPrime, double[] C){
		int T = o.length;
		double[][] alphaStar = new double[T+1][this.N];
		for(int i=0; i<this.N; i++){
			alphaStar[0][i] = this.pi[i];
		}
		C[0] = this.getC(alphaStar[0]);
		for(int i=0; i<this.N; i++){
			alphaPrime[0][i] = alphaStar[0][i] * C[0];
		}

		for(int t=1; t<C.length; t++){
			for(int i=0; i<this.N; i++){
				double sum = 0;
				for(int j=0; j<this.N; j++){
					sum += alphaPrime[t-1][j] * this.A[j][i] * this.B[j][i][o[t-1]];
				}
				alphaStar[t][i] = sum;
			}
			C[t] = this.getC(alphaStar[t]);
			for(int i=0; i<this.N; i++){
				alphaPrime[t][i] = C[t] * alphaStar[t][i];
			}
		}

		double sum = 0;
		for(int t=0; t<C.length; t++){
			sum += Math.log(C[t]);
		}
		return sum;
	}

	//バックワードアルゴリズム
	private double backwardScaling(int[] o, double[][] betaPrime, double[] C){
		int T = o.length;
		double[][] betaStar = new double[T+1][this.N];
		for(int i=0; i<this.N; i++){
			betaStar[T][i] = 1.0;
			betaPrime[T][i] = C[T] * betaStar[T][i];
		}

		for(int t=T-1; t>=0; t--){
			for(int i=0; i<this.N; i++){
				double sum = 0;
				for(int j=0; j<this.N; j++){
					sum += this.A[i][j]*this.B[i][j][o[t]]*betaPrime[t+1][j];
				}
				betaStar[t][i] = sum;
				betaPrime[t][i] = C[t] * sum;
			}
		}
		double sum = 0;
		for(int t=0; t<T; t++){
			sum += Math.log(C[t]);
		}
		return sum;
	}

	//尤度の算出
	public double getLikelihood(int[] simbol){
		int T = simbol.length;
		double[] C = new double[T+1];
		double[][] alphaPrime = new double[T+1][A.length];
		double l = this.forwardScaling(simbol, alphaPrime, C);
		if(Double.isNaN(l))l=Double.POSITIVE_INFINITY;
		return l;
	}

	private double viterbi(int[] o, int[] sStar){
		int T = o.length;
		double[][] delta = new double[T+1][this.N];
		int[][] phi = new int[T+1][this.N];

		for(int i=0; i<this.N; i++){
			delta[0][i] = this.pi[i];
			phi[0][i] = 0;
		}

		for(int t=1; t<T+1; t++){
			for(int j=0; j<this.N; j++){
				double max = delta[t-1][0] * this.A[0][j] * this.B[0][j][o[t-1]];
				int argmax = 0;
				for(int i=1; i<this.N; i++){
					double v = delta[t-1][i] * this.A[i][j] * this.B[i][j][o[t-1]];
					if(v > max){
						max = v;
						argmax = i;
					}
				}
				delta[t][j] = max;
				phi[t][j] = argmax;
			}
		}

		double pStar = delta[T][0];
		double pstarmax = delta[T][0];
		int sstarmax = 0;
		for(int i=0; i<this.N; i++){
			if(delta[T][i] > pstarmax){
				pstarmax = delta[T][i];
				sstarmax = i;
			}
		}
		sStar[T] = sstarmax;
		pStar = pstarmax;

		for(int t=T-1; t>=0; t--){
			sStar[t] = phi[t+1][sStar[t+1]];
		}
		return pStar;
	}

	//複数のデータを学習するbaumWelchアルゴリズム
	private double baumWelchBatch(int[][] o){
		double[] piHat = new double[this.N];
		double[][] aHat = new double[this.N][this.N];
		double[][][] bHat = new double[this.N][this.N][this.L];
		
		//ファイル毎に学習
		for(int f=0; f<o.length; f++){
			int T = o[f].length;
			double[] C = new double[T+1];
			double[][] alphaPrime = new double[T+1][A.length];
			double[][] betaPrime = new double[T+1][A.length];

			//フォワードバックワードアルゴリズム
			this.forwardScaling(o[f], alphaPrime, C);
			this.backwardScaling(o[f], betaPrime, C);

			//初期状態確率の算出
			for(int i=0; i<this.N; i++){
				piHat[i] += alphaPrime[0][i] * betaPrime[0][i];
			}
			
			//状態遷移確率の算出
			for(int i=0; i<this.N; i++){
				for(int j=0; j<this.N; j++){
					double sumA=0, sumB=0;
					for(int t=0; t<T; t++){
						sumA += alphaPrime[t][i]*this.A[i][j]*this.B[i][j][o[f][t]]*betaPrime[t+1][j];
						sumB += alphaPrime[t][i]*betaPrime[t][i] / C[t];
					}
					//加算するようにする
					aHat[i][j] += sumA / sumB;

				}
			}

			//出力確率の算出
			for(int i=0; i<this.N; i++){
				for(int j=0; j<this.N; j++){
					for(int v=0; v<this.L; v++){
						double sumA=0, sumB=0;
						for(int t=0; t<T; t++){
							double del = o[f][t] == v ? 1.0 : 0.0;
							sumA += del * alphaPrime[t][i]*this.A[i][j]*this.B[i][j][o[f][t]]*betaPrime[t+1][j];
							sumB += alphaPrime[t][i] * this.A[i][j] * this.B[i][j][o[f][t]]*betaPrime[t+1][j];
						}
						//加算するようにする
						if(sumB != 0.0)bHat[i][j][v] += sumA / sumB;

					}
				}
			}
			
		}
		
		//状態遷移確率等の更新値の平均を取る
		for(int i=0; i<this.N; i++){
			piHat[i] /= (double)o.length;
			for(int j=0; j<this.N; j++){
				aHat[i][j] /= (double)o.length;
				for(int l=0; l<this.L; l++){
					bHat[i][j][l] /= (double)o.length;
				}
			}
		}
		
		//パラメータの更新
		this.A = aHat;
		this.B = bHat;
		this.pi = piHat;
		
		//平均尤度の算出
		double like = 0;
		for(int f=0; f<o.length; f++){
			like += this.getLikelihood(o[f]);
		}
		return like / (double)o.length;
	}
	
	//HMMの学習
	public static HMM train(int[][] o, int N, int L, int EPOCH_MAX, boolean showLog){
		/*	input:
		 * 		o			学習ベクトル
		 * 		N			状態数
		 * 		L			シンボル数
		 * 		EPOCH_MAX	学習回数
		 * 		showLog		trueならば学習時に尤度を表示
		 * 	
		 * 	output:
		 * 		HMM
		 */
		HMM hmm = new HMM(makeA(N), makeB(N, L), makePi(N));
		for(int epoch=0; epoch<EPOCH_MAX; epoch++){
			double err = hmm.baumWelchBatch(o);
			if(showLog)System.out.println(err);
		}
		return hmm;
	}
	
	//HMMの学習
	public static HMM train(int[][] o, int N, int L, double threshold, boolean showLog){
		/*	input:
		 * 		o			学習ベクトル
		 * 		N			状態数
		 * 		L			シンボル数
		 * 		threshold	学習収束条件
		 * 		showLog		trueならば学習時に尤度を表示
		 * 	
		 * 	output:
		 * 		HMM
		 */
		HMM hmm = new HMM(makeA(N), makeB(N, L), makePi(N));
		double err = Double.MAX_VALUE;
		while(true){
			double _err = hmm.baumWelchBatch(o);
			if(showLog)System.out.println(err);
			if(Math.abs(err-_err) < threshold)break;
			err = _err;
		}
		return hmm;
	}
	
	//状態遷移配列の生成
	private static double[][] makeA(int N){
		double[][] a = new double[N][N];
		for(int i=0; i<N; i++){
			for(int j=0; j<N; j++){
				a[i][j] = 1.0 / (double)(N-i);
			}
		}
		return a;
	}

	//出力確率配列の生成
	private static double[][][] makeB(int N, int L){
		double[][][] b = new double[N][N][L];
		for(int i=0; i<N; i++){
			for(int j=i; j<N; j++){
				for(int l=0; l<L; l++){
					b[i][j][l] = 1.0 /(double)L;
				}
			}
		}
		return b;
	}
	
	//初期状態確率配列の生成
	private static double[] makePi(int N){
		double[] pi = new double[N];
		pi[0] = 1.0;
		return pi;
	}
}

HMMの学習は次のように行います

//学習データ
int[][] o = {
	{0,0,0,0,1,1,1,1,1,2,2,2,2,2,2,2},
	{0,0,1,1,1,1,1,2,2,2,2,2,2},
	{0,0,0,0,0,0,0,1,1,1,2,2,2,2,2,2},
};
//状態数
int N = 3;
//シンボル数
int L = 3;
//学習回数
int EPOCH_MAX = 100;
HMM hmm = HMM.train(o, N, L, EPOCH_MAX, true);

また、新たな信号に対して尤度を求める際には次のようにします

int[] testO = {0,0,0,1,1,1,1,1,1,1,1,1,2,2,2};
double like = hmm.getLikelihood(testO);
System.out.println("尤度\t"+like);

これでアンダフロー等のHMMを実装する際の諸問題を解決できたので、次回はいよいよ手書き数字認識等に適用したいと思います。