MindSpore调用Dataset.batch()中per_batch_map函数出错

1 系统环境

硬件环境(Ascend/GPU/CPU): Ascend/GPU/CPU
MindSpore版本: mindspore=2.2.10
执行模式(PyNative/ Graph):不限
Python版本: Python=3.8
操作系统平台: linux

2 报错信息

2.1 问题描述

当设置了per_batch_map字段,batch_size=1时报错,其中can_collator为:

def can_ collator(batch, BatchInfo):
    """
    batch: [
        image [batch_size, channel, maxHinbatch, maxwinbatch]
        image_mask [batch_size, channel, maxHinbatch, maxwinbatch]
        label [batch_size, maxLabelLen]
        label mask [batch size, maxLabelLen]
        ...
    ]
    """
    max_width, max_height, max_length =0,0,0
    bs, channel = len(batch), batch[0][0].shape[0]
    proper_items = [ ]
    for item in batch:
        if(
                item[0]. shape[1] * max width > 1600 * 320
                or item[0].shape[2] * max_height > 1600 * 320
        ):
            continue
        max_ height =
            item[0].shape[1] if item[ø].shape[1] > max_height else max_height
        )
        max_width = item[0].shape[2] if item[ø].shape[2] > max_width else max_width
        max_length = len(item[1]) if len(item[1]) > max_length else max_length
        proper_items.append(item)
    images = ops.zeros((len(proper_items) , channel, max_height, max_width), dtype=ms. float32
    image_masks = ops.zeros(( len(proper_items), 1, max_height, max_width), dtype=ms.float32
    labels = ops.zeros((len(proper_i tems), max_length), dtype=ms.int64
    label_masks = ops.zeros(( len(proper_items), max_length), dtype=ms. int64)

    for i in range(len(proper_items):
        _, h, w = proper_items[i][0].shape
        images[i][:, :h, :w] = proper_items[i][0]
        image_masks[i][:, :h, :w] = 1
        l = len(proper_items[i][1])
        labels[i][:1] = proper_items[i][1]
        label_masks[i][:1] = 1
    return images, image_masks, labels, label_masks

2.2 脚本信息

dataloader = ds.batch(
    batch_size,
    drop_remainder=drop_remainder,
    num_parallel_workers=min(
        num_workers, 2
    ), # set small workers for lite computation. ToDo: increase for batch-wise mapping
    # input_columns=["image","label"],
    output_columns=["images", "image_masks", "labels", "label_masks"],
    per_batch_map=can_collator , # uncomment to use inner-batch transformation
)

return dataloader

2.3 报错信息

size=8
col_name-['images', 'image_masks', 'labels', 'label_masks']
Traceback (most recent call last):
    File "/home/ma-user/work/ prtest/mindocr/tests/ut/test_can_dataset.py", line 111, in <module>
        for data in data_ loader.create_dict_iterator():
    File "/home/ma-user/ anaconda3/envs/Mindspore/lib/python3.7/site-packages/mindspore/dataset/engine/iterators.py", line 145, in __next_.
        data = self._get_next()
    File "/home/ma-user/anaconda3/envs/Mindspore/lib/python3.7/site-packages/mindspore/dataset/engine/ iterators .py", line 270, in _get_next
        raise err
    File "/home/ma-user/anaconda3/envs/Mindspore/1ib/python3.7/site-packages/mindspore/dataset/engine/iterators .py", line 253, in _get_next
        return {k: self._transform md _to_output(t) for k, t in self._iterator •GetNextAsMap(). items()}
RuntimeError: Exception thrown from user defined Python function in dataset.
Python Call stack:

TypeError: Traceback (most recent call last):
    File "/home/ma-user/anaconda3/envs/Mindspore/1ib/python3.7/site-packages/mindspore/dataset/transforms/py_transforms_util.py", line 198, in __cal1
        result = self.transform(*args)
TypeError: can_ collator() takes2 positional arguments but 3were given
Dataset Pipeline Error Message:
[ERROR] Execute user Python code failed, check 'Python Cal1 stack' above.复制

3 根因分析

从报错来看TypeError: can_ collator() takes2 positional arguments but 3were given
因为输入是有两列,分别是image和label,因此传到can_collator函数分别是image列数据,label列数据, 和BatchInfo。
并不是上面理解的整个输入作为入参。
具体参考api说明

per_batch_map (Callable[[List[numpy.ndarray], …, List[numpy.ndarray], BatchInfo], (List[numpy.ndarray],…, List[numpy.ndarray])], 可选) - 可调用对象, 以(list[numpy.ndarray], …, list[numpy.ndarray], BatchInfo)作为输入参数, 处理后返回(list[numpy.ndarray], list[numpy.ndarray],…)作为新的数据列。输入参数中每个list[numpy.ndarray]代表给定数据列中的一批numpy.ndarray, list[numpy.ndarray]的个数应与 input_columns 中传入列名的数量相匹配,在返回的(list[numpy.ndarray], list[numpy.ndarray], …)中, list[numpy.ndarray]的个数应与输入相同,如果输出列数与输入列数不一致,则需要指定 output_columns 。该可调用对象的最后一个输入参数始终是BatchInfo, 用于获取数据集的信息,用法参考样例(2)。

(list[numpy.ndarray], …, list[numpy.ndarray], BatchInfo)作为输入参数, 处理后返回(list[numpy.ndarray], list[numpy.ndarray],…)作为新的数据

4 解决方案

修改can_collator的入参定义如下。
def can_collator(image,label , BatchInfo):
内部的代码也要按照入参来修改。