スポンサーリンク

「ライブラリーを使わずにPythonでニューラルネットワークを構築してみる」を写経してみる

前回、Chainer(GPU使用)のインストール(OSはWindows 7)がむちゃくちゃ簡単だったのでびっくりした。

以下の本を読んだ後、実際にチュートリアルをがんがんやってみたい今日この頃。


ゼロから作るDeep Learning ―Pythonで学ぶディープラーニングの理論と実装 単行本(ソフトカバー) – 2016/9/24
斎藤 康毅  (著)
3672円

今回は、以下のサイトを写経してみる。

ライブラリーを使わずにPythonでニューラルネットワークを構築してみる
kiminaka
2016年07月31日に更新
http://qiita.com/kiminaka/items/9ae195739093277490fe

ソースコード
https://github.com/dennybritz/nn-from-scratch

ちなみに、上記サイトは、以下のブログを読んでいて、リンクから飛んだ。

2016-03-02
【ディープラーニング】10時間でChainerの基本を身につける
ディープラーニング勉強記
http://esu-ko.hatenablog.com/entry/2016/03/02/%E3%80%90%E3%83%87%E3%82%A3%E3%83%BC%E3%83%97%E3%83%A9%E3%83%BC%E3%83%8B%E3%83%B3%E3%82%B0%E3%80%9110%E6%99%82%E9%96%93%E3%81%A7Chainer%E3%81%AE%E5%9F%BA%E6%9C%AC%E3%82%92%E8%BA%AB%E3%81%AB%E3%81%A4

(環境)
Intel(R) Core(TM) i7-4770 CPU @ 3.40GHz
RAM 32.0 GB
GPU NVIDIA GeForce GTX 660 (4GB)
Windows 7 Pro
Anaconda 4.1.1(64-bit)
Python 3.5.2
Tensorflow 1.1 GPU
http://twosquirrel.mints.ne.jp/?p=17040
http://twosquirrel.mints.ne.jp/?p=17040

(1)データの生成

cmd.exeで、jupyter notebook してから、Newでipynbファイル(python3)を作成してから、

%matplotlib inline

 

import numpy as np
from sklearn import datasets, linear_model
import matplotlib.pyplot as plt

 

# データを生成してプロットする
np.random.seed(0)
X, y = datasets.make_moons(200, noise=0.20)
plt.scatter(X[:,0], X[:,1], s=40, c=y, cmap=plt.cm.Spectral)

たしかに、上記データは直線では分類が不可能。

(2)ロジスティック回帰

(3)ニューラルネットワークを学習させる

入力層のノードを2つ、出力層のノードも2つ。隠れ層を1層として、ノードを5個にする。
https://qiita-image-store.s3.amazonaws.com/0/107310/ccb32887-677e-a041-dbe0-b2693381d053.png

今回は、隠れ層のアクティベーション関数に、tanhx関数を用いる。

tanh(x)= \frac{ e^{x}- e^{-x} } { e^{x}+ e^{-x} }

tanh(x)の微分が、1-{tanh(x)}2 であり、便利らしい。

また、今回はアウトプットに確率を与えたいので、アウトプット層のアクティベーション関数にSoftmax関数を用いる。

y_{i}=  \frac{ e^{ a_{i}} } { \sum_{k=1}^K e^{a_{k}}}\,\,   (i = 1, ... , K )

(3)ニューラルネットワークの予測のしくみ

forward propagationを用いる。

(4)パラメーターを学習させる

Loss関数として、交差エントロピー最小化を用いる。

Loss最小値を計算するために、Gradient Descent(勾配降下)を用いる。

back-propagation

このあたりから、かなり理解があやしくなってくる。上記の本を繰り返し読むしかあるまい。。。

(5)実際にコードを書いてみる

疲れたのでコピペしよー。

https://github.com/dennybritz/nn-from-scratch/blob/master/ann_classification.py

今日はこのあたりで。。。