【Java】隠れマルコフモデルによる手書き数字認識
はじめに
手書き数字認識をタスクとして隠れマルコフモデル(HMM)の動作確認を行います。手書き数字認識は文字認識の中では単純な識別問題なのですが、タスクの困難性が想像しやすくサンプルデータが用意しやすい利点があります。
隠れマルコフモデル
隠れマルコフモデルは統計的に時系列をモデル化する手法であり、非決定性確率有限オートマトンとして定義できます。HMMを定義するパラメータは状態遷移確率,出力確率、初期状態確率あり、状態遷移確率は状態から状態に遷移する確率をの行列にまとめたものです。また出力確率は状態からに遷移する際に、HMMがあるシンボルを出力する確率です。この際に出力値が連続値の場合連続分布のHMMとなり、シンボル(離散的な記号)を出力する場合離散型のHMMとなります。初期状態確率は字面通りにHMMに時系列を与える際に、状態が選択される確率を次元数のベクトルでまとめたものです。
以下に離散型のHMMのパラメータを示します。
HMMに与える時系列長
HMMに与える時系列
状態数
記号の種類
状態集合
時刻[の状態
状態番号
出力記号の集合
また、状態遷移確率や出力確率、初期状態確率は以下に定義されます。
HMMはこれらのパラメータを用いて次のように表記します。
隠れマルコフモデルで出来る事
隠れマルコフモデルは他のアルゴリズムと組み合わせる事で様々な事が出来ます。その中でもよく適用される問題は次の3つがあります。
- 尤度評価
観測系列とが既知の時、時系列Oがモデルλから出力される確率を求めます。この時は尤度と呼びます。主にHMMを識別問題に使用し、新たな時系列がどのクラスに分類されるか評価する際に使用されます。尤度を算出する際はフォワードアルゴリズムを使用します。
- モデルの学習
既知の学習用の時系列を用いて、未知のモデルのパラメータに対し尤度を最大となるように推定します。先の尤度評価を行う際にはまず尤度が最大となるパラメータを推定する必要があります。パラメータの推定にはEMアルゴリズムの一つであるバウムウェルチアルゴリズムを使用します。
- 状態系列の推定
既知のモデルが時系列を出力する時、モデルの中でどのような状態変化を経たか推定します。推定にはビタビアルゴリズムを使用します。今回はビタビアルゴリズムを使用しない為、説明は次の機会に回します。
Forwardアルゴリズム
フォワードアルゴリズムとはが時系列を生成する尤度を求める問題を効率的に計算する動的計画法の一つです。からに向かって処理していくため前向きアルゴリズムと呼ばれています。動作は次のようになります。
手書き文字認識
手書き数字認識は人の手によって書かれた文字がどの数字であるのか識別するパターン認識問題です。機械学習の学習性能や分類器の性能評価のタスクとしてよく利用されています。手書き文字の特徴ベクトルにはオフライン形式とオンライン形式があります。オフライン形式では手書き文字を画像として扱い、オンライン形式では手書き文字を運筆動作として扱います。具体的な例を出すと、オンライン手書き文字を取得する際にはペンタブレット等を利用し、筆記の際に移動するペンの座標値等を時系列データとして取得します。オフライン形式はオンライン形式に比べ既に紙などに筆記された文字を特徴として使用できる利点があり、オンライン形式の特徴量はオフラインでは得られない書き順等を特徴量として利用できます。
今回はHMMの動作確認のタスクとして文字認識を使用する為、時系列特徴量であるオンライン形式の特徴量を利用します。
手書き文字の収集
手書き文字の学習、テストデータは以下の記事に載せたGUIプログラムで収集しました。emoson.hateblo.jp
本当はたくさんの被験者を募り様々な運筆情報を収集したほうが、統計モデルとしての強みが活かされるのですが、今回はぼく一人で行いました。
筆記は0から9までの数字をそれぞれ40回筆記し、学習データとテストデータを完全に分離するオープンテストを行いました。また、サンプルデータが非常に少ない為、識別率を求める際には交差確認法でデータを分割して識別率の平均を求めました。
実験プログラム
使用したプログラムを以下に掲載します。また、同じパッケージにHMMクラスがあります。
HMMクラスは以前実装したものを使用しました。emoson.hateblo.jp
package emoson.hmm; import java.io.File; import java.io.FileInputStream; import java.io.FileNotFoundException; import java.io.IOException; import java.util.ArrayList; import java.util.List; public class Recognition { //筆跡情報の量子化 移動角度に対しL分割する public static int[] getSinbol(double[][] vec, int T, int L){ //移動角度についてL分割に量子化する double[][] sepVec = new double[T+1][2]; for(int t=0; t<T+1; t++){ double sumX=0, sumY=0; for(int _t=0; _t<(int)(vec.length/(T+1)); _t++){ sumX += vec[t*(int)(vec.length/(T+1))+_t][0]; sumY += vec[t*(int)(vec.length/(T+1))+_t][1]; } sepVec[t][0] = sumX/vec.length/(T+1); sepVec[t][1] = sumY/vec.length/(T+1); } int[] simbol = new int[T]; for(int t=0; t<T; t++){ double rad = Math.atan((sepVec[t+1][1]-sepVec[t][1]) /(sepVec[t+1][0]-sepVec[t][0])); simbol[t] = (int)(rad / (Math.PI/(L/2))); if(simbol[t] < 0)simbol[t] +=L; } return simbol; } //交差確認用のデータを作成 public static int[][][] makeTrainData(int[][] data, int groupCount, int Num){ int L = data.length; int bSize = L / groupCount; List<int[]> testDatas = new ArrayList<int[]>(); List<int[]> trainDatas = new ArrayList<int[]>(); for(int i=0; i<L; i++){ if(Num*bSize > i || (Num+1)*bSize <= i){ testDatas.add(data[i]); } else{ trainDatas.add(data[i]); } } int[][][] Datas = new int[2][][]; Datas[0] = new int[trainDatas.size()][]; Datas[1] = new int[testDatas.size()][]; for(int i=0; i<Datas[0].length; i++)Datas[0][i] = trainDatas.get(i); for(int i=0; i<Datas[1].length; i++)Datas[1][i] = testDatas.get(i); return Datas; } //筆跡データの読み込み public static int[][] getWritingData(String dirName, int T, int L){ File dir = new File("./"+dirName); int[][] o = new int[dir.listFiles().length][]; for(int f=0; f<o.length; f++){ File file = dir.listFiles()[f]; FileInputStream fis; try { fis = new FileInputStream(file); byte[] b = new byte[(int) file.length()]; fis.read(b); fis.close(); String[] line = new String(b).split("\n"); double[][] vec = new double[line.length-1][2]; for(int i=0; i<vec.length; i++){ vec[i][0] = Double.valueOf(line[i].split(",")[0]); vec[i][1] = Double.valueOf(line[i].split(",")[1]); } o[f] = getSinbol(vec, T, L); } catch (FileNotFoundException e) { // TODO 自動生成された catch ブロック e.printStackTrace(); } catch (IOException e) { // TODO 自動生成された catch ブロック e.printStackTrace(); } } return o; } public static void main(String[] args) { //int[][][] data = makeTrainData(getWritingData("./number/"+1, 20, 16), 5, 4); int[][][][] data = new int[10][2][][]; int N = 10; //状態数 int T = 20; //時系列長 int L = 16; //シンボル数 double threshold = 1E-01; //収束条件(小さいな値ほど多く学習する) int crossBatchGroup = 3; //交差グループ数 double[] rate = new double[crossBatchGroup]; for(int currentGroup=0; currentGroup<crossBatchGroup; currentGroup++){ HMM[] hmm = new HMM[10]; //字種毎にHMMを学習 for(int num=0; num<hmm.length; num++){ //System.out.println("字種"+num+"を学習中"); data[num] = makeTrainData(getWritingData("./number/"+num, T, L), crossBatchGroup, currentGroup); hmm[num] = HMM.train(data[num][0], N, L, threshold, false); } double correctAll = 0; //正解数 int dataCount = 0; //テストデータ数 //字種毎に識別 for(int num=0; num<10; num++){ int correct = 0; //この字種に対する正解数 //字種numの全てのテストデータに対して識別 for(int f=0; f<data[num][1].length; f++){ int rec = 0; //識別結果 double like = hmm[0].getLikelihood(data[num][1][f]); //各種HMMに対し尤度を算出 for(int _num=1; _num<10; _num++){ double _like = hmm[_num].getLikelihood(data[num][1][f]); if(_like < like){ like = _like; rec = _num; } } //trueなら正解 if(rec == num)correct++; } correctAll += correct; dataCount += data[num][1].length; //System.out.println("字種"+num+"の正解率\t"+((double)correct/(double)data[num][1].length)); } //識別率の算出 //System.out.println("正解率\t"+((double)correctAll/(double)dataCount)*100+"%"); rate[currentGroup] = (double)correctAll/(double)dataCount; } double aveRate = 0; //識別率 for(int i=0; i<rate.length; i++){ aveRate += rate[0]; } System.out.println("識別率\t"+(aveRate/(double)rate.length)); } }
実行すると次のような結果が得られました
識別率 0.8481481481481481
この結果は手書き数字認識においてはあまり芳しくない結果ですが、概ね字種毎にモデル化ができている事が分かります。
また、モデル数や出力シンボル数等のパラメータを操作し、字種毎の識別率や不正解時の文字と誤って推定された文字などを調べてみると、いろいろな特徴を見ることが出来て面白いですが、今回はHMMの動作確認が目的なので考察はまた次の機会にまとめたいと思います。