WHAT'S UP?

New post every sometimes. Here's OKKAH NET.

Chainer公式サンプルMNISTを読み解いてみた

train_mnist.pyのコードを追ってみました。Chainerのバージョンは4.0.0です。

github.com


MNISTとは
手書き数字の画像セット。機械学習の分野で最も有名なデータセットの1つ。
データセットは70000枚。(訓練データ: 60000枚、テストデータ: 10000枚)


処理の流れ
main関数における大まかな処理の流れは
 ⑴ モデルの生成
 ⑵ Optimizerの生成
 ⑶ Iteratorの生成
 ⑷ Trainer、Updaterの生成
 ⑸ Extensionsの登録
 ⑹ 学習ループ
といった感じです。


Trainerについて
Trainerは学習に必要なもの全てをひとまとめにする機能を持っています。
全体図は以下のようになっています。

f:id:okkah:20180501193156p:plain

基本的にTrainerは、渡されたUpdater(必要ならばExtensionsも)を実行するだけですが、
Updaterは中にIteratorとOptimizerを持っています。

Iteratorはデータセットにアクセスする機能を持ち、
Optimizerは重み・バイアスを更新する機能を持ちます。

つまりUpdaterは
 ⑴ データセットからデータを取り出す。(Iterator)
 ⑵ モデルに渡してロスを計算する。(Model = Optimizer.target)
 ⑶ モデルのパラメータを更新する。(Optimizer)
といった一連の学習の主要部分を担っています。

Extensionsは、可視化やログの保存などの様々な便利な機能を持つので、
必要に応じて使いましょう。


実際にMNISTのサンプルコードを追ってみる
1. 必要なライブラリやモジュールをimport

from __future__ import print_function

import argparse

import chainer
import chainer.functions as F
import chainer.links as L
from chainer import training
from chainer.training import extensions

・__future__: Python2系とは互換性の無い3系の機能を、2系でも使用可能にする。
・argparse: コマンドライン引数を扱えるようにする。
・chainer.functions: パラメータを持たない関数。
・chainer.links: パラメータを持つ関数。
・training: trainer


2. モデルの生成

# Set up a neural network to train    
# Classifier reports softmax cross entropy loss and accuracy at every    
# iteration, which will be used by the PrintReport extension below.    
model = L.Classifier(MLP(args.unit, 10))    
if args.gpu >= 0:
    # Make a specified GPU current
    chainer.backends.cuda.get_device_from_id(args.gpu).use()
    model.to_gpu()  # Copy the model to the GPU
# Network definition
class MLP(chainer.Chain):

    def __init__(self, n_units, n_out):
        super(MLP, self).__init__()
        with self.init_scope():
            # the size of the inputs to each layer will be inferred
            self.l1 = L.Linear(None, n_units)  # n_in -> n_units
            self.l2 = L.Linear(None, n_units)  # n_units -> n_units
            self.l3 = L.Linear(None, n_out)  # n_units -> n_out

    def __call__(self, x):
        h1 = F.relu(self.l1(x))
        h2 = F.relu(self.l2(h1))
        return self.l3(h2)

model: Classifierインスタンス。このmodelにデータを渡すと順伝播が始まる。
MLP: レイヤー構成を定義するクラス。
・chainer.Chain: パラメータを持つlinksをまとめておくクラス。
 Optimizerが更新するパラメータを簡単に取得できるように一箇所にまとめる。
・__init__: モデルを構成するレイヤーを定義するメソッド。インスタンス化。
・__call__: クラスのインスタンスを関数として呼び出すメソッド。
L.Linear: 全結合層を実現するクラス。
 データがその層に入力されると、必要な入力ユニット数を自動的に計算し、
 (n_in) x (n_units)の大きさの行列を生成し、パラメータとして保持する。
・relu: 活性化関数(→ゼロから作るDeep Learning 3章)

モデルのインスタンスは以下のようになります。
f:id:okkah:20180501230144p:plain

3. コマンドライン引数の設定

    parser = argparse.ArgumentParser(description='Chainer example: MNIST')
    parser.add_argument('--batchsize', '-b', type=int, default=100,
                        help='Number of images in each mini-batch')
    parser.add_argument('--epoch', '-e', type=int, default=20,
                        help='Number of sweeps over the dataset to train')
    parser.add_argument('--frequency', '-f', type=int, default=-1,
                        help='Frequency of taking a snapshot')
    parser.add_argument('--gpu', '-g', type=int, default=-1,
                        help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--out', '-o', default='result',
                        help='Directory to output the result')
    parser.add_argument('--resume', '-r', default='',
                        help='Resume the training from snapshot')
    parser.add_argument('--unit', '-u', type=int, default=1000,
                        help='Number of units')
    parser.add_argument('--noplot', dest='plot', action='store_false',
                        help='Disable PlotReport extension')
    args = parser.parse_args()

    print('GPU: {}'.format(args.gpu))
    print('# unit: {}'.format(args.unit))
    print('# Minibatch-size: {}'.format(args.batchsize))
    print('# epoch: {}'.format(args.epoch))
    print('')

