Mindspore除了提供如下各类Sampler,也支持用户自定义Sampler进行自定义的采样操作,具体可以参考:自定义采样器
Mindspore会依据自定义Sampler中__iter__返回的索引值对样本进行采样。
在CPU + MindSpore1.1.1进行定义如下自定义Sampler:
1 class IdentitySampler(ds.Sampler):
2
3 """Sample person identities evenly in each batch.
4
5 Args:
6
7 train_color_label, train_thermal_label: labels of two modalities
8
9 color_pos, thermal_pos: positions of each identity
10
11 batchSize: batch size
12
13 """
14
15
16 def __init__(self, train_color_label, train_thermal_label, color_pos, thermal_pos, num_pos, batchSize, epoch):
17
18 uni_label = np.unique(train_color_label)
19
20 self.n_classes = len(uni_label)
21
22 N = np.maximum(len(train_color_label), len(train_thermal_label))
23
24 for j in range(int(N/(batchSize*num_pos))+1):
25
26 batch_idx = np.random.choice(uni_label, batchSize, replace = False)
27
28 for i in range(batchSize):
29
30 sample_color = np.random.choice(color_pos[batch_idx], num_pos)
31
32 sample_thermal = np.random.choice(thermal_pos[batch_idx], num_pos)
33
34 if j ==0 and i==0:
35
36 index1= sample_color
37
38 index2= sample_thermal
39
40 else:
41
42 index1 = np.hstack((index1, sample_color))
43
44 index2 = np.hstack((index2, sample_thermal))
45
46 self.index1 = index1
47
48 self.index2 = index2
49
50 self.N = N
51
52 self.num_samples = N
53
54 def __iter__(self):
55
56 # return iter(np.arange(len(self.index1)))
57
58 for i in range(len(self.index1)):
59
60 yield i
61
62 def __len__(self):
63
64 return self.N
在GeneratorDataset中使用自定义的Sampler,具体如下:
trainset = ds.GeneratorDataset(trainset_generator, ["color", "thermal","color_label", "thermal_label"], sampler=sampler).map(
operations=transform_train, input_columns=["color", "thermal"]
)
......
model = Model(net, loss_fn=criterion1, optimizer=optimizer_P, metrics=None)
model.train(1, trainset, callbacks=cb)
出现如下报错信息:
Traceback (most recent call last):
File "e:\PythonProject\DDAG_mindspore\train_ddag.py", line 284, in <module>
model.train(1, trainset, callbacks=cb)
File "D:\ProgramData\Anaconda3\envs\mindspore_cpu\lib\site-packages\mindspore\train\model.py", line
578, in train
dataset_size = train_dataset.get_dataset_size()
File "D:\ProgramData\Anaconda3\envs\mindspore_cpu\lib\site-packages\mindspore\dataset\engine\datasets.py", line 1455, in get_dataset_size
runtime_getter = self._init_size_getter()
File "D:\ProgramData\Anaconda3\envs\mindspore_cpu\lib\site-packages\mindspore\dataset\engine\datasets.py", line 1400, in _init_size_getter
ir_tree, api_tree = self.create_ir_tree()
File "D:\ProgramData\Anaconda3\envs\mindspore_cpu\lib\site-packages\mindspore\dataset\engine\datasets.py", line 157, in create_ir_tree
dataset = copy.deepcopy(self)
File "D:\ProgramData\Anaconda3\envs\mindspore_cpu\lib\copy.py", line 161, in deepcopy
y = copier(memo)
File "D:\ProgramData\Anaconda3\envs\mindspore_cpu\lib\site-packages\mindspore\dataset\engine\datasets.py", line 2367, in __deepcopy__
new_op.children = copy.deepcopy(self.children, memodict)
y = copier(x, memo)
File "D:\ProgramData\Anaconda3\envs\mindspore_cpu\lib\copy.py", line 215, in _deepcopy_list
append(deepcopy(a, memo))
File "D:\ProgramData\Anaconda3\envs\mindspore_cpu\lib\copy.py", line 161, in deepcopy
y = copier(memo)
File "D:\ProgramData\Anaconda3\envs\mindspore_cpu\lib\site-packages\mindspore\dataset\engine\datasets.py", line 3832, in __deepcopy__
sampler_instance = new_op.sampler.create()
File "D:\ProgramData\Anaconda3\envs\mindspore_cpu\lib\site-packages\mindspore\dataset\engine\samplers.py", line 90, in create
c_child_sampler = self.create_child()
File "D:\ProgramData\Anaconda3\envs\mindspore_cpu\lib\site-packages\mindspore\dataset\engine\samplers.py", line 102, in create_child
if self.child_sampler is not None:
AttributeError: 'IdentitySampler' object has no attribute 'child_sampler'
原因分析:
用户自定义的 IdentitySampler 虽然继承了父类 ds.sampler 但是在第16行的构造函数__init__中没有显式调用父类的构造函数,导致丢失父类中的 child_sampler 成员。
MindSpore中支持链式采样器,即多个采样器通过 child_sampler 串联起来形成一条链(chain)。
解决办法:
在自定义的 IdentitySampler 采样器中第17行 init 中使用 super().init() 调用父类的构造函数。