深層学習とその他

機械学習したいマン

scikit-learnのSVMを使ってみた

アドベントカレンダー三日目です。
出かけてて投稿できなかったので同日投稿です。
最近scikit-learnのSVMを使ってみたので、それについて書いていきます

必要なライブラリのimport

今回は、
いつも通りnumpyと、
訓練データとテストデータを分ける train_test_split、
学習に使うデータの load_digits、
あとはSVMを使います。

import numpy as np
from sklearn.cross_validation import train_test_split
from sklearn.datasets import load_digits
from sklearn import svm


データの取得と分割

学習に使うデータを読み込み、次にデータを分割します。
digits.dataでデータが、digits.targetでラベルがとれます。

digits = load_digits()
X_train, X_test, y_train, y_test = train_test_split(digits.data, digits.target)


学習と評価

最初は何も考えずにデフォルトのまま学習を行います。
わかりやすくするために、明示的にrbfを指定しています。
clf.fit(x,y)で学習が行われます。

clf = svm.SVC(kernel='rbf')
clf.fit(X_train, y_train)

次に、clf.predict(X_test) のようにすることで、未知のデータに対して分類を行います。
今回は、精度の平均を出すために、labelと比較してmeanを使っています。

np.mean(clf.predict(X_test) == y_test)
>>> 0.51555555555555554


別のカーネルを使う

svm.SVCカーネルを違うものにして、同じように精度を見てみましょう。

clf = svm.SVC(kernel='linear')
clf.fit(X_train, y_train)

np.mean(clf.predict(X_test) == y_test)
>>> 0.96444444444444444

このように精度が大きく変化しました。

まとめ

SVCにも膨大な引数があるので、使う際にはドキュメントを一度見てみましょう。
sklearn.svm.SVC — scikit-learn 0.21.dev0 documentation

あと、scikit-learnはお手軽ですごい。