深層学習とその他

機械学習したいマン

scikit-learnのconfusion matrixとclassification reportについて

アドベントカレンダー四日目です。
昨日に引き続き、scikit-learnの話をしていきましょう。

confusion matrix

必要なライブラリのimport

前回同様numpyと、
訓練データとテストデータを分ける train_test_split、
学習に使うデータの load_digits、
SVM
新たにconfusion matrixとclassification reportを読み込みます。

import numpy as np
from sklearn.cross_validation import train_test_split
from sklearn.datasets import load_digits
from sklearn import svm
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report


データの取得と分割

同じ流れですが、学習に使うデータを読み込み、次にデータを分割します。
digits.dataでデータが、digits.targetでラベルがとれます。

digits = load_digits()
X_train, X_test, y_train, y_test = train_test_split(digits.data, digits.target)


学習と評価

とりあえず学習します。

clf = svm.SVC(kernel='linear')
clf.fit(X_train, y_train)

そして、未知データを分類します。

predict_label = clf.predict(X_test)


confusion matrixの表示

confusion matrixを表示します。
これを見ることで、正しいラベルにどれくらい分類されたかが確認できます。
こちらのサイトが参考になるので見てみてください。
混同行列(Confusion Matrix)

print(confusion_matrix(predict_label, y_test))

[[40  0  0  0  0  0  1  0  0  0]
 [ 0 39  0  0  0  0  0  0  2  0]
 [ 0  0 47  0  0  0  0  0  0  0]
 [ 0  0  0 45  0  0  0  0  0  0]
 [ 0  0  0  0 44  0  0  2  0  0]
 [ 0  0  0  1  0 42  0  0  0  0]
 [ 0  0  0  0  0  0 43  0  0  0]
 [ 0  0  0  0  0  0  0 42  0  0]
 [ 0  3  0  1  0  0  0  0 53  2]
 [ 0  0  0  1  0  1  0  1  1 39]]


classification report

今度はclassification reportをみてみましょう。
予測結果の評価に使います。
F値 - 機械学習の「朱鷺の杜Wiki」

print(classification_report(predict_label, y_test))

             precision    recall  f1-score   support

          0       1.00      0.98      0.99        41
          1       0.93      0.95      0.94        41
          2       1.00      1.00      1.00        47
          3       0.94      1.00      0.97        45
          4       1.00      0.96      0.98        46
          5       0.98      0.98      0.98        43
          6       0.98      1.00      0.99        43
          7       0.93      1.00      0.97        42
          8       0.95      0.90      0.92        59
          9       0.95      0.91      0.93        43

avg / total       0.96      0.96      0.96       450


まとめ

いい感じの関数がscikit-learnだと一瞬で使えてすごいですね。
今度はこれを図で描画したいですね。