機械学習とその他

機械学習したいマン

ChainerのTupleDatasetについて

アドベントカレンダー12日目です。
adventar.org
今日はChainerのTupleDatasetについて書いていきます。

TupleDatasetとは

複数のデータセットからTupleのデータセットを作成します。
https://docs.chainer.org/en/stable/reference/generated/chainer.datasets.TupleDataset.html#chainer.datasets.TupleDataset

Chainerの公式Exampleの中で有名なMNISTでも用いられています。
MNISTでは、データセットの準備の際にget_mnist()と呼んでいます。

# Load the MNIST dataset
train, test = chainer.datasets.get_mnist()

get_mnist()の中を見に行くと、このように書かれています。

if withlabel:
        labels = raw['y'].astype(label_dtype)
        return tuple_dataset.TupleDataset(images, labels)

ラベルがある場合には、画像とラベルを引数にTupleDatasetを呼んでいます。

内部での処理

まずdatasetsが渡されているか確認します。datasets[0]の長さをとり、enumerateを使い、datasetsの長さが全て一致しているか確認します。問題が無ければ、self.datasetsとlengthに値をセットします。

def __init__(self, *datasets):
        if not datasets:
            raise ValueError('no datasets are given')
        length = len(datasets[0])
        for i, dataset in enumerate(datasets):
            if len(dataset) != length:
                raise ValueError(
                    'dataset of the index {} has a wrong length'.format(i))
        self._datasets = datasets
        self._length = length

getitemでは、datasetsの分for分を回し、dataset[index]をbatchesに代入し、
isinstanceがtrueの場合は、batches[0]の長さを確認しfor文を回して、tupleで値を返す。
falseの場合は、そのままbatchesをtupleにして値を返す。

def __getitem__(self, index):
        batches = [dataset[index] for dataset in self._datasets]
        if isinstance(index, slice):
            length = len(batches[0])
            return [tuple([batch[i] for batch in batches])
                    for i in six.moves.range(length)]
        else:
                return tuple(batches)

まとめ

TupleDataset自体はそこまで複雑な処理は行っておらず、基本的な長さの確認をしたりして、データを扱いやすくするための関数でした。これを理解することで、SerialIteratorなども読みやすくなると思います。