MindSpore如何对使用了自定义采样器的数据集进行分布式采样

1 系统环境

硬件环境(Ascend/GPU/CPU): Ascend
MindSpore版本: mindspore=2.2.0
执行模式(PyNative/ Graph):不限
Python版本: Python=3.7.5
操作系统平台: Linux

2 报错信息

2.1 问题描述

数据比较特殊,需要根据数据集的样本数执行特殊的采样:如果样本数是2的整数倍,则每2条样本取其1;如果样本数是3的整数倍,则每3条样本取其1;如果样本是5的整数倍,则每5条样本取其1;否则,每条样本都采样。但是我同时需要执行8卡训练,按照官方文档的说明,设置了数据集的num_shards和shard_id参数,但是会出现报错。

2.2 脚本信息

import os
import cv2
import mindspore.dataset as ds

class MySampler(ds.Sampler):
    def __iter__(self):
        if self.num_samples % 2 == 0:
            interval = 2
        elif self.num_samples % 3 == 0:
            interval = 3
        elif self.num_samples % 5 == 0:
            interval = 5
        else:
            interval = 1
        for i in range(0, self.num_samples, interval):
            yield i


class MyDataset:
    def __init__(self, dataset_dir):
        self.data_file = [os.path.join(dataset_dir, filename) for filename in os.listdir(dataset_dir)]

    def __getitem__(self, index):
        image = cv2.imread(self.data_file[index])
        return image

    def __len__(self):
        return len(self.data_file)

dataset_size = 20
dataset = ds.GeneratorDataset(MyDataset("./image"), column_names=["image"],
                                sampler=ds.IterSampler(MySampler(dataset_size)),
                                num_shards=8, shard_id=0)
for data in dataset.create_dict_iterator(output_numpy=True, num_epochs=1):
print(data)

2.3 问题描述

dataset/test_s.py:31: in <module>
    dataset = ds.GeneratorDataset(MyDataset("./image"), column_names=["image"],
../../../mindspore/python/mindspore/dataset/engine/validators.py:1136: in new_method
    return method(self, *args, **kwargs)
../../../mindspore/python/mindspore/dataset/engine/datasets_user_defined.py:767: in __init__
    super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples,
../../../mindspore/python/mindspore/dataset/engine/datasets.py:2362: in __init__
    self.sampler = samplers.select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
../../../mindspore/python/mindspore/dataset/engine/samplers.py:55: in select_sampler
    raise ValueError(
E   ValueError: Conflicting arguments during sampler assignments. num_samples: None, num_shards: 8, shard_id: 0, shuffle: None.

3 根因分析

分析报错代码,当自定义的sampler是BuiltinSampler的派生类时,不允许配置[num_shards, shard_id, shuffle, num_samples]参数

if (isinstance(input_sampler, BuiltinSampler) and  
        (any(arg is not None for arg in [num_shards, shard_id, shuffle, num_samples]))):  
    raise ValueError(  
        'Conflicting arguments during sampler assignments. num_samples: {}, num_shards: {},'  
        ' shard_id: {}, shuffle: {}.'.format(num_samples, num_shards, shard_id, shuffle))

4 解决方案

传递给ds.GeneratorDataset的采样器,避免使用继承于BuiltinSampler的采样器