如何使用MindSpore实现梯度对数据求导retain_graph=True

1 系统环境

硬件环境(Ascend/GPU/CPU): Ascend/GPU/CPU
MindSpore版本: 2.2.0
执行模式(PyNative/ Graph): 不限

2 报错信息

##2.1 问题描述
假设输入数据x和标签y,通过模型反向传播mindspore.gard获取到了模型的梯度g,现在需要将梯度g进行一定的运算(假设为f(g)),最后f(g)需要对最初的数据x进行求导,如果使用Torch的话,在求梯度g的时候只要 retain_graph=True,MindSpore该如何实现呢?

3 解决方案

在第一次求梯度g的时候,不调用MindSpore的 NLLLoss() 即可,然而MindSpore自带的交叉熵损失也自动调用了 NLLLoss(),因此也需要重写。

def _loss_fn(self, logits, labels):  
    logs = ops.log(ops.softmax(logits))  
    loss = self._nll_loss(logs, labels)  
    return loss  
    
def _nll_loss(self, logits, labels):  
    loss = [-logs[i, int(j)] for i, j in enumerate(labels)]  
    loss = ops.stack(loss).mean()  
    return loss