MindSpore数据集加载报错【'IdentitySampler' object has no attribute 'child_sampler'】

Mindspore除了提供如下各类Sampler,也支持用户自定义Sampler进行自定义的采样操作,具体可以参考:自定义采样器

Mindspore会依据自定义Sampler中__iter__返回的索引值对样本进行采样。

在CPU + MindSpore1.1.1进行定义如下自定义Sampler:

  1    class IdentitySampler(ds.Sampler):
  2
  3     """Sample person identities evenly in each batch.
  4
  5         Args:
  6
  7             train_color_label, train_thermal_label: labels of two modalities
  8
  9             color_pos, thermal_pos: positions of each identity
 10
 11             batchSize: batch size
 12
 13     """
 14
 15
 16     def __init__(self, train_color_label, train_thermal_label, color_pos, thermal_pos, num_pos, batchSize, epoch):
 17
 18         uni_label = np.unique(train_color_label)
 19
 20         self.n_classes = len(uni_label)
 21
 22         N = np.maximum(len(train_color_label), len(train_thermal_label))
 23
 24         for j in range(int(N/(batchSize*num_pos))+1):
 25
 26             batch_idx = np.random.choice(uni_label, batchSize, replace = False)
 27
 28             for i in range(batchSize):
 29
 30                 sample_color  = np.random.choice(color_pos[batch_idx], num_pos)
 31
 32                 sample_thermal = np.random.choice(thermal_pos[batch_idx], num_pos)
 33
 34                 if j ==0 and i==0:
 35
 36                     index1= sample_color
 37
 38                     index2= sample_thermal
 39
 40                 else:
 41
 42                     index1 = np.hstack((index1, sample_color))
 43
 44                     index2 = np.hstack((index2, sample_thermal))
 45
 46         self.index1 = index1
 47
 48         self.index2 = index2
 49
 50         self.N  = N
51
 52         self.num_samples = N
 53
 54     def __iter__(self):
 55
 56         # return iter(np.arange(len(self.index1)))
 57
 58         for i in range(len(self.index1)):
 59
 60             yield i
 61
 62     def __len__(self):
 63
 64         return self.N

在GeneratorDataset中使用自定义的Sampler,具体如下:

trainset = ds.GeneratorDataset(trainset_generator, ["color", "thermal","color_label", "thermal_label"], sampler=sampler).map(

            operations=transform_train, input_columns=["color", "thermal"]

        )

 ......


 model = Model(net, loss_fn=criterion1, optimizer=optimizer_P, metrics=None)

 model.train(1, trainset, callbacks=cb)

出现如下报错信息:

Traceback (most recent call last):

  File "e:\PythonProject\DDAG_mindspore\train_ddag.py", line 284, in <module>

    model.train(1, trainset, callbacks=cb)

  File "D:\ProgramData\Anaconda3\envs\mindspore_cpu\lib\site-packages\mindspore\train\model.py", line 

578, in train

    dataset_size = train_dataset.get_dataset_size()

  File "D:\ProgramData\Anaconda3\envs\mindspore_cpu\lib\site-packages\mindspore\dataset\engine\datasets.py", line 1455, in get_dataset_size

    runtime_getter = self._init_size_getter()

  File "D:\ProgramData\Anaconda3\envs\mindspore_cpu\lib\site-packages\mindspore\dataset\engine\datasets.py", line 1400, in _init_size_getter

    ir_tree, api_tree = self.create_ir_tree()

  File "D:\ProgramData\Anaconda3\envs\mindspore_cpu\lib\site-packages\mindspore\dataset\engine\datasets.py", line 157, in create_ir_tree

    dataset = copy.deepcopy(self)

  File "D:\ProgramData\Anaconda3\envs\mindspore_cpu\lib\copy.py", line 161, in deepcopy

    y = copier(memo)

  File "D:\ProgramData\Anaconda3\envs\mindspore_cpu\lib\site-packages\mindspore\dataset\engine\datasets.py", line 2367, in __deepcopy__

    new_op.children = copy.deepcopy(self.children, memodict)

    y = copier(x, memo)

  File "D:\ProgramData\Anaconda3\envs\mindspore_cpu\lib\copy.py", line 215, in _deepcopy_list

    append(deepcopy(a, memo))

  File "D:\ProgramData\Anaconda3\envs\mindspore_cpu\lib\copy.py", line 161, in deepcopy

    y = copier(memo)

  File "D:\ProgramData\Anaconda3\envs\mindspore_cpu\lib\site-packages\mindspore\dataset\engine\datasets.py", line 3832, in __deepcopy__

    sampler_instance = new_op.sampler.create()

  File "D:\ProgramData\Anaconda3\envs\mindspore_cpu\lib\site-packages\mindspore\dataset\engine\samplers.py", line 90, in create

    c_child_sampler = self.create_child()

  File "D:\ProgramData\Anaconda3\envs\mindspore_cpu\lib\site-packages\mindspore\dataset\engine\samplers.py", line 102, in create_child

    if self.child_sampler is not None:

AttributeError: 'IdentitySampler' object has no attribute 'child_sampler'

原因分析:

用户自定义的 IdentitySampler 虽然继承了父类 ds.sampler 但是在第16行的构造函数__init__中没有显式调用父类的构造函数,导致丢失父类中的 child_sampler 成员。

MindSpore中支持链式采样器,即多个采样器通过 child_sampler 串联起来形成一条链(chain)。

解决办法:

在自定义的 IdentitySampler 采样器中第17行 init 中使用 super().init() 调用父类的构造函数。