使用dataset.create_dict_iterator()后,计算前向网络报错:untimeError: Illegal AnfNode for evaluating, node: @Batch

1 系统环境

硬件环境(Ascend/GPU/CPU): CPU
MindSpore版本: mindspore=1.10
执行模式(PyNative/ Graph):不限
Python版本: Python=3.8
操作系统平台: linux

2 报错信息

2.1 问题描述

将model(x)放在dataset.create_dict_iterator()遍历的外面,程序正确运行,代码如下:

def test(dataset,model,mask_ratio=0.75):  
    batch_loss = 0  
    model.set_train(False)  
    i = 0  
    x = mindspore.Tensor(np.randon.rand(1, 40, 300).astype(np. float32), dtype=nindspore.float32)  
    loss, pred, mask = model(x, mask_ratio)  
    for data in dataset.create_dict_iterator():  
        # input = data['data']  
        # x = mindspore.Tensor(np.random.rand(1, 40, 300).astype(np.fLoat32), dtype=mindspore.float32)  
        # loss, pred, mask = model(x, mask_ratio)  
        batch_loss += loss  
        i += 1  
return batch_loss/(i+1)

但是只要将model(x)放入dataset.create_dict_iterator()的遍历后,就会报错,即使x并不是dataset中的数据。

def test(dataset,model,mask_ratio=0.75):  
    batch_loss = 0  
    model.set_train(False)  
    i = 0  
    x = mindspore.Tensor(np.randon.rand(1, 40, 300).astype(np. float32), dtype=nindspore.float32)  
    # loss, pred, mask = model(x, mask_ratio)  
    for data in dataset.create_dict_iterator():  
        # input = data['data']  
        x = mindspore.Tensor(np.random.rand(1, 40, 300).astype(np.fLoat32), dtype=mindspore.float32)  
        loss, pred, mask = model(x, mask_ratio)  
        batch_loss += loss  
        i += 1  
    return batch_loss/(i+1)

2.2 报错信息

RuntimeError: Illegal AnfNode for evaluating.node: @BatchNorm.20:фdout(type:Parameter), fg: BatchNorm.20 conf: Node: 0x18b20bbdb00/@BatchNorm.20:фdout-uid(2115), Context: 0x18b20b4fc70/{FuncGraph: BatchNorm.20 Args: [0]: AbstractTensor(shape: (1, 32, 40, 99), element: AbstractScalar(Type: Float32, Value: AnyValue, Shape: NoShape), value_ptr: 0x18b1728ef90, value: AnyValue), [1]: AbstractTensor(shape: (32), element: AbstractScalar(Type: Float32, Value: AnyValue, Shape: NoShape), value_ptr: 0x18b1728ef90, value: AnyValue), [2]: AbstractTensor(shape: (32), element: AbstractScalar(Type: Float32, Value: AnyValue, Shape: NoShape), value_ptr: 0x18b1728ef90, value: AnyValue), [3]: AbstractTensor(shape: (32), element: AbstractScalar(Type: Float32, Value: AnyValue, Shape: NoShape), value_ptr: 0x18b1728ef90, value: AnyValue), [4]: AbstractTensor(shape: (32), element: AbstractScalar(Type: Float32, Value: AnyValue, Shape: NoShape), value_ptr: 0x18b1728ef90, value: AnyValue), [5]: AbstractTuple{element[0]: AbstractTensor(shape: (1, 32, 40, 99), element: AbstractScalar(Type: Float32, Value: AnyValue, Shape: NoShape), value_ptr: 0x18b1728ef90, value: AnyValue), element[1]: AbstractTensor(shape: (32), element: AbstractScalar(Type: Float32, Value: AnyValue, Shape: NoShape), value_ptr: 0x18b1728ef90, value: AnyValue), element[2]: AbstractTensor(shape: (32), element: AbstractScalar(Type: Float32, Value: AnyValue, Shape: NoShape), value_ptr: 0x18b1728ef90, value: AnyValue), element[3]: AbstractTensor(shape: (32), element: AbstractScalar(Type: Float32, Value: AnyValue, Shape: NoShape), value_ptr: 0x18b1728ef90, value: AnyValue), element[4]: AbstractTensor(shape: (32), element: AbstractScalar(Type: Float32, Value: AnyValue, Shape: NoShape), value_ptr: 0x18b1728ef90, value: AnyValue)}, [6]: AbstractTuple{element[0]: AbstractTensor(shape: (1, 32, 40, 99), element: AbstractScalar(Type: Float32, Value: AnyValue, Shape: NoShape), value_ptr: 0x18b1728ef90, value: AnyValue), element[1]: AbstractTensor(shape: (32), element: AbstractScalar(Type: Float32, Value: AnyValue, Shape: NoShape), value_ptr: 0x18b1728ef90, value: AnyValue), element[2]: AbstractTensor(shape: (32), element: AbstractScalar(Type: Float32, Value: AnyValue, Shape: NoShape), value_ptr: 0x18b1728ef90, value: AnyValue), element[3]: AbstractTensor(shape: (32), element: AbstractScalar(Type: Float32, Value: AnyValue, Shape: NoShape), value_ptr: 0x18b1728ef90, value: AnyValue), element[4]: AbstractTensor(shape: (32), element: AbstractScalar(Type: Float32, Value: AnyValue, Shape: NoShape), value_ptr: 0x18b1728ef90, value: AnyValue)}, Parent: { Args: }}, FuncGraph: 0x18b20bb0cb0/BatchNorm.20

3 根因分析

根据报错信息提示,这应该是数据的类型不符合计算。

4 解决方案

构造数据集类的时候定义了mindspore的tensor,换成np数组就可以了。