1 系统环境
硬件环境(Ascend/GPU/CPU): Ascend/GPU/CPU
MindSpore版本: mindspore=2.0.0
执行模式(PyNative/ Graph):不限
Python版本: Python=3.7
操作系统平台: 不限
2 报错信息
2.1 问题描述
将pytorch转换成mindspore代码的时候,调用了pytorch.distributions的Categorical方法,这三行代码该怎么使用mindspore的方法进行呢。
2.2 脚本信息
Torch实现方式如下:
import torch.nn.functional as F
from torch.distributions import Categorical
def sample_cycle_component(self, logits, output_type, random_ratio=0):
action_index_local = Categorical(logits=logits).sample()
prob_matrix = F.softmax(logits, dim=1)
log_prob_matrix = F.log_softmax(logits, dim=1)
3 根因分析
torch.distributions.Categorical => mindspore.nn.probability.distribution.Categorical
torch.nn.functional.softmax => mindspore.ops.softmax
torch.nn.functional.log_softmax => mindspore.ops.log_softmax
4 解决方案
from mindspore import ops
from mindspore.nn.probability.distribution import Categorical
def sample_cycle_component(logits, output_type, random_ratio=0):
action_index_local = Categorical(probs=input).sample()
prob_matrix = ops.softmax(logits, axis=1)
log_prob_matrix = ops.log_softmax(logits, axis=1)