1 系统环境
硬件环境(Ascend/GPU/CPU): Ascend
MindSpore版本: mindspore=2.2.0
执行模式(PyNative/ Graph):不限
Python版本: Python=3.7.5
操作系统平台: Linux
2 报错信息
2.1 问题描述
数据比较特殊,需要根据数据集的样本数执行特殊的采样:如果样本数是2的整数倍,则每2条样本取其1;如果样本数是3的整数倍,则每3条样本取其1;如果样本是5的整数倍,则每5条样本取其1;否则,每条样本都采样。但是我同时需要执行8卡训练,按照官方文档的说明,设置了数据集的num_shards和shard_id参数,但是会出现报错。
2.2 脚本信息
import os
import cv2
import mindspore.dataset as ds
class MySampler(ds.Sampler):
def __iter__(self):
if self.num_samples % 2 == 0:
interval = 2
elif self.num_samples % 3 == 0:
interval = 3
elif self.num_samples % 5 == 0:
interval = 5
else:
interval = 1
for i in range(0, self.num_samples, interval):
yield i
class MyDataset:
def __init__(self, dataset_dir):
self.data_file = [os.path.join(dataset_dir, filename) for filename in os.listdir(dataset_dir)]
def __getitem__(self, index):
image = cv2.imread(self.data_file[index])
return image
def __len__(self):
return len(self.data_file)
dataset_size = 20
dataset = ds.GeneratorDataset(MyDataset("./image"), column_names=["image"],
sampler=ds.IterSampler(MySampler(dataset_size)),
num_shards=8, shard_id=0)
for data in dataset.create_dict_iterator(output_numpy=True, num_epochs=1):
print(data)
2.3 问题描述
dataset/test_s.py:31: in <module>
dataset = ds.GeneratorDataset(MyDataset("./image"), column_names=["image"],
../../../mindspore/python/mindspore/dataset/engine/validators.py:1136: in new_method
return method(self, *args, **kwargs)
../../../mindspore/python/mindspore/dataset/engine/datasets_user_defined.py:767: in __init__
super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples,
../../../mindspore/python/mindspore/dataset/engine/datasets.py:2362: in __init__
self.sampler = samplers.select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
../../../mindspore/python/mindspore/dataset/engine/samplers.py:55: in select_sampler
raise ValueError(
E ValueError: Conflicting arguments during sampler assignments. num_samples: None, num_shards: 8, shard_id: 0, shuffle: None.
3 根因分析
分析报错代码,当自定义的sampler是BuiltinSampler的派生类时,不允许配置[num_shards, shard_id, shuffle, num_samples]参数
if (isinstance(input_sampler, BuiltinSampler) and
(any(arg is not None for arg in [num_shards, shard_id, shuffle, num_samples]))):
raise ValueError(
'Conflicting arguments during sampler assignments. num_samples: {}, num_shards: {},'
' shard_id: {}, shuffle: {}.'.format(num_samples, num_shards, shard_id, shuffle))
4 解决方案
传递给ds.GeneratorDataset的采样器,避免使用继承于BuiltinSampler的采样器