ChainerでMNIST

  • ニューラルネットワークのライブラリのChainerですが、去年のうちに大分変更がありました.
  • というかバージョンアップ早すぎてびびる
  • この記事書く際にふとリファレンス見たらいつの間にか1.20.0のドキュメントができてた(GitHubのリリースノートの最新は現時点ではまだ1.19.0)
    Chainer1.19.0版MNISTのコードを紹介します.
# -*- coding: utf-8 -*-
from __future__ import print_function

import chainer
import chainer.functions as F
import chainer.links as L
from chainer import training
from chainer.training import extensions

#Network definition
class MLP(chainer.Chain):
    def __init__(self, n_units, n_out):
        super(MLP, self).__init__(
            l1=L.Linear(None, n_units),
            l2=L.Linear(None, n_units),
            l3=L.Linear(None, n_out),
        )

    def __call__(self, x):
        h1 = F.relu(self.l1(x))
        h2 = F.relu(self.l2(h1))
        return self.l3(h2)


def main():
    unit = 1000
    batchsize = 100
    epoch = 20

    model = L.Classifier(MLP(unit, 10))

    optimizer = chainer.optimizers.Adam()
    optimizer.setup(model)

    train, test = chainer.datasets.get_mnist()
    train_iter = chainer.iterators.SerialIterator(train, batchsize)
    test_iter = chainer.iterators.SerialIterator(test, batchsize, repeat=False, shuffle=False)

    updater = training.StandardUpdater(train_iter, optimizer)
    trainer = training.Trainer(updater, (epoch, 'epoch'), out='result')

    trainer.extend(extensions.Evaluator(test_iter, model))
    trainer.extend(extensions.dump_graph('main/loss'))
    trainer.extend(extensions.snapshot(), trigger=(epoch, 'epoch'))
    trainer.extend(extensions.LogReport())
    trainer.extend(extensions.PrintReport(
        ['epoch', 'main/loss', 'validation/main/loss',
         'main/accuracy', 'validation/main/accuracy', 'elapsed_time']))
    trainer.extend(extensions.ProgressBar())

    trainer.run()

if __name__ == "__main__":
    main()

ざっくりとした説明とかはこちら
ChainerでMNIST