世に機械学習,パターン認識に関する書式は多くありますが,ユーザ操作やセンシングデータを元に リアルタイムにパターン認識を行うために必要な参考書があまりないので,ここに基礎的な内容を実際に プログラミングしてみることで対話設計において必要な入力処理の基礎知識を身につけるのが目的です. 実装が容易なkNN,単純ベイズを利用して同じ学習データからリアルタイムに入力データの 識別をしてみます.このページと合わせて別ページのインタラクションデザインの為のweka入門もお薦めします.
https://docs.google.com/spreadsheets/d/11eyes2zlBXisAPCbIg95_LazhKp5fpzvkSUyaCBw7nI/edit?usp=sharingまずは下準備をします.以後OFを利用してコードを記述していきます.下に示すような学習データを準備し, 画面上にプロット(点をうつ)し,マウス位置も同様にプロットします.
データ番号 | クラス | 特徴量1 \(x_k\) | 特徴量2 \(y_k\) |
---|---|---|---|
k | 性別 | 身長[cm] | 体重[kg] |
1 | M | 164 | 60 |
2 | M | 178 | 80 |
3 | M | 168 | 69 |
4 | M | 170 | 58 |
5 | M | 165 | 68 |
6 | F | 160 | 47 |
7 | F | 155 | 45 |
8 | F | 164 | 60 |
9 | F | 170 | 62 |
10 | F | 148 | 40 |
ofApp.h
#pragma once #include "ofMain.h" class ofApp : public ofBaseApp{ public: void setup(); void update(); void draw(); };
ofApp.cpp
#include "ofApp.h" int data[10][3] = { 'M', 164, 60, 'M', 178, 80, 'M', 168, 69, 'M', 170, 58, 'M', 165, 68, 'F', 160, 47, 'F', 155, 45, 'F', 164, 60, 'F', 170, 62, 'F', 148, 40, }; //-------------------------------------------------------------- void ofApp::setup(){ ofBackground(100, 100, 100); ofHideCursor(); } //-------------------------------------------------------------- void ofApp::update(){ } void ofApp::draw(){ ofSetColor(0); for( int i = 0; i < 10; i++ ){ if( data[i][0] == 'F' ){ ofSetColor(255,0,0); } else{ ofSetColor(0,0,255); } ofCircle(data[i][1], data[i][2], 1.0); } ofDrawBitmapString("("+ ofToString(ofGetMouseX())+ ","+ ofToString(ofGetMouseY())+")", 20,20); ofCircle(ofGetMouseX(), ofGetMouseY(), 1.0); }
馬場が適当に入れたデータです.二次元の特徴量を持つ学習データになります.パターン認識では性別をクラス(ラベル), 身長,体重をそれぞれ特徴量と呼びます.main.cppの箇所で画面サイズを300x300程度にして実行してみましょう. すると,現在のマウス位置を緑色,男性データを青,女性データを赤でプロットするものになります. ここで,マウス位置を入力データとして考えてください.つまりx座標が身長,y座標が体重のデータになるわけです. マウスを動かすことで入力データの特徴量が変化します.この時に既存の学習データから,入力データが 男性なのか女性なのかを判断するプログラムをこれから記述していきます.学習データ自体はあまり面白みがありませんが, 解りやすい学習データということで勘弁ください.
まずは得られた学習データと入力データをそのまま比べる手法です.学習データにいろいろ手を加えず, そのまま入力データと比較するものなので機械側でなにか学習しているわけではありません.入力データに 対して「こういう事例が近くにあるから,入力データはこういうものなんでない?」という流れになります. まずはk-NN(k-Nearest Neighbor)を理解,コーディングしてみます. 実際の近さには,2点間の距離を計算すれば良いだけです.この計算方法にはいくつかの手法がありますが, 馴染みのあるユークリッド距離を計算してみます.今回は2次元ですが,一般に次元をnとした場合, のように計算式を示すことができます.入力データと学習データのすべての距離を計算し最も距離が近いものを 第1候補,2番めに距離が近いものを第2候補等とし,k=1のときは第1候補のみ,k=3のときは第3候補まで調べ, もっとも近い3つのデータから多数決をとる,といった手法になります. 今回はもっともシンプルにk=1の時の計算を行ってみます.マウス操作を行い,もっとも近い学習データが 男性の場合は青,女性の場合は赤で表示しなおしてみましょう.今回の事例では計算式は \( d_k = \sqrt[]{(x-x_k)^2+(y-y_k)^2} \) のように表せます.\(x,y\)はマウス座標,\(x_k, y_k\)は学習データのそれぞれ身長体重です. またkは1〜nまでの学習データの番号を指します.
ofApp.h
#pragma once #include "ofMain.h" class ofApp : public ofBaseApp{ public: void setup(); void update(); void draw(); int num_candidate; };
ofApp.cpp
#include "ofApp.h" int data[10][3] = { 'M', 164, 60, 'M', 178, 80, 'M', 168, 69, 'M', 170, 58, 'M', 165, 68, 'F', 160, 47, 'F', 155, 45, 'F', 164, 60, 'F', 170, 62, 'F', 148, 40, }; void ofApp::setup(){ ofBackground(100, 100, 100); ofHideCursor(); } void ofApp::update(){ int min = 10000; float d; for(int i = 0; i < 10; i++ ){ d = sqrt( (data[i][1]-ofGetMouseX())*(data[i][1]-ofGetMouseX())+ (data[i][2]-ofGetMouseY())*(data[i][2]-ofGetMouseY()) ); if( d < min ){ min = d; num_candidate = i; } } } void ofApp::draw(){ ofSetColor(0); for( int i = 0; i < 10; i++ ){ if( data[i][0] == 'F' ){ ofSetColor(255,0,0); } else{ ofSetColor(0,0,255); } ofCircle(data[i][1], data[i][2], 1.0); } if( data[num_candidate][0] == 'M' ){ ofSetColor(0, 0, 255); } else{ ofSetColor(255, 0, 0); } ofDrawBitmapString("("+ ofToString(ofGetMouseX())+ ","+ ofToString(ofGetMouseY())+")", 20,20); ofCircle(ofGetMouseX(), ofGetMouseY(), 1.0); }
kNNの場合は得られた学習データにいろいろ手を加えることなく,そのまま入力データとの近さを計算している だけでした.このようなやり方以外に,入力データを識別するためには一般的な手法として統計的アプローチが あります.つまり各特徴量は統計的な分布に従うと仮定し,入力データの起こりうる確率を計算することで, 識別を行うやり方です.ここではナイーブベイズ(単純ベイズ)と呼ばれる,特徴量が正規分布に従う前提での パターン識別事例を実装してみます.
すこし確率の話になります.ここで求めようとしているのは,入力データがあった時にそれが男性であるか, 女性であるかを確率的に数値にすることです.これができれば生起確率の高い方を識別結果とすればよいこと になります.この場合,学習データを元に分布関数(確率密度関数)を作成し,入力データはその分布において どの程度の生起確率を持つのかを計算すればよさそうです.平均値や分散等を学習データから計算することは 出来ますが,これらを利用して入力データがどの程度の可能性で有り得る値なのかを考えたい場合, 一般的には正規分布等を用いることで計算ができます.正規分布による確率密度関数は $ p(x) = \dfrac {1} {\sqrt {2\pi }\sigma }exp(\dfrac{(x-\mu)^2}{2\sigma^2}) $ と表すことができるので,この $p(x)$を利用して,入力データの各クラス毎の生起確率を計算すればOKです. なお $\mu, \sigma$はそれぞれ平均値と標準偏差となります.
まずは各クラス毎の平均と標準偏差を計算します.つまり,各クラスとは男性または女性のことで, それぞれの平均身長,体重,及び標準偏差を計算することになります.実際に計算してみます. $\mu_M, \sigma_M$を男性の学習データ平均体重と標準偏差.$\mu_F, \sigma_F$を女性の学習データ 平均体重と標準偏差とします.平均値と標準偏差は次の式の通りです. \[ \mu = \dfrac{1}{N}\sum_{k=1}^n x_i \\ \sigma = \sqrt{ \dfrac{1}{N}\sum_{k=1}^n (x_i-\mu)^2} \]
以上から実際に各偏差値,標準偏差を計算すると下記のとおりになります. エクセルで実際に計算してみましょう.
データ番号 | クラス | 特徴量1 \(x_k\) | 特徴量2 \(y_k\) |
---|---|---|---|
k | 性別 | 身長[cm] | 体重[kg] |
1 | M | 164 | 60 |
2 | M | 178 | 80 |
3 | M | 168 | 69 |
4 | M | 170 | 58 |
5 | M | 165 | 68 |
平均 | 169 | 67 | |
標準偏差 | 4.98 | 7.80 | |
6 | F | 160 | 47 |
7 | F | 155 | 45 |
8 | F | 164 | 60 |
9 | F | 170 | 62 |
10 | F | 148 | 40 |
平均 | 159.4 | 50.8 | |
標準偏差 | 7.53 | 8.66 |
事象A,Bに関する同時確率の定義は次のとおりです. $ P(A,B) = P(A)P(B|A)=P(B)P(A|B) $ つまり,事象Aが起きた後,事象Bが起きる確率は $ P(A|B)=\dfrac{P(B|A)P(A)}{P(B)} $ で表すことができます.いま入力データはマウス座標,学習データは男性or女性に クラスがわかれています.座標入力が来た際,それが男性である確率は,実際に学習 データ内にまったく同じ入力データがあれば直接計算できますが,現実的にまったく 同じデータが必ず学習データ内にあるとは限りません.そこでこのベイズの定理を 利用することで,事象B(学習データ入力)が起きた結果それが事象A(男性or女性)で ある確率を求めることができます.また実際にはP(B)は男性,女性に入力に応じて 一定なので比較する際には無視することができます.つまり, $ P(Male| (x,y)), P(Female| (x,y)) $ のどちらが大きいかを,計算できるようになります.例えばMaleは下記のように計算できます. \[ P(Male|(x,y)) = P(Male)P((x,y)|Male)) = P(Male)P(x|Male)P(y||Male) \] $ P(Male)=5/10 $ であり,正規分布による確率密度関数を利用すると \[ P(x|Male) = \dfrac {1} {\sqrt {2\pi }\sigma_{M身長}}exp(\dfrac{(x-\mu_{M身長})^2}{2\sigma_{M身長}^2}) \\ P(y|Male) = \dfrac {1} {\sqrt {2\pi }\sigma_{M体重}}exp(\dfrac{(x-\mu_{M体重})^2}{2\sigma_{M体重}^2}) \] にて計算できます.以上でナイーブベイズによる識別器を実装する準備ができたので,実際に プログラムを記述してみます.
ofApp.h
#pragma once #include "ofMain.h" class ofApp : public ofBaseApp{ public: void setup(); void update(); void draw(); float u_MT; float u_MW; float u_FT; float u_FW; float s_MT; float s_MW; float s_FT; float s_FW; float p_M; float p_F; };
ofApp.cpp
#include "ofApp.h" int data[10][3] = { 'M', 164, 60, 'M', 178, 80, 'M', 168, 69, 'M', 170, 58, 'M', 165, 68, 'F', 160, 47, 'F', 155, 45, 'F', 164, 60, 'F', 170, 62, 'F', 148, 40, }; void ofApp::setup(){ ofBackground(100, 100, 100); ofHideCursor(); u_MT = u_MW = u_FT = u_FW = 0; for( int i = 0; i < 5; i++ ){ u_MT = u_MT + data[i][1]; u_MW = u_MW + data[i][2]; } u_MT = u_MT/5; u_MW = u_MW/5; for( int i = 0; i < 5; i++ ){ s_MT = s_MT + (data[i][1]-u_MT)*(data[i][1]-u_MT); s_MW = s_MW + (data[i][2]-u_MW)*(data[i][2]-u_MW); } s_MT = sqrt(s_MT/5); s_MW = sqrt(s_MW/5); for( int i = 5; i < 10; i++ ){ u_FT = u_FT + data[i][1]; u_FW = u_FW + data[i][2]; } u_FT = u_FT/5; u_FW = u_FW/5; for( int i = 5; i < 10; i++ ){ s_FT = s_FT + (data[i][1]-u_FT)*(data[i][1]-u_FT); s_FW = s_FW + (data[i][2]-u_FW)*(data[i][2]-u_FW); } s_FT = sqrt(s_FT/5); s_FW = sqrt(s_FW/5); } void ofApp::update(){ p_M = (5/10.0)* ( (1/(sqrt(2*3.14)*s_MT)) *exp(-1*(pow((ofGetMouseX()-u_MT),2))/(2.0*s_MT*s_MT)))* ( (1/(sqrt(2*3.14)*s_MW)) *exp(-1*(pow((ofGetMouseY()-u_MW),2))/(2.0*s_MW*s_MW))); p_F = (5/10.0)* (1/(sqrt(2*3.14)*s_FT)*exp(-1*(pow(ofGetMouseX()-u_FT,2))/(2*s_FT*s_FT)))* (1/(sqrt(2*3.14)*s_FW)*exp(-1*(pow(ofGetMouseY()-u_FW,2))/(2*s_FW*s_FW))); } void ofApp::draw(){ ofSetColor(0); for( int i = 0; i < 10; i++ ){ if( data[i][0] == 'F' ){ ofSetColor(255,0,0); } else{ ofSetColor(0,0,255); } ofCircle(data[i][1], data[i][2], 1.0); } if( p_M > p_F ){ ofSetColor(0,0,255); } else if( p_M < p_F){ ofSetColor(255,0,0); } else{ ofSetColor(0,255,0); } ofDrawBitmapString("("+ ofToString(ofGetMouseX())+ ","+ ofToString(ofGetMouseY())+")", 20,20); ofDrawBitmapString("u_MT="+ofToString(u_MT), 20,200); ofDrawBitmapString("u_MW="+ofToString(u_MW), 150,200); ofDrawBitmapString("s_MT="+ofToString(s_MT), 20,220); ofDrawBitmapString("s_MW="+ofToString(s_MW), 150,220); ofDrawBitmapString("u_FT="+ofToString(u_FT), 20,240); ofDrawBitmapString("u_FW="+ofToString(u_FW), 150,240); ofDrawBitmapString("s_FT="+ofToString(s_FT), 20,260); ofDrawBitmapString("s_FW="+ofToString(s_FW), 150,260); ofDrawBitmapString("P_M="+ofToString((double)p_M), 20,280); ofDrawBitmapString("P_F="+ofToString((double)p_F), 150,280); ofCircle(ofGetMouseX(), ofGetMouseY(), 1.0); }
いかがでしたでしょうか?k-NNは非常に直感的なアルゴリズムである一方,データ量が増大する程 処理に時間がかかります.またナイーブベイズではデータが増大した場合も平均値と標準偏差を計算するだけなので, 処理時間は変わることはないでしょう.ただし,統計データに基づくため,少し変わった特徴量をもつ同クラスのデータが 埋もれてしまう欠点があります. この他にもSVMやニューラルネットワーク,HMM等のアルゴリズムも有りますが,このページでは実装が容易なk-NNと 単純ベイズに関して説明しました.興味をもった人は機械学習に関する書籍は 最初に紹介した教科書以外にもたくさんあるので,いろいろ読んでみるとよいか と思います.