MindSpore报错TypeError: For 'CellList', each cell should be subclass of Cell, but got NoneType.

1 报错描述

1.1 系统环境

Hardware Environment(Ascend/GPU/CPU): Ascend Software
Environment: -- MindSpore version (source or binary): 1.8.0
Python version (e.g., Python 3.7.5): 3.7.6
OS platform and distribution (e.g., Linux Ubuntu 16.04): Ubuntu 4.15.0-74-generic -- GCC/Compiler version (if compiled from source):

1.2 基本信息

1.2.1 脚本

训练脚本是通过构建CellList的单算子网络,实现cell列表容器。脚本如下:

from mindspore import nn, Tensor
import numpy as np


class ListNoneExample(nn.Cell):
    def __init__(self):
        super(ListNoneExample, self).__init__()
        self.lst = nn.CellList([nn.ReLU(), None, nn.ReLU()])

    def construct(self, x):
        output = []
        for op in self.lst:
            output.append(op(x))
        return output

input = Tensor(np.random.normal(0, 2, (2, 1)).astype(np.float32))
example = ListNoneExample()
output = example(input)
print(output)

1.2.2 报错

这里报错信息如下:

Traceback (most recent call last):
  File "C:/Users/user1/PycharmProjects/q2_map/new/I3OGVW.py", line 31, in <module>
    example = ListNoneExample()
  File "C:/Users/user1/PycharmProjects/q2_map/new/I3OGVW.py", line 19, in __init__
    self.lst = nn.CellList([nn.ReLU(), None, nn.ReLU()])
  File "C:\Users\user1\PycharmProjects\q2_map\lib\site-packages\mindspore\nn\layer\container.py", line 310, in __init__
    self.extend(args[0])
  File "C:\Users\user1\PycharmProjects\q2_map\lib\site-packages\mindspore\nn\layer\container.py", line 405, in extend
    if _valid_cell(cell, cls_name):
  File "C:\Users\user1\PycharmProjects\q2_map\lib\site-packages\mindspore\nn\layer\container.py", line 39, in _valid_cell
    raise TypeError(f'{msg_prefix} each cell should be subclass of Cell, but got {type(cell).__name__}.')
TypeError: For 'CellList', each cell should be subclass of Cell, but got NoneType.

2 原因分析

我们看报错信息,在TypeError中,写到For ‘CellList’, each cell should be subclass of Cell, but got NoneType. ,意思是对于CellList这个算子, 传入的每一个cell都因该是nn.Cell的子类, 但是得到了None类型。检查网络中初始化CellList的行为第4行, 发现传入了一个None, 因此报错。为了解决这个问题, 只需把这里的None换成一个继承于基类Cell类的对象, 就能实现相同的功能。

3 解决方法

基于上面已知的原因,很容易做出如下修改:

from mindspore import nn, Tensor
import numpy as np


class NoneCell(nn.cell):
    def __init__(self):
        super(NoneCell, self).__init__()

    def construct(self, x):
        return x


class ListNoneExample(nn.Cell):
    def __init__(self):
        super(ListNoneExample, self).__init__()
        self.lst = nn.CellList([nn.ReLU(), NoneCell(), nn.ReLU()])

    def construct(self, x):
        output = []
        for op in self.lst:
            output.append(op(x))
        return output

input = Tensor(np.random.normal(0, 2, (2, 1)).astype(np.float32))
example = ListNoneExample()
output = example(input)
print(output)

此时执行成功,输出如下:

Output: (Tensor(shape=[2, 1], dtype=Float32, value=
[[1.09826946e+000],
 [0.00000000e+000]]), Tensor(shape=[2, 1], dtype=Float32, value=
[[1.09826946e+000],
 [-2.74355006e+000]]), Tensor(shape=[2, 1], dtype=Float32, value=
[[1.09826946e+000],
 [0.00000000e+000]])) 

4 总结

定位报错问题的步骤:
1、找到报错的用户代码行: self.lst = nn.CellList([nn.ReLU(), None, nn.ReLU()]) ;
2、 根据日志报错信息中的关键字,缩小分析问题的范围 each cell should be subclass of Cell, but got NoneType ;
3、需要重点关注变量定义、初始化的正确性。

5 参考文档

CellList算子API接口