元理系院生の新入社員がPythonとJavaで色々頑張るブログ

プログラミングや機械学習について調べた事を書いていきます

【Java】隠れマルコフモデルによる手書き数字認識

はじめに

 手書き数字認識をタスクとして隠れマルコフモデル(HMM)の動作確認を行います。手書き数字認識は文字認識の中では単純な識別問題なのですが、タスクの困難性が想像しやすくサンプルデータが用意しやすい利点があります。

隠れマルコフモデル

 隠れマルコフモデルは統計的に時系列をモデル化する手法であり、非決定性確率有限オートマトンとして定義できます。HMMを定義するパラメータは状態遷移確率,出力確率、初期状態確率あり、状態遷移確率は状態iから状態jに遷移する確率を N \times N(i,j=1,\dots,N)の行列にまとめたものです。また出力確率は状態iからjに遷移する際に、HMMがあるシンボルを出力する確率です。この際に出力値が連続値の場合連続分布のHMMとなり、シンボル(離散的な記号)を出力する場合離散型のHMMとなります。初期状態確率は字面通りにHMMに時系列を与える際に、状態iが選択される確率を次元数Nのベクトルでまとめたものです。
 以下に離散型のHMMのパラメータを示します。



T HMMに与える時系列長
o_{1},\dots,o_{T} HMMに与える時系列
N 状態数
L 記号の種類
S=\{s\} 状態集合
s_{t} 時刻[tの状態
i,j 状態番号
v=\{v_{1},\dots,v_{L}\} 出力記号の集合

また、状態遷移確率Aや出力確率B、初期状態確率\piは以下に定義されます。
A=\{a_{i,j} | a_{i,j}=P(s_{t+1}=j|s_{t}=i)\}  (i,j=1,\dots,N)
B=\{b_{i,j}(o_{t})=P(o_{t}|s_{t-1}=i,s_{t}=j) \} (i,j=1,\dots,N)
\pi=\{ \pi_{i}| \pi_{i}=P(s_{0}=i) \} (i=1,\dots,N)

HMMはこれらのパラメータを用いて次のように表記します。
 \lambda = (A, B, \pi)

隠れマルコフモデルで出来る事

隠れマルコフモデルは他のアルゴリズムと組み合わせる事で様々な事が出来ます。その中でもよく適用される問題は次の3つがあります。

  • 尤度評価

観測系列 O \lambdaが既知の時、時系列Oがモデルλから出力される確率 P(O|\lambda)を求めます。この時 P(O|\lambda)は尤度と呼びます。主にHMMを識別問題に使用し、新たな時系列がどのクラスに分類されるか評価する際に使用されます。尤度を算出する際はフォワードアルゴリズムを使用します。

  • モデルの学習

既知の学習用の時系列Oを用いて、未知のモデル \lambdaのパラメータ\pi, A, Bに対し尤度を最大となるように推定します。先の尤度評価を行う際にはまず尤度が最大となるパラメータを推定する必要があります。パラメータの推定にはEMアルゴリズムの一つであるバウムウェルチアルゴリズムを使用します。

  • 状態系列の推定

既知のモデル \lambdaが時系列 Oを出力する時、モデルの中でどのような状態変化を経たか推定します。推定にはビタビアルゴリズムを使用します。今回はビタビアルゴリズムを使用しない為、説明は次の機会に回します。

Forwardアルゴリズム

フォワードアルゴリズムとは \lambdaが時系列 Oを生成する尤度を求める問題を効率的に計算する動的計画法の一つです。 t=0からt=Tに向かって処理していくため前向きアルゴリズムと呼ばれています。動作は次のようになります。

  1.  \alpha_{0}(j)=\pi_{j} (j=1,\dots,N)
  2.  \alpha_{t}(j)= \sum_{i=1}^{N} \alpha_{t-1}a_{i,j}b_{i,j}(o_{t})(j=1,\dots,N, t=1,\dots,T)
  3.  P(O|\lambda)=\sum_{i=1}^{N} \alpha_{T}(i)

Backwardアルゴリズム

バックワードアルゴリズムはフォワードアルゴリズムとは逆にt=Tからt=1にむかって 逆向きに計算していきます。動作は次のようになります。

  1.  \beta_{T}(i)=1.0 (i=1,\dots,N)
  2.  \beta_{t}(i)=\sum_{j=1}^{N} a_{i,j}b_{i,j}(o_{t})\beta_{t+1}(j) (i=1,\dots,N , t=T-1,\dots,0)
  3.  P(O|\lambda)=\sum_{i=1}^{N} \pi_{i}\beta_{0}(i)

Forward, Backwardアルゴリズムは確率の積算が続く為、時系列が長くなると値が0に丸まってしまいます。その為、実装する際には対数確率を用いたりスケーリング処理を施します。

BaumWelchアルゴリズム

バウムウェルチアルゴリズムは学習時系列の生成確率がなるべく大きくなるパラメータを推定します。動作は次のようになります。

  1.  \xi_{t}(i,j)=P(s_{t-1}=i, s_{t}=j|O, \lambda) = \frac{\alpha_{t-1}(i)a_{i,j}b_{i,j}(o_{t}) \beta_{t}(j)}{P(O|\lambda)} (t=1,\dots,T)
  2.  \gamma_{t}(j)=P(s_{t}=j|O, \lambda) =\sum_{i=1}^{N} \xi_{t}(i,j) (t=1,\dots,T)
  3.  \hat{\pi_{i}}=\gamma_{0}(i)(i=1,\dots,N)
  4.  \hat{a_{i,j}}=\frac{\sum_{t=1}^{T} \xi_{t}(i,j)}{\sum_{t=1}^{T} \gamma_{t-1}(i)}
  5.  \hat{b}_{i,j}(o_{t})= \frac{\sum_{t=1}^{T} \delta(o_{t}, v_{k})\xi_{t}(i, j)}{\sum_{t=1}^{T} \xi_{t}(i, j)}

出力確率の推定時に分子に出てくる \deltaデルタ関数です。

手書き文字認識

 手書き数字認識は人の手によって書かれた文字がどの数字であるのか識別するパターン認識問題です。機械学習の学習性能や分類器の性能評価のタスクとしてよく利用されています。手書き文字の特徴ベクトルにはオフライン形式とオンライン形式があります。オフライン形式では手書き文字を画像として扱い、オンライン形式では手書き文字を運筆動作として扱います。具体的な例を出すと、オンライン手書き文字を取得する際にはペンタブレット等を利用し、筆記の際に移動するペンの座標値等を時系列データとして取得します。オフライン形式はオンライン形式に比べ既に紙などに筆記された文字を特徴として使用できる利点があり、オンライン形式の特徴量はオフラインでは得られない書き順等を特徴量として利用できます。
今回はHMMの動作確認のタスクとして文字認識を使用する為、時系列特徴量であるオンライン形式の特徴量を利用します。

手書き文字の収集

手書き文字の学習、テストデータは以下の記事に載せたGUIプログラムで収集しました。emoson.hateblo.jp

本当はたくさんの被験者を募り様々な運筆情報を収集したほうが、統計モデルとしての強みが活かされるのですが、今回はぼく一人で行いました。
筆記は0から9までの数字をそれぞれ40回筆記し、学習データとテストデータを完全に分離するオープンテストを行いました。また、サンプルデータが非常に少ない為、識別率を求める際には交差確認法でデータを分割して識別率の平均を求めました。

学習、認証方法

手書き文字は時間毎にマウスの座標値が連続値としてサンプリングされています。前回作成したHMMは離散型のHMMであった為、マウスの移動角を16分割して量子化を行いました。
また、文字毎に合計10個の \lambda_{n}=(A_{n}, B_{n}, \pi_{n}) (n=0,\dots,9)を学習しました。
識別クラスは対象の時系列 Oに対し、 arg\max_{n} \lambda_{n}としました。

実験プログラム

使用したプログラムを以下に掲載します。また、同じパッケージにHMMクラスがあります。
HMMクラスは以前実装したものを使用しました。emoson.hateblo.jp

package emoson.hmm;

import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
public class Recognition {
	//筆跡情報の量子化 移動角度に対しL分割する
	public static int[] getSinbol(double[][] vec, int T, int L){
		//移動角度についてL分割に量子化する
		double[][] sepVec = new double[T+1][2];
		for(int t=0; t<T+1; t++){
			double sumX=0, sumY=0;
			for(int _t=0; _t<(int)(vec.length/(T+1)); _t++){
				sumX += vec[t*(int)(vec.length/(T+1))+_t][0];
				sumY += vec[t*(int)(vec.length/(T+1))+_t][1];
			}
			sepVec[t][0] = sumX/vec.length/(T+1);
			sepVec[t][1] = sumY/vec.length/(T+1);
		}

		int[] simbol = new int[T];
		for(int t=0; t<T; t++){
			double rad = Math.atan((sepVec[t+1][1]-sepVec[t][1]) /(sepVec[t+1][0]-sepVec[t][0]));
			simbol[t] = (int)(rad / (Math.PI/(L/2)));
			if(simbol[t] < 0)simbol[t] +=L;
		}
		return simbol;
	}
	
	//交差確認用のデータを作成
	public static int[][][] makeTrainData(int[][] data, int groupCount, int Num){
		int L = data.length;
		int bSize = L / groupCount;

		List<int[]> testDatas = new ArrayList<int[]>();
		List<int[]> trainDatas = new ArrayList<int[]>();
		for(int i=0; i<L; i++){
			if(Num*bSize > i || (Num+1)*bSize <= i){
				testDatas.add(data[i]);
			}
			else{
				trainDatas.add(data[i]);
			}
		}

		int[][][] Datas = new int[2][][];
		Datas[0] = new int[trainDatas.size()][];
		Datas[1] = new int[testDatas.size()][];
		for(int i=0; i<Datas[0].length; i++)Datas[0][i] = trainDatas.get(i);
		for(int i=0; i<Datas[1].length; i++)Datas[1][i] = testDatas.get(i);

		return Datas;
	}
	//筆跡データの読み込み
	public static int[][] getWritingData(String dirName, int T, int L){
		File dir = new File("./"+dirName);
		int[][] o = new int[dir.listFiles().length][];

		for(int f=0; f<o.length; f++){
			File file = dir.listFiles()[f];
			FileInputStream fis;
			try {
				fis = new FileInputStream(file);
				byte[] b = new byte[(int) file.length()];
				fis.read(b);
				fis.close();
				String[] line = new String(b).split("\n");
				double[][] vec = new double[line.length-1][2];
				for(int i=0; i<vec.length; i++){
					vec[i][0] = Double.valueOf(line[i].split(",")[0]);
					vec[i][1] = Double.valueOf(line[i].split(",")[1]);
				}
				o[f] = getSinbol(vec, T, L);

			} catch (FileNotFoundException e) {
				// TODO 自動生成された catch ブロック
				e.printStackTrace();
			} catch (IOException e) {
				// TODO 自動生成された catch ブロック
				e.printStackTrace();
			}
		}
		return o;
	}

	public static void main(String[] args) {
		//int[][][] data = makeTrainData(getWritingData("./number/"+1, 20, 16), 5, 4);
		int[][][][] data = new int[10][2][][];
		int N = 10;							//状態数
		int T = 20;							//時系列長
		int L = 16;							//シンボル数
		double threshold = 1E-01;			//収束条件(小さいな値ほど多く学習する)
		int crossBatchGroup = 3;			//交差グループ数
		double[] rate = new double[crossBatchGroup];	
		
		for(int currentGroup=0; currentGroup<crossBatchGroup; currentGroup++){
			HMM[] hmm = new HMM[10];
			//字種毎にHMMを学習
			for(int num=0; num<hmm.length; num++){
				//System.out.println("字種"+num+"を学習中");
				data[num] = makeTrainData(getWritingData("./number/"+num, T, L), crossBatchGroup, currentGroup);
				hmm[num] = HMM.train(data[num][0], N, L, threshold, false);		
			}

			double correctAll = 0;				//正解数
			int dataCount = 0;					//テストデータ数

			//字種毎に識別
			for(int num=0; num<10; num++){
				int correct = 0;				//この字種に対する正解数

				//字種numの全てのテストデータに対して識別
				for(int f=0; f<data[num][1].length; f++){
					int rec = 0;				//識別結果
					double like = hmm[0].getLikelihood(data[num][1][f]);
					//各種HMMに対し尤度を算出
					for(int _num=1; _num<10; _num++){
						double _like = hmm[_num].getLikelihood(data[num][1][f]);
						if(_like < like){
							like = _like;
							rec = _num;
						}
					}
					//trueなら正解
					if(rec == num)correct++;
				}
				correctAll += correct;
				dataCount += data[num][1].length;
				//System.out.println("字種"+num+"の正解率\t"+((double)correct/(double)data[num][1].length));
			}
			//識別率の算出
			//System.out.println("正解率\t"+((double)correctAll/(double)dataCount)*100+"%");
			rate[currentGroup] = (double)correctAll/(double)dataCount;
		}
		double aveRate = 0;			//識別率
		for(int i=0; i<rate.length; i++){
			aveRate += rate[0];
		}
		System.out.println("識別率\t"+(aveRate/(double)rate.length));
	}
}


実行すると次のような結果が得られました

識別率	0.8481481481481481

この結果は手書き数字認識においてはあまり芳しくない結果ですが、概ね字種毎にモデル化ができている事が分かります。
また、モデル数や出力シンボル数等のパラメータを操作し、字種毎の識別率や不正解時の文字と誤って推定された文字などを調べてみると、いろいろな特徴を見ることが出来て面白いですが、今回はHMMの動作確認が目的なので考察はまた次の機会にまとめたいと思います。