解决报错:迁移pytorch代码时如何将torch.device映射 usability/api

问题来源:

https://www.hiascend.com/forum/thread-0227105267390861012-1-1.html

1 系统环境

硬件环境(Ascend/GPU/CPU): Ascend/GPU/CPU
MindSpore版本: mindspore=1.8.1
执行模式(动态图/静态图): GRAPH_MODE
Python版本: Python=3.7.5
操作系统平台: Linux

2 报错信息

2.1 问题描述

迁移pytorch代码

import torch

device = torch.device(‘cuda’ if torch.cuda.is_available() else ‘cpu’)
print(type(device))

2.2 报错信息

尝试用ms.set_context(device_target="GPU\CPU\Ascend"),但是运行结果输出为None.

2.3 脚本代码

import torch
import os
import sys

# 有的地方改的不正去,比如这个ms.set_countext
# https://blog.csdn.net/qq_43215538/article/details/126161578
# 有问题,torch里出来的torch type的cpu
# device = torch.device(‘cuda’ if torch.cuda.is_available() else ‘cpu’)

device = ms.set_context(device_target=’CPU’)
print(type(device))

3 根因分析

根因:接口使用错误,可用mindspore.get_context完成。
接口说明:mindspore.context | MindSpore 1.7 documentation | MindSpore
原因:set_context应该是没有返回值,用get_context来获取

4 解决方案

import torch  
import os  
import sys  
from mindspore import context  
    
context.set_context(device_target="CPU")  
device = context.get_context("device_target")  
print(device)  
print(type(device))