この文書の現在のバージョンと選択したバージョンの差分を表示します。
両方とも前のリビジョン 前のリビジョン 次のリビジョン | 前のリビジョン | ||
4_モデル学習_keras [2017/11/09] adash333 [1. モデルの設定] |
4_モデル学習_keras [2018/10/07] (現在) |
||
---|---|---|---|
ライン 14: | ライン 14: | ||
<code> | <code> | ||
#4 モデル学習(Keras) | #4 モデル学習(Keras) | ||
- | history = model.fit(X_train, y_train, | + | history = model.fit(x_train, y_train, |
batch_size=batch_size, epochs=epochs, | batch_size=batch_size, epochs=epochs, | ||
- | verbose=1, validation_data=(X_test, y_test)) | + | verbose=1, validation_data=(x_test, y_test)) |
</code> | </code> | ||
- | 作成中 | ||
===== 開発環境 ===== | ===== 開発環境 ===== | ||
ライン 39: | ライン 38: | ||
{{:pasted:20171106-063759.png}} | {{:pasted:20171106-063759.png}} | ||
- | ==== 1. モデルの設定 ==== | + | ==== 1. モデルの学習 ==== |
以下のコードを入力して、Shift + Enterを押します。 | 以下のコードを入力して、Shift + Enterを押します。 | ||
ライン 58: | ライン 57: | ||
- | 順に解説していきます。 | + | model.fit()関数により、モデルの学習を実行しています。 |
+ | 引数については、KerasのDocumentationそのままとなりますが、以下に記載します。 | ||
- | KerasでのModel学習の手順は上記でおしまいです。 | + | ---- |
- | <wrap hi> | + | x: 入力データ,Numpy 配列,あるいは Numpy 配列のリスト (モデルに複数の入力がある場合)\\ |
- | 初めての場合は、次は、とりあえず、</wrap>[[(4)モデル学習(Keras)]]<wrap hi>に進んでください。 | + | y: ラベル,Numpy 配列. |
- | </wrap> | + | |
+ | batch_size: 整数.設定したサンプル数ごとに勾配の更新を行います。今回は、<wrap hi>[[(3)モデル設定(Keras)]]</wrap>のところで、batch_size = 128と設定していましたので、128が用いられています。 | ||
+ | |||
+ | epochs: 整数で,モデルを訓練するエポック数。今回は、<wrap hi>[[(3)モデル設定(Keras)]]</wrap>のところで、epochs = 3と設定していましたので、3回学習が行われています。 | ||
+ | |||
+ | verbose: 0とすると標準出力にログを出力しません. 1の場合はログをプログレスバーで標準出力,2 の場合はエポックごとに1行のログを出力します | ||
+ | |||
+ | validation_data=(x_test, y_test): ホールドアウト検証用データとして使うデータのタプル (x_val, y_val) か (x_val, y_val, val_sample_weights)。設定すると validation_split を無視します。 | ||
+ | |||
+ | ---- | ||
+ | |||
+ | |||
+ | KerasでのModel学習の手順は上記でおしまいです。 | ||
- | ===== kerasで損失関数(=目的関数)の利用方法 ===== | + | 初めての方は、次は、<wrap hi>[[(5)結果の出力(Keras)]]</wrap>に進んでください。 |
- | 作成中 | ||
(参考) | (参考) | ||
- | 損失関数の利用方法について\\ | + | Keras チュートリアル\\ |
- | https://keras.io/ja/losses/\\ | + | sasayabaku |
- | https://keras.io/ja/objectives/ | + | 2017年08月16日に更新\\ |
+ | https://qiita.com/sasayabaku/items/64a01363bcd5c44feb0b | ||
- | 機械学習における誤差関数、損失関数、etcについて\\ | + | ===== kerasのSequentialモデルのfitメソッドについて ===== |
- | http://otasuke.goo-net.com/qa8944219.html | + | |
- | ===== Optimizerについて ===== | + | https://keras.io/ja/models/sequential/\\ |
- | optimizer(最適化)について\\ | + | {{:pasted:20171110-035359.png}} |
- | https://keras.io/ja/optimizers/ | + | |
- | ===== 参考文献 ===== | + | |
- | 初めてKerasプログラミングをやるときの超おすすめ本。\\ | + | |
- | <html> | + | fit()関数は、固定のエポック数でモデルを訓練します。 |
- | <iframe style="width:120px;height:240px;" marginwidth="0" marginheight="0" scrolling="no" frameborder="0" src="//rcm-fe.amazon-adsystem.com/e/cm?lt1=_blank&bc1=000000&IS2=1&bg1=FFFFFF&fc1=000000&lc1=0000FF&t=twosquirrel-22&o=9&p=8&l=as4&m=amazon&f=ifr&ref=as_ss_li_til&asins=4873117585&linkId=13a7db2c19cc5f40d6ab48906de8abd1"></iframe> | + | 戻り値は、History オブジェクト。History.history 属性は、実行に成功したエポックにおける訓練の損失値と評価関数値の記録と,(適用可能ならば)検証における損失値と評価関数値も記録しています。 |
- | | + | model.fit()の返り値を出力を変数に格納すると学習過程のパラメータの推移をプロットできます。 |
- | <iframe style="width:120px;height:240px;" marginwidth="0" marginheight="0" scrolling="no" frameborder="0" src="//rcm-fe.amazon-adsystem.com/e/cm?lt1=_blank&bc1=000000&IS2=1&bg1=FFFFFF&fc1=000000&lc1=0000FF&t=twosquirrel-22&o=9&p=8&l=as4&m=amazon&f=ifr&ref=as_ss_li_til&asins=4839962510&linkId=d722909965b5eab4196d370757843f6f"></iframe> | + | 上記の例では、Historyに格納しているので、以下のようなコードで、lossやaccuracyのグラフを出力することができます。 |
- | </html> | + | |
- | ===== リンク ===== | + | <code> |
+ | import matplotlib.pyplot as plt | ||
+ | %matplotlib inline | ||
+ | loss = history.history['loss'] | ||
+ | val_loss = history.history['val_loss'] | ||
- | 次 [[(4)モデル学習(Keras)]] | + | # lossのグラフ |
+ | plt.plot(range(3), loss, marker='.', label='loss') | ||
+ | plt.plot(range(3), val_loss, marker='.', label='val_loss') | ||
+ | plt.legend(loc='best', fontsize=10) | ||
+ | plt.grid() | ||
+ | plt.xlabel('epoch') | ||
+ | plt.ylabel('loss') | ||
+ | plt.show() | ||
+ | </code> | ||
- | 前 [[(2)データ準備(Keras)]] | + | {{:pasted:20171110-040150.png}} |
+ | <code> | ||
+ | import matplotlib.pyplot as plt | ||
+ | %matplotlib inline | ||
+ | acc = history.history['acc'] | ||
+ | val_acc = history.history['val_acc'] | ||
- | <wrap hi>Keras2でMNIST目次</wrap>\\ | + | # accuracyのグラフ |
- | [[Kerasプログラミングの全体図]] | + | plt.plot(range(3), acc, marker='.', label='acc') |
- | -[[(1)Kerasを使用するためのimport文]] | + | plt.plot(range(3), val_acc, marker='.', label='val_acc') |
- | -[[(2)データ準備(Keras)]] | + | plt.legend(loc='best', fontsize=10) |
- | -[[(3)モデル設定(Keras)]] | + | plt.grid() |
- | -[[(4)モデル学習(Keras)]] | + | plt.xlabel('epoch') |
- | -[[(5)結果の出力(Keras)]] | + | plt.ylabel('acc') |
- | -[[(6)学習結果の保存(Keras)]] | + | plt.show() |
- | -[[(7)推測(Keras)]] | + | </code> |
+ | {{:pasted:20171110-040318.png}} | ||
+ | |||
+ | |||
+ | |||
+ | ===== Optimizerについて ===== | ||
+ | optimizer(最適化)について\\ | ||
+ | https://keras.io/ja/optimizers/ | ||
===== 参考文献 ===== | ===== 参考文献 ===== |