MindSpore训练时报错:TypeError: For ‘MatMul’, the input data must be float16, float32, uint16 but got int32

1 系统环境

硬件环境(Ascend/GPU/CPU): Ascend/GPU/CPU
MindSpore版本: MindSpore=2.6.0
执行模式(PyNative/ Graph): 不限
Python版本: Python=3.8
操作系统平台: linux

2 报错信息

2.1 问题描述

在用MindSpore训练一个简单的全连接网络时遇到如下报错:

TypeError: For 'MatMul', the input data must be float16, float32, uint16 but got int32

  1. 为什么会出现这个类型错误?

  2. 应该如何正确设置Tensor的数据类型?

2.2 脚本信息

import mindspore.nn as nn
from mindspore import Tensor

# 网络定义
class SimpleNet(nn.Cell):
    def __init__(self):
        super().__init__()
        self.fc = nn.Dense(10, 5) # 输入10维,输出5维
    def construct(self, x):
        return self.fc(x)

# 数据准备
x = Tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) # 输入数据
net = SimpleNet()
output = net(x) # 这里报错

3 根因分析

查看api mindspore.nn.Dense | MindSpore 2.7.0 文档 | 昇思MindSpore社区

classmindspore.nn.Dense(in_channels, out_channels, weight_init=None, bias_init=None, has_bias=True, activation=None, dtype=mstype.float32)

  • weight_init (Union[Tensor, str, Initializer, numbers.Number],可选) - 权重参数的初始化方法。数据类型与 x 相同。str的值引用自函数 mindspore.common.initializer.initializer()。默认值: None ,权重使用HeUniform初始化。

  • has_bias (bool,可选) - 是否使用偏置向量 bias 。默认值: True

  • dtype (mindspore.dtype,可选) - Parameter的数据类型。默认值: mstype.float32 。 当 weight_init 是Tensor时,Parameter的数据类型与 weight_init 的数据类型一致,其他情况Parameter的数据类型跟 dtype 一致, bias_init 同理。

可以看出该算子有默认的权重和偏置, 另外dtype默认float32.

文档说明如下:

适用于输入的密集连接层。公式如下:

outputs=activation(X∗kernel+bias),

其中 X 是输入Tensor, activation 是激活函数, kernel 是一个权重矩阵,其数据类型与 X 相同, bias 是一个偏置向量,其数据类型与 X 相同(仅当 has_bias 为 True 时)。

代码里面x = Tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]),不设置dtype的情况下根据输入,类型会被设置为int64.因此会报错.

4 解决方案

x设置类型为float32. x = Tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10],dtype=ms.float32)

代码如下:

import mindspore.nn as nn
from mindspore import Tensor
import mindspore as ms
# 网络定义
class SimpleNet(nn.Cell):
    def __init__(self):
        super().__init__()
        self.fc = nn.Dense(10, 5) # 输入10维,输出5维
    def construct(self, x):
        return self.fc(x)

# 数据准备
x = Tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10],dtype=ms.float32) # 输入数据
net = SimpleNet()
output = net(x) # 这里报错