MindSpore如何实现pytoch中的detach()方法

1 系统环境

硬件环境(Ascend/GPU/CPU): CPU
MindSpore版本: mindspore=2.0.0
执行模式(PyNative/ Graph):不限
Python版本: Python=3.9
操作系统平台: 不限

2 报错信息

2.1 问题描述

将原pytoch代码网络的forward中对某参数使用了detach方法,MindSpore该如何实现detach()方法。

2.2 脚本信息

PyTorch

Prototypes = self.Prototypes.detach()

3 根因分析

torch 中detach的用途

Returns a new Tensor, detached from the current graph.   
The result will never require gradient.     
.. note::   
  Returned Tensor uses the same data tensor as the original one.   
  In-place modifications on either of them will be seen, and may trigger   
  errors in correctness checks. 

返回一个新的tensor,并且这个tensor是从当前的计算图中分离出来的。但是返回的tensor和原来的tensor是共享内存空间的。

4 解决方案

mindspore 没有detach这个概念。

如果非要在mindspore中找到类似的代替,只能对参数先进行克隆,然后将requires_grad设置为False。或者是直接将参数本身requires_grad设置为False。

但是克隆的话Mindspore和Torch还是有差别。克隆的参数和原来的参数不共享内存空间。原有的参数修改后,克隆的参数不会一起修改。

import torch  
from mindspore import Tensor,Parameter  
    
a = torch.tensor([1.0, 2.0, 3.0], requires_grad = True)  
a = a.detach() # 会将requires_grad 属性设置为False  
a=a*2  
print(a.requires_grad)  
print(a)  
b = Parameter(Tensor([1.0,2.0,3.0]))  
c = b.clone()  
b=b*2  
c.requires_grad = False  
print(b.value())  
print(c.value())  
print(c.requires_grad)
False  
tensor([2., 4., 6.])  
[2. 4. 6.]  
[1. 2. 3.]  
False