こんにちは、技術開発の三浦です。今回は機械学習フレームワークPyTorchについて紹介します!
PyTorch
PyTorchはオープンソースの機械学習フレームワークで特に深層学習の分野でよく活用されています。深層学習の分野ではGoogleが開発したTensorFlowもよく活用されていますが、PyTorchは研究分野で、TensorFlowは産業分野でよく使われる傾向がある、といった記事もあります。
私はTensorFlowは何度か使ったことがあるのですが、PyTorchには触れたことがありません。せっかくなので触れてみようと思い、PyTorchの公式Tutorialにトライしてみました。
Tutorial
こちらが公式のTutorialです。
結構いっぱいコンテンツがあります。Object DetectionやGANもあるのでこれをこなしていけばPyTorchで色んなことが出来るようになりそうです。
Deep Learning with PyTorch:A 60 Minute Blitz
そんな中でとりあえずPyTorchの基本が学べそうなこちらのTutorialをやってみました。
最初にTensorの扱い方、自動微分で勾配を計算する、Neural Networkの組み方、そして最後にCNNでCIFAR10画像セットののclass分類をする流れになっています。ここまでやってうっすらPyTorchの使い方が見えてきた感じです。
Training a Classifier
さてここからが本題です。CNNでCIFAR10のclass分類をするのですが、Tutorialで扱うネットワークの構造は以下のようになっています。
畳み込み層(kernel=5x5, channel:3→6)
relu活性
プーリング(MaxPool, kernel=2x2)
↓
畳み込み層(kernel=5x5, channel:6→16)
relu活性
プーリング(MaxPool, kernel=2x2)
↓
全結合(400→120)
↓
全結合(120→84)
↓
全結合(84→10)
PyTorchで書くとこうなります。
class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = x.view(-1, 16 * 5 * 5) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x
この構造でbatch_size=4、epoch=2でテストデータに対してAccuracyが54%になりました。ネットワークの構造を変えてAccuracyを上げることが出来るか簡単に実験してみました。
ネットワークを深くしたらいい?
精度を上げるためのアプローチとしてネットワークを深くすることが有効だと聞いたことがあります。なので畳み込み層〜プーリング層を1つ追加し以下のようにしてみました。
畳み込み層(kernel=5x5, channel:3→6)
relu活性
プーリング(MaxPool, kernel=2x2)
↓
畳み込み層(kernel=3x3, channel:6→16)
relu活性
プーリング(MaxPool, kernel=2x2)
↓
畳み込み層(kernel=3x3, channel:16→32)
relu活性
プーリング(MaxPool, kernel=2x2)
↓
全結合(128→120)
↓
全結合(120→84)
↓
全結合(84→10)
class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5) #32 - 2 self.pool = nn.MaxPool2d(2, 2) #30 / 2 self.conv2 = nn.Conv2d(6, 16, 3) #15 - 2 self.conv3 = nn.Conv2d(16, 32, 3) self.fc1 = nn.Linear(32 * 2 * 2, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = self.pool(F.relu(self.conv3(x))) x = x.view(-1, 32 * 2 * 2) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x
これで精度上がるかな・・・
Accuracy:49% !
げ、めちゃくちゃ下がってしまった!!
考えられる理由
よくよく見てみると理由はおそらく最初の全結合層に入力されるニューロンの数がオリジナルだと400なのに対し、今回のは128に大幅に下がっていることではないでしょうか。画像から抽出する特徴量が減ってしまったため、精度が下がったのではないかと思います。今回はpaddingなども指定していないので、畳込みやプーリング処理で画像のサイズはどんどん小さくなっていきます。ならばチャンネル数を増やすことで特徴量を増やしてみます。
チャンネル数を増やす
次のように作り変えてみました。
畳み込み層(kernel=3x3, channel:3-→6)
relu活性
プーリング(MaxPool, kernel=2x2)
↓
畳み込み層(kernel=2x2, channel:6→18)
relu活性
プーリング(MaxPool, kernel=2x2)
↓
畳み込み層(kernel=2x2, channel:18→72)
relu活性
プーリング(MaxPool, kernel=2x2)
↓
全結合(558→120)
↓
全結合(120→84)
↓
全結合(84→10)
class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(3, 6, 3) #32 - 2 self.pool = nn.MaxPool2d(2, 2) #30 / 2 self.conv2 = nn.Conv2d(6, 18, 2) #15 - 2 self.conv3 = nn.Conv2d(18, 72, 2) self.fc1 = nn.Linear(72 * 3 * 3, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = self.pool(F.relu(self.conv3(x))) x = x.view(-1, 72 * 3 * 3) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x
全結合層に入力されるニューロンの数は558に増えました。このネットワークでAccuracyを測定すると58%にアップしました。作戦成功です!ちなみに各classごとのAccuracyの変化を見てみると、 オリジナルでは
accuracy of plane : 45 % accuracy of car : 70 % accuracy of bird : 44 % accuracy of cat : 32 % accuracy of deer : 27 % accuracy of dog : 57 % accuracy of frog : 64 % accuracy of horse : 54 % accuracy of ship : 73 % accuracy of truck : 72 %
だったのに対し、今回のネットワークでは
accuracy of plane : 79 % accuracy of car : 74 % accuracy of bird : 35 % accuracy of cat : 41 % accuracy of deer : 50 % accuracy of dog : 41 % accuracy of frog : 67 % accuracy of horse : 69 % accuracy of ship : 63 % accuracy of truck : 58 %
のようになりました。planeのAccuracyが大きく上がっていて、truckの方が結構下がりましたね。ネットワークの構造でうまく捉えられるものや捉えられないものがあるということなのでしょうか。面白い結果ですが、CNNの奥深さ・難しさを感じる結果とも言えます。
最後に
本当はpaddingを指定したりoptimizerを変更したりbatch_sizeを変更したり色々な精度向上のアプローチがあるはずですが、今回は単純にネットワークの構造を変える方法で精度向上に取り組みました。引き続きPyTorchに触ってみようと思います!