使用Mindspore的embedding报错

1. 系统环境

硬件环境(Ascend/GPU/CPU): GPU/CPU
MindSpore版本: mindspore=2.2.0
执行模式(PyNative/ Graph):不限
Python版本: Python=3.8
操作系统平台: 不限

2. 报错信息

2.1 问题描述

将pytorch的代码转换为mindspore,其中在embedding层遇到如下代码:

  import mindspore
  from mindspore import Tensor, nn
  import numpy as np
  net = nn.Embedding(110, 8)
  x = Tensor(np.array([[107] *50]), mindspore. int64)
  print(x)
  # Maps the input wordIDs to word embedding.
  output = net(x)
  result = output.shape
  print(net)
  print(result)

其中input_data为int64类型,torch的embedding能够处理,使用mindspore将代码修改为:

net = nn.Embedding(vocab_size=self.hidden_size, embedding_size=self.embedding_size,dtype=ms.int64) 
input_data = net(input_data)

发现定义出来的net的dtype仍为float32,导致了类型错误。

2.2 报错信息

Traceback (most recent cal1 last):
File "main. py", line 206, in Kmodule)    main0
File "main. py", line 200, in main
ppo_ori. feature_search ()
File "/root/Catch-master/ppo_ori. py", line 96, in feature_search
sample history = self. controller. sample 0
File "/root/Catch master/control ler_alpha. py", line 446, in sample
res h_c_t_list, ops_logits, otp_logits = self. construct(
File "/root/Catch master/control ler_alpha. py", line 218, in construct
input_data = self. embedding (input_data)
File "/root/miniconda3/1ib/python3. 8/si te packages/mindspore/nn/cel1. py", line 705, in _cal1
raise err
File "/root/miniconda3/1ib/python3. 8/si te packages/mindspore/nn/cel1. py", line 701, in_call
output = self. _run_construct (args, kwargs)
File "/root/miniconda3/1ib/python3. 8/site-packages/mindspore/nn/ce11. py", line 482, in _run_construct
output = self. construct (*cast_inputs, **kwargs)
File "/root/miniconda3/1ib/python3.8/si te packages/mindspore/nn/layer/embedding. py", line 137, in constru
out_shape = self. get_shp(ids) + (self. embedding_size,)     File "/root/miniconda3/1ib/python3. 8/si te packages/mindspore/ops/operat ions/array_ops. py", line 701, in.
return x. shape
File "/root/miniconda3/1ib/python3.8/si te packages/mindspore/common/_stub_tensor. py", line 85, in shape
self. stub_shape = self. stub. get_shape()
TypeError: Invalid dtype

3. 根因分析

查看API文档

class mindspore.nn.Embedding(vocab_size , embedding_size , use_one_hot=False , embedding_table=‘normal’ , dtype=mstype.float32 , padding_idx=None)

  • vocab_size (int) - 词典的大小。
  • embedding_size (int) - 每个嵌入向量的大小。
  • use_one_hot (bool) - 指定是否使用one-hot形式。默认值: False
  • embedding_table (Union[Tensor, str, Initializer, numbers.Number]) - embedding_table的初始化方法。当指定为字符串,字符串取值请参见类 mindspore.common.initializer 。默认值: "normal"
  • dtype (mindspore.dtype) - x的数据类型。默认值: mstype.float32

输入:

  • x (Tensor) - Tensor的shape为 (batch_size,x_length) ,其元素为整型值,并且元素数目必须小于等于vocab_size,否则相应的嵌入向量将为零。该数据类型可以是int32或int64。
    输出:
    Tensor的shape (batch_size,x_length,embedding_size) 。
    可以看出dtype 的默认类型是mstype.float32,输入x元素为整型值 int32或int64.
    这个地方就有个矛盾。输入既然是整型,默认值怎么还能是float32
    输出的类型是float32

4. 解决方案

修改dtype为int64

import mindspore  
from mindspore import Tensor, nn  
import numpy as np  
net = nn.Embedding(110, 8, dtype=mindspore.int64)  
x = Tensor(np.array([[107] *50]), mindspore.int64)  
print(x)  
# Maps the input wordIDs to word embedding.  
output = net(x)  
result = output.shape  
print(net)  
print(result)