1 报错描述
1.1 系统环境
Hardware Environment(Ascend/GPU/CPU): Ascend Software
Environment: -- MindSpore version (source or binary): 1.8.0
Python version (e.g., Python 3.7.5): 3.7.6
OS platform and distribution (e.g., Linux Ubuntu 16.04): Ubuntu 4.15.0-74-generic -- GCC/Compiler version (if compiled from source):
1.2 基本信息
1.2.1 脚本
训练脚本是通过构建CellList的单算子网络,实现cell列表容器。脚本如下:
from mindspore import nn, Tensor
import numpy as np
class ListNoneExample(nn.Cell):
def __init__(self):
super(ListNoneExample, self).__init__()
self.lst = nn.CellList([nn.ReLU(), None, nn.ReLU()])
def construct(self, x):
output = []
for op in self.lst:
output.append(op(x))
return output
input = Tensor(np.random.normal(0, 2, (2, 1)).astype(np.float32))
example = ListNoneExample()
output = example(input)
print(output)
1.2.2 报错
这里报错信息如下:
Traceback (most recent call last):
File "C:/Users/user1/PycharmProjects/q2_map/new/I3OGVW.py", line 31, in <module>
example = ListNoneExample()
File "C:/Users/user1/PycharmProjects/q2_map/new/I3OGVW.py", line 19, in __init__
self.lst = nn.CellList([nn.ReLU(), None, nn.ReLU()])
File "C:\Users\user1\PycharmProjects\q2_map\lib\site-packages\mindspore\nn\layer\container.py", line 310, in __init__
self.extend(args[0])
File "C:\Users\user1\PycharmProjects\q2_map\lib\site-packages\mindspore\nn\layer\container.py", line 405, in extend
if _valid_cell(cell, cls_name):
File "C:\Users\user1\PycharmProjects\q2_map\lib\site-packages\mindspore\nn\layer\container.py", line 39, in _valid_cell
raise TypeError(f'{msg_prefix} each cell should be subclass of Cell, but got {type(cell).__name__}.')
TypeError: For 'CellList', each cell should be subclass of Cell, but got NoneType.
2 原因分析
我们看报错信息,在TypeError中,写到For ‘CellList’, each cell should be subclass of Cell, but got NoneType. ,意思是对于CellList这个算子, 传入的每一个cell都因该是nn.Cell的子类, 但是得到了None类型。检查网络中初始化CellList的行为第4行, 发现传入了一个None, 因此报错。为了解决这个问题, 只需把这里的None换成一个继承于基类Cell类的对象, 就能实现相同的功能。
3 解决方法
基于上面已知的原因,很容易做出如下修改:
from mindspore import nn, Tensor
import numpy as np
class NoneCell(nn.cell):
def __init__(self):
super(NoneCell, self).__init__()
def construct(self, x):
return x
class ListNoneExample(nn.Cell):
def __init__(self):
super(ListNoneExample, self).__init__()
self.lst = nn.CellList([nn.ReLU(), NoneCell(), nn.ReLU()])
def construct(self, x):
output = []
for op in self.lst:
output.append(op(x))
return output
input = Tensor(np.random.normal(0, 2, (2, 1)).astype(np.float32))
example = ListNoneExample()
output = example(input)
print(output)
此时执行成功,输出如下:
Output: (Tensor(shape=[2, 1], dtype=Float32, value=
[[1.09826946e+000],
[0.00000000e+000]]), Tensor(shape=[2, 1], dtype=Float32, value=
[[1.09826946e+000],
[-2.74355006e+000]]), Tensor(shape=[2, 1], dtype=Float32, value=
[[1.09826946e+000],
[0.00000000e+000]]))
4 总结
定位报错问题的步骤:
1、找到报错的用户代码行: self.lst = nn.CellList([nn.ReLU(), None, nn.ReLU()]) ;
2、 根据日志报错信息中的关键字,缩小分析问题的范围 each cell should be subclass of Cell, but got NoneType ;
3、需要重点关注变量定义、初始化的正确性。