chainerのSerial_Iteratorについて
アドベントカレンダー13日目です。大幅に超過してますが許してください。
adventar.org
今日はchainerのSerial Iteratorについて書いていきます。
Serial_Iteratorとは
まずはIteratorの説明を見ていきます。
Base class of all dataset iterators.
Iterator iterates over the dataset, yielding a minibatch at each iteration. Minibatch is a list of examples. Each implementation should implement an iterator protocol (e.g., the __next__() method).
Note that, even if the iterator supports setting the batch size, it does not guarantee that each batch always contains the same number of examples. For example, if you let the iterator to stop at the end of the sweep, the last batch may contain a fewer number of examples.
The interface between the iterator and the underlying dataset is not fixed, and up to the implementation.
chainer.dataset.Iterator — Chainer 3.2.0 documentation
TupleDatasetなど、Datasetを引数にとり、バッチサイズ分だけ返してくれる関数です。
学習の際に使います。
内部での処理
まずは__init__を見ていきます。値を受け取るくらいですね。
def __init__(self, dataset, batch_size, repeat=True, shuffle=True): self.dataset = dataset self.batch_size = batch_size self._repeat = repeat self._shuffle = shuffle self.reset()
次に__next__を見ていきます。これが重要な部分になります。
def __next__(self): if not self._repeat and self.epoch > 0: raise StopIteration self._previous_epoch_detail = self.epoch_detail i = self.current_position i_end = i + self.batch_size N = len(self.dataset) if self._order is None: batch = self.dataset[i:i_end] else: batch = [self.dataset[index] for index in self._order[i:i_end]] if i_end >= N: if self._repeat: rest = i_end - N if self._order is not None: numpy.random.shuffle(self._order) if rest > 0: if self._order is None: batch.extend(self.dataset[:rest]) else: batch.extend([self.dataset[index] for index in self._order[:rest]]) self.current_position = rest else: self.current_position = 0 self.epoch += 1 self.is_new_epoch = True else: self.is_new_epoch = False self.current_position = i_end return batch
初期設定などはここで行われます。
shuffleを行う場合には、データセットの長さがからpermutationを行います。
def reset(self): if self._shuffle: self._order = numpy.random.permutation(len(self.dataset)) else: self._order = None self.current_position = 0 self.epoch = 0 self.is_new_epoch = False # use -1 instead of None internally. self._previous_epoch_detail = -1.
batchを返す準備は以下で行われます。スライシングする範囲の最初と最後を計算します。データセットのサイズを同時に計算します。shuffleを行わない場合には、データセットからスライシングしてそのままbatchに代入します。shuffleを行う場合には、permutationを使い、バッチサイズ分だけbatchに代入していきます。
i = self.current_position i_end = i + self.batch_size N = len(self.dataset) if self._order is None: batch = self.dataset[i:i_end] else: batch = [self.dataset[index] for index in self._order[i:i_end]]
スライシングがバッチサイズを上回ったときのための処理です。
上回った場合では、データセットのサイズまでにスライシングします。
エポックのインクリメントなども行われます。
if i_end >= N: if self._repeat: rest = i_end - N if self._order is not None: numpy.random.shuffle(self._order) if rest > 0: if self._order is None: batch.extend(self.dataset[:rest]) else: batch.extend([self.dataset[index] for index in self._order[:rest]]) self.current_position = rest else: self.current_position = 0 self.epoch += 1 self.is_new_epoch = True else: self.is_new_epoch = False self.current_position = i_end return batch
まとめ
ミニバッチを取り出す処理を行うわけですが、シャッフルできる機能があったり、リピート機能があったりして思っていたよりはいろんな機能があったことをしれてよかったです。次はこれを使うupdaterあたりを見ていこうと思います。