1 系统环境
硬件环境(Ascend/GPU/CPU): Ascend/GPU/CPU
MindSpore版本: mindspore=2.2.10
执行模式(PyNative/ Graph):不限
Python版本: Python=3.8
操作系统平台: linux
2 报错信息
2.1 问题描述
当设置了per_batch_map字段,batch_size=1时报错,其中can_collator为:
def can_ collator(batch, BatchInfo):
"""
batch: [
image [batch_size, channel, maxHinbatch, maxwinbatch]
image_mask [batch_size, channel, maxHinbatch, maxwinbatch]
label [batch_size, maxLabelLen]
label mask [batch size, maxLabelLen]
...
]
"""
max_width, max_height, max_length =0,0,0
bs, channel = len(batch), batch[0][0].shape[0]
proper_items = [ ]
for item in batch:
if(
item[0]. shape[1] * max width > 1600 * 320
or item[0].shape[2] * max_height > 1600 * 320
):
continue
max_ height =
item[0].shape[1] if item[ø].shape[1] > max_height else max_height
)
max_width = item[0].shape[2] if item[ø].shape[2] > max_width else max_width
max_length = len(item[1]) if len(item[1]) > max_length else max_length
proper_items.append(item)
images = ops.zeros((len(proper_items) , channel, max_height, max_width), dtype=ms. float32
image_masks = ops.zeros(( len(proper_items), 1, max_height, max_width), dtype=ms.float32
labels = ops.zeros((len(proper_i tems), max_length), dtype=ms.int64
label_masks = ops.zeros(( len(proper_items), max_length), dtype=ms. int64)
for i in range(len(proper_items):
_, h, w = proper_items[i][0].shape
images[i][:, :h, :w] = proper_items[i][0]
image_masks[i][:, :h, :w] = 1
l = len(proper_items[i][1])
labels[i][:1] = proper_items[i][1]
label_masks[i][:1] = 1
return images, image_masks, labels, label_masks
2.2 脚本信息
dataloader = ds.batch(
batch_size,
drop_remainder=drop_remainder,
num_parallel_workers=min(
num_workers, 2
), # set small workers for lite computation. ToDo: increase for batch-wise mapping
# input_columns=["image","label"],
output_columns=["images", "image_masks", "labels", "label_masks"],
per_batch_map=can_collator , # uncomment to use inner-batch transformation
)
return dataloader
2.3 报错信息
size=8
col_name-['images', 'image_masks', 'labels', 'label_masks']
Traceback (most recent call last):
File "/home/ma-user/work/ prtest/mindocr/tests/ut/test_can_dataset.py", line 111, in <module>
for data in data_ loader.create_dict_iterator():
File "/home/ma-user/ anaconda3/envs/Mindspore/lib/python3.7/site-packages/mindspore/dataset/engine/iterators.py", line 145, in __next_.
data = self._get_next()
File "/home/ma-user/anaconda3/envs/Mindspore/lib/python3.7/site-packages/mindspore/dataset/engine/ iterators .py", line 270, in _get_next
raise err
File "/home/ma-user/anaconda3/envs/Mindspore/1ib/python3.7/site-packages/mindspore/dataset/engine/iterators .py", line 253, in _get_next
return {k: self._transform md _to_output(t) for k, t in self._iterator •GetNextAsMap(). items()}
RuntimeError: Exception thrown from user defined Python function in dataset.
Python Call stack:
TypeError: Traceback (most recent call last):
File "/home/ma-user/anaconda3/envs/Mindspore/1ib/python3.7/site-packages/mindspore/dataset/transforms/py_transforms_util.py", line 198, in __cal1
result = self.transform(*args)
TypeError: can_ collator() takes2 positional arguments but 3were given
Dataset Pipeline Error Message:
[ERROR] Execute user Python code failed, check 'Python Cal1 stack' above.复制
3 根因分析
从报错来看TypeError: can_ collator() takes2 positional arguments but 3were given
因为输入是有两列,分别是image和label,因此传到can_collator函数分别是image列数据,label列数据, 和BatchInfo。
并不是上面理解的整个输入作为入参。
具体参考api说明
per_batch_map (Callable[[List[numpy.ndarray], …, List[numpy.ndarray], BatchInfo], (List[numpy.ndarray],…, List[numpy.ndarray])], 可选) - 可调用对象, 以(list[numpy.ndarray], …, list[numpy.ndarray], BatchInfo)作为输入参数, 处理后返回(list[numpy.ndarray], list[numpy.ndarray],…)作为新的数据列。输入参数中每个list[numpy.ndarray]代表给定数据列中的一批numpy.ndarray, list[numpy.ndarray]的个数应与 input_columns 中传入列名的数量相匹配,在返回的(list[numpy.ndarray], list[numpy.ndarray], …)中, list[numpy.ndarray]的个数应与输入相同,如果输出列数与输入列数不一致,则需要指定 output_columns 。该可调用对象的最后一个输入参数始终是BatchInfo, 用于获取数据集的信息,用法参考样例(2)。
以(list[numpy.ndarray], …, list[numpy.ndarray], BatchInfo)
作为输入参数, 处理后返回(list[numpy.ndarray], list[numpy.ndarray],…)
作为新的数据
4 解决方案
修改can_collator的入参定义如下。
def can_collator(image,label , BatchInfo):
内部的代码也要按照入参来修改。