ChainerのTupleDatasetについて
アドベントカレンダー12日目です。
adventar.org
今日はChainerのTupleDatasetについて書いていきます。
TupleDatasetとは
複数のデータセットからTupleのデータセットを作成します。
https://docs.chainer.org/en/stable/reference/generated/chainer.datasets.TupleDataset.html#chainer.datasets.TupleDataset
Chainerの公式Exampleの中で有名なMNISTでも用いられています。
MNISTでは、データセットの準備の際にget_mnist()と呼んでいます。
# Load the MNIST dataset
train, test = chainer.datasets.get_mnist()
get_mnist()の中を見に行くと、このように書かれています。
if withlabel: labels = raw['y'].astype(label_dtype) return tuple_dataset.TupleDataset(images, labels)
ラベルがある場合には、画像とラベルを引数にTupleDatasetを呼んでいます。
内部での処理
まずdatasetsが渡されているか確認します。datasets[0]の長さをとり、enumerateを使い、datasetsの長さが全て一致しているか確認します。問題が無ければ、self.datasetsとlengthに値をセットします。
def __init__(self, *datasets): if not datasets: raise ValueError('no datasets are given') length = len(datasets[0]) for i, dataset in enumerate(datasets): if len(dataset) != length: raise ValueError( 'dataset of the index {} has a wrong length'.format(i)) self._datasets = datasets self._length = length
getitemでは、datasetsの分for分を回し、dataset[index]をbatchesに代入し、
isinstanceがtrueの場合は、batches[0]の長さを確認しfor文を回して、tupleで値を返す。
falseの場合は、そのままbatchesをtupleにして値を返す。
def __getitem__(self, index): batches = [dataset[index] for dataset in self._datasets] if isinstance(index, slice): length = len(batches[0]) return [tuple([batch[i] for batch in batches]) for i in six.moves.range(length)] else: return tuple(batches)
まとめ
TupleDataset自体はそこまで複雑な処理は行っておらず、基本的な長さの確認をしたりして、データを扱いやすくするための関数でした。これを理解することで、SerialIteratorなども読みやすくなると思います。