使用MindSpore的ops中的矩阵相乘算子进行int8的相乘运算时报错

1 系统环境

硬件环境(Ascend/GPU/CPU): Ascend
MindSpore版本: mindspore=2.2.10
执行模式(动态图):GRAPH/PYNATIVE
Python版本: Python=3.7.5
操作系统平台: linux

2 报错信息

使用mindspore2.2的ops中的矩阵相乘算子进行int8的相乘运算时,出现以下报错

TypeError: For primitive[MatMul], the input argument must be a type of {Tensor[BFloat16], Tensor[Float16], Tensor[Float32], Tensor[Int32], Tensor[Int64], Tensor[UInt8]}, but got Tensor[Int8].

2.1 问题描述

ascend上matmul接口不支持int8 , 使用int8计算时报错

2.2 脚本代码

import mindspore as ms 
import mindspore.ops as ops 
import numpy as np 

x = ms.Tensor(np.ones([1,3]).astype(np.int8)) 
y = ms.Tensor(np.ones([3,1]).astype(np.int8)) 
op = ops.matmul 
z = op(x,y) 
print(z)

3 根因分析

ascend上ops.matmul 不支持int8, 计算时会报错,目前2.2 版本ascend上只支持这些类型{Tensor[BFloat16], Tensor[Float16], Tensor[Float32], Tensor[Int32], Tensor[Int64], Tensor[UInt8]}

4 解决方案

修改输入dtype为支持的数据类型,Tensor[BFloat16], Tensor[Float16], Tensor[Float32], Tensor[Int32], Tensor[Int64], Tensor[UInt8] 这些都可以。

下面是以int32为例:

import mindspore as ms 
import mindspore.ops as ops 
import numpy as np 

x = ms.Tensor(np.ones([1,3]).astype(np.int32)) 
y = ms.Tensor(np.ones([3,1]).astype(np.int32)) 
op = ops.matmul 
z = op(x,y) 
print(z) 
[[3]]