MindSpore报错: module() takes at most 2 arguments (3 given)

1 报错描述

1.1 系统环境

Environment(Ascend/GPU/CPU): CPU
Software Environment:
– MindSpore version (source or binary): 1.8.0
– Python version (e.g., Python 3.7.5): 3.7.5
– OS platform and distribution (e.g., Linux Ubuntu 16.04): win10 
– CUDA version : NA

1.2 基本信息

1.2.1脚本

如下所示,定义一个基于基类nn.cell的网络, 实现一个简单的功能:展平输入的Tensor数据。

from mindspore import nn
from mindspore import set_context
import mindspore as ms
    
set_context(mode=ms.PYNATIVE_MODE)
class Net(nn.cell):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
    
    def construct(self, x):
        return self.flatten(x)
    
def test_compile_cell():
    net = Net()
    print("network:")
    print(net)
    
if __name__ == "__main__":
    test_compile_cell()

1.2.2 报错信息

网络发生如下报错。其实,注释掉def test_compile_cell方法及其调用, 执行脚本依然报错。意味着Net(nn.cell) 的定义就已经发生了错误。

Traceback (most recent call last):
  File "test_compiler_cell.py", line 9, in <module>
    class Net(nn.cell):
TypeError: module() takes at most 2 arguments (3 given)

2 原因分析及解决方法

此处想要导入类,如上代码所示只是导入了模块:mindspore/python/mindspore/nn/cell.py,Python的模块名与类名是在两个不同的名字空间中,初学者很容易将其弄混淆。

  • python 类 用来描述具有相同的属性和方法的对象的集合。它定义了该集合中每个对象所共有的属性和方法。对象是类的实例

  • python 模块 模块,在Python可理解为对应于一个文件。

根据上面代码,想使用nn.cell作为基类定义网络,但其实际上是一个模块,所以错误。 真正的基类应该是nn.Cell,即mindspore/python/mindspore/nn/cell.py文件中定义的python类对象,它是MindSpore中神经网络的基本构成单元。

修改后的网络定义:

class Net(nn.Cell):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
    
    def construct(self, x):
        return self.flatten(x)

参考 【1】 mindspore.nn.Cell。 【2】 python 模块与类。