scikit-learnのSVMを使ってみた
アドベントカレンダー三日目です。
出かけてて投稿できなかったので同日投稿です。
最近scikit-learnのSVMを使ってみたので、それについて書いていきます
必要なライブラリのimport
今回は、
いつも通りnumpyと、
訓練データとテストデータを分ける train_test_split、
学習に使うデータの load_digits、
あとはSVMを使います。
import numpy as np from sklearn.cross_validation import train_test_split from sklearn.datasets import load_digits from sklearn import svm
データの取得と分割
学習に使うデータを読み込み、次にデータを分割します。
digits.dataでデータが、digits.targetでラベルがとれます。
digits = load_digits() X_train, X_test, y_train, y_test = train_test_split(digits.data, digits.target)
学習と評価
最初は何も考えずにデフォルトのまま学習を行います。
わかりやすくするために、明示的にrbfを指定しています。
clf.fit(x,y)で学習が行われます。
clf = svm.SVC(kernel='rbf')
clf.fit(X_train, y_train)
次に、clf.predict(X_test) のようにすることで、未知のデータに対して分類を行います。
今回は、精度の平均を出すために、labelと比較してmeanを使っています。
np.mean(clf.predict(X_test) == y_test)
>>> 0.51555555555555554
別のカーネルを使う
svm.SVCのカーネルを違うものにして、同じように精度を見てみましょう。
clf = svm.SVC(kernel='linear') clf.fit(X_train, y_train) np.mean(clf.predict(X_test) == y_test) >>> 0.96444444444444444
このように精度が大きく変化しました。
まとめ
SVCにも膨大な引数があるので、使う際にはドキュメントを一度見てみましょう。
sklearn.svm.SVC — scikit-learn 0.21.dev0 documentation
あと、scikit-learnはお手軽ですごい。