1 系统环境
硬件环境(Ascend/GPU/CPU): Ascend/GPU/CPU
MindSpore版本: mindspore=2.0.0
执行模式(PyNative/ Graph):不限
Python版本: Python=3.7
操作系统平台: 不限
2 报错信息
2.1 问题描述
如何使用MindSpore实现Torch的logsumexp函数
2.2 脚本信息
Torch实现方式如下:
self.logits = logits - logits.logsumexp(dim=-1, keepdim=True)
3 根因分析
torch.logsumexp => mindspore.nn.logsumexp
4 解决方案
import mindspore as ms
from mindspore import ops, nn
import torch
logits_torch = torch.randn((1, 2))
logits_ms = ms.Tensor(logits_torch.numpy())
logits_torch = logits_torch - logits_torch.logsumexp(dim=-1, keepdim=True)
logits_ms = logits_ms - logits_ms.logsumexp(axis=-1, keepdims=True)
print(logits_torch)
print(logits_ms)
tensor([[-1.7354, -0.1940]])
[[-1.7353749 -0.19399017]]