MindSpore报错RuntimeError: The 'getitem' operation does not support the type [Func, Int64].

1 系统环境

硬件环境(Ascend/GPU/CPU): CPU
MindSpore版本: 1.7.0
执行模式(PyNative/ Graph): Graph
Python版本: Python=3.9.15
操作系统平台: 不限

2 报错信息

2.1 问题描述

静态图模式下的construct函数中对SequentialCell通过下标访问元素出错,“SequentialCell[-1]” 语法不支持。

2.2 报错信息

RuntimeError: The 'getitem' operation does not support the type [Func, Int64].

2.3 脚本代码

from mindspore import nn  
from mindspore.nn import Cell  
from mindspore import context  
import mindspore  
from mindspore.ops import operations as op  
import numpy as np  
from mindspore import Tensor  
context.set_context(mode=context.GRAPH_MODE)  
    
    
class Network(nn.Cell):  
    def __init__(self):  
        super().__init__()  
        self.flatten = nn.Flatten()  
        self.dense_relu_sequential = nn.SequentialCell(  
            nn.Dense(28*28, 512),  
            nn.ReLU(),  
            nn.Dense(512, 512),  
            nn.ReLU(),  
            nn.Dense(512, 10)  
        )  
        self.print = op.Print()  
    
    def construct(self, x):  
        x = self.flatten(x)  
        logits = self.dense_relu_sequential[-1](x)  
        return logits  
    
X = Tensor(np.ones((1, 16, 32), np.float32))  
model = Network()  
logits = model(X)  

3 根因分析

静态图模式不支持SequentialCell的getitem操作。如果需要执行getitem操作,建议使用nn.CellList。

4 解决方案

将SequentialCell换成CellList即可;
CellList可以像普通Python列表一样使用,其包含的Cell均已初始化。

from mindspore import nn  
from mindspore.nn import Cell  
from mindspore import context  
import mindspore  
from mindspore.ops import operations as op  
import numpy as np  
from mindspore import Tensor  
context.set_context(mode=context.GRAPH_MODE)  
    
    
class Network(nn.Cell):  
    def __init__(self):  
        super().__init__()  
        self.flatten = nn.Flatten()  
        self.dense_relu_sequential = nn.CellList(  
            [nn.Dense(28*28, 512),  
            nn.ReLU(),  
            nn.Dense(512, 512),  
            nn.ReLU(),  
            nn.Dense(512, 10)]  
        )  
        self.print = op.Print()  
    
    def construct(self, x):  
        x = self.flatten(x)  
        logits = self.dense_relu_sequential[-1](x)  
        return logits  
    
X = Tensor(np.ones((1, 16, 32), np.float32))  
model = Network()  
logits = model(X)  

执行无报错
cke_46095.png
cke_49134.png