1 报错描述
1.1 系统环境
Environment(Ascend/GPU/CPU): CPU
Software Environment:
– MindSpore version (source or binary): 1.8.0
– Python version (e.g., Python 3.7.5): 3.7.5
– OS platform and distribution (e.g., Linux Ubuntu 16.04): win10
– CUDA version : NA
1.2 基本信息
1.2.1脚本
如下所示,定义一个基于基类nn.cell的网络, 实现一个简单的功能:展平输入的Tensor数据。
from mindspore import nn
from mindspore import set_context
import mindspore as ms
set_context(mode=ms.PYNATIVE_MODE)
class Net(nn.cell):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
def construct(self, x):
return self.flatten(x)
def test_compile_cell():
net = Net()
print("network:")
print(net)
if __name__ == "__main__":
test_compile_cell()
1.2.2 报错信息
网络发生如下报错。其实,注释掉def test_compile_cell方法及其调用, 执行脚本依然报错。意味着Net(nn.cell) 的定义就已经发生了错误。
Traceback (most recent call last):
File "test_compiler_cell.py", line 9, in <module>
class Net(nn.cell):
TypeError: module() takes at most 2 arguments (3 given)
2 原因分析及解决方法
此处想要导入类,如上代码所示只是导入了模块:mindspore/python/mindspore/nn/cell.py,Python的模块名与类名是在两个不同的名字空间中,初学者很容易将其弄混淆。
-
python 类 用来描述具有相同的属性和方法的对象的集合。它定义了该集合中每个对象所共有的属性和方法。对象是类的实例
-
python 模块 模块,在Python可理解为对应于一个文件。
根据上面代码,想使用nn.cell作为基类定义网络,但其实际上是一个模块,所以错误。 真正的基类应该是nn.Cell,即mindspore/python/mindspore/nn/cell.py文件中定义的python类对象,它是MindSpore中神经网络的基本构成单元。
修改后的网络定义:
class Net(nn.Cell):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
def construct(self, x):
return self.flatten(x)
参考 【1】 mindspore.nn.Cell。 【2】 python 模块与类。