機械学習とその他

機械学習したいマン

chainerのSerial_Iteratorについて

アドベントカレンダー13日目です。大幅に超過してますが許してください。
adventar.org
今日はchainerのSerial Iteratorについて書いていきます。

Serial_Iteratorとは

まずはIteratorの説明を見ていきます。

Base class of all dataset iterators.

Iterator iterates over the dataset, yielding a minibatch at each iteration. Minibatch is a list of examples. Each implementation should implement an iterator protocol (e.g., the __next__() method).

Note that, even if the iterator supports setting the batch size, it does not guarantee that each batch always contains the same number of examples. For example, if you let the iterator to stop at the end of the sweep, the last batch may contain a fewer number of examples.

The interface between the iterator and the underlying dataset is not fixed, and up to the implementation.

chainer.dataset.Iterator — Chainer 3.2.0 documentation

TupleDatasetなど、Datasetを引数にとり、バッチサイズ分だけ返してくれる関数です。
学習の際に使います。

内部での処理

まずは__init__を見ていきます。値を受け取るくらいですね。

def __init__(self, dataset, batch_size, repeat=True, shuffle=True):
        self.dataset = dataset
        self.batch_size = batch_size
        self._repeat = repeat
        self._shuffle = shuffle

        self.reset()

次に__next__を見ていきます。これが重要な部分になります。

def __next__(self):
        if not self._repeat and self.epoch > 0:
            raise StopIteration

        self._previous_epoch_detail = self.epoch_detail

        i = self.current_position
        i_end = i + self.batch_size
        N = len(self.dataset)

        if self._order is None:
            batch = self.dataset[i:i_end]
        else:
            batch = [self.dataset[index] for index in self._order[i:i_end]]

        if i_end >= N:
            if self._repeat:
                rest = i_end - N
                if self._order is not None:
                    numpy.random.shuffle(self._order)
                if rest > 0:
                    if self._order is None:
                        batch.extend(self.dataset[:rest])
                    else:
                        batch.extend([self.dataset[index]
                                      for index in self._order[:rest]])
                self.current_position = rest
            else:
                self.current_position = 0

            self.epoch += 1
            self.is_new_epoch = True
        else:
            self.is_new_epoch = False
            self.current_position = i_end

        return batch

初期設定などはここで行われます。
shuffleを行う場合には、データセットの長さがからpermutationを行います。

def reset(self):
        if self._shuffle:
            self._order = numpy.random.permutation(len(self.dataset))
        else:
            self._order = None

        self.current_position = 0
        self.epoch = 0
        self.is_new_epoch = False

        # use -1 instead of None internally.
        self._previous_epoch_detail = -1.


batchを返す準備は以下で行われます。スライシングする範囲の最初と最後を計算します。データセットのサイズを同時に計算します。shuffleを行わない場合には、データセットからスライシングしてそのままbatchに代入します。shuffleを行う場合には、permutationを使い、バッチサイズ分だけbatchに代入していきます。

i = self.current_position
i_end = i + self.batch_size
N = len(self.dataset)

if self._order is None:
    batch = self.dataset[i:i_end]
else:
    batch = [self.dataset[index] for index in self._order[i:i_end]]

スライシングがバッチサイズを上回ったときのための処理です。
上回った場合では、データセットのサイズまでにスライシングします。
エポックのインクリメントなども行われます。

if i_end >= N:
            if self._repeat:
                rest = i_end - N
                if self._order is not None:
                    numpy.random.shuffle(self._order)
                if rest > 0:
                    if self._order is None:
                        batch.extend(self.dataset[:rest])
                    else:
                        batch.extend([self.dataset[index]
                                      for index in self._order[:rest]])
                self.current_position = rest
            else:
                self.current_position = 0

            self.epoch += 1
            self.is_new_epoch = True
        else:
            self.is_new_epoch = False
            self.current_position = i_end

        return batch

まとめ

ミニバッチを取り出す処理を行うわけですが、シャッフルできる機能があったり、リピート機能があったりして思っていたよりはいろんな機能があったことをしれてよかったです。次はこれを使うupdaterあたりを見ていこうと思います。