argparseモジュールを使うと、コマンドライン引数を扱うことができます。
これにより、Pythonを実行するときに様々なパラメータを指定することができます。

例)

$ python train_mnist.py -g 0 -e 10
GPU: 0
# unit: 1000
# Minibatch-size: 100
# epoch: 10

GPU: 0 (0: GPUを使用、-1: CPUを使用)
epoch: 10 (学習の反復回数: 10回)
で実行することができます。(他はデフォルトのまま)



4. Optimizerの生成

    # Setup an optimizer
    optimizer = chainer.optimizers.Adam()
    optimizer.setup(model)

モデルによって順伝播・逆伝播が実行され勾配が計算されますが、
その値を重み・バイアスに反映するのがOptimizerの仕事です。
役割上、モデルをラップするような形になります。

・Adam: パラメータの更新(→ゼロから作るDeep Learning 6章)


5. Iteratorの生成

    # Load the MNIST dataset
    train, test = chainer.datasets.get_mnist()

    train_iter = chainer.iterators.SerialIterator(train, args.batchsize)
    test_iter = chainer.iterators.SerialIterator(test, args.batchsize,
                                                 repeat=False, shuffle=False)

Iteratorでデータセットにアクセスします。

・get_mnist: mnistのデータをtrainとtestに入れる。
・SerialIterator: データを順番に取り出すもっともシンプルなIterator
・batchsize: 何枚の画像データを一括りにして取り出すかどうか。
・repeat: 何周もデータを繰り返し読むかどうか。
・shuffle: 取り出すデータの順番をepochごとにランダムに変更するかどうか。

ここまでが初期化に関する処理です。


6. Trainer、Updaterの生成

    # Set up a trainer
    updater = training.updaters.StandardUpdater(
        train_iter, optimizer, device=args.gpu)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)

ここから学習ループに関する処理に入ります。

・StandardUpdater: Updaterの処理を遂行するための最もシンプルなクラス。
・args.epoch: stop_trigger。学習をどのタイミングで終了するか。
・out: Extensionsで描画したグラフ画像の保存先。


7. Extensionsの登録

    # Evaluate the model with the test dataset for each epoch
    trainer.extend(extensions.Evaluator(test_iter, model, device=args.gpu))

    # Dump a computational graph from 'loss' variable at the first iteration
    # The "main" refers to the target link of the "main" optimizer.
    trainer.extend(extensions.dump_graph('main/loss'))

    # Take a snapshot for each specified epoch
    frequency = args.epoch if args.frequency == -1 else max(1, args.frequency)
    trainer.extend(extensions.snapshot(), trigger=(frequency, 'epoch'))

    # Write a log of evaluation statistics for each epoch
    trainer.extend(extensions.LogReport())

    # Save two plot images to the result dir
    if args.plot and extensions.PlotReport.available():
        trainer.extend(
            extensions.PlotReport(['main/loss', 'validation/main/loss'],
                                  'epoch', file_name='loss.png'))
        trainer.extend(
            extensions.PlotReport(
                ['main/accuracy', 'validation/main/accuracy'],
                'epoch', file_name='accuracy.png'))

    # Print selected entries of the log to stdout
    # Here "main" refers to the target link of the "main" optimizer again, and
    # "validation" refers to the default name of the Evaluator extension.
    # Entries other than 'epoch' are reported by the Classifier link, called by
    # either the updater or the evaluator.
    trainer.extend(extensions.PrintReport(
        ['epoch', 'main/loss', 'validation/main/loss',
         'main/accuracy', 'validation/main/accuracy', 'elapsed_time']))

    # Print a progress bar to stdout
    trainer.extend(extensions.ProgressBar())

・Extensions: 可視化やログの保存などの様々な便利な機能を持つ。
 ・Evaluator: モデルの評価。
 ・dump_graph: グラフの保存。
 ・snapshot: Trainerを保存。
 ・LogReport: ログを保存。
 ・PlotReport: ロスを可視化して保存。
 ・PrintReport: ログを出力。
 ・ProgressBar: 学習の進行状況を出力。


8. 最後に

    if args.resume:
        # Resume from a snapshot
        chainer.serializers.load_npz(args.resume, trainer)

    # Run the training
    trainer.run()


if __name__ == '__main__':
    main()

・serializers.load_npz: snapshotから再開。
・trainer.run: 学習を開始。
・if __name__ == '__main__':
 直接実行されていればTrue、importされて実行されていればFalse。