如何使用MindSpore替换torch.distributions的Categorical函数

1 系统环境

硬件环境(Ascend/GPU/CPU): Ascend/GPU/CPU
MindSpore版本: mindspore=2.0.0
执行模式(PyNative/ Graph):不限
Python版本: Python=3.7
操作系统平台: 不限

2 报错信息

2.1 问题描述

将pytorch转换成mindspore代码的时候,调用了pytorch.distributions的Categorical方法,这三行代码该怎么使用mindspore的方法进行呢。

2.2 脚本信息

Torch实现方式如下:

import torch.nn.functional as F  
from torch.distributions import Categorical  
def sample_cycle_component(self, logits, output_type, random_ratio=0):  
        action_index_local = Categorical(logits=logits).sample()  
        prob_matrix = F.softmax(logits, dim=1)  
        log_prob_matrix = F.log_softmax(logits, dim=1)

3 根因分析

torch.distributions.Categorical => mindspore.nn.probability.distribution.Categorical

torch.nn.functional.softmax => mindspore.ops.softmax

torch.nn.functional.log_softmax => mindspore.ops.log_softmax

4 解决方案

from mindspore import ops  
from mindspore.nn.probability.distribution import Categorical  
def sample_cycle_component(logits, output_type, random_ratio=0):  
    action_index_local = Categorical(probs=input).sample()  
    prob_matrix = ops.softmax(logits, axis=1)  
    log_prob_matrix = ops.log_softmax(logits, axis=1)