如何使用MindSpore实现Torch的logsumexp函数

1 系统环境

硬件环境(Ascend/GPU/CPU): Ascend/GPU/CPU
MindSpore版本: mindspore=2.0.0
执行模式(PyNative/ Graph):不限
Python版本: Python=3.7
操作系统平台: 不限

2 报错信息

2.1 问题描述

如何使用MindSpore实现Torch的logsumexp函数

2.2 脚本信息

Torch实现方式如下:

self.logits = logits - logits.logsumexp(dim=-1, keepdim=True)

3 根因分析

torch.logsumexp => mindspore.nn.logsumexp

4 解决方案

import mindspore as ms  
from mindspore import ops, nn  
import torch  
logits_torch = torch.randn((1, 2))  
logits_ms = ms.Tensor(logits_torch.numpy())  
logits_torch = logits_torch - logits_torch.logsumexp(dim=-1, keepdim=True)  
logits_ms = logits_ms - logits_ms.logsumexp(axis=-1, keepdims=True)  
print(logits_torch)  
print(logits_ms)  
    
tensor([[-1.7354, -0.1940]])  
[[-1.7353749  -0.19399017]]