使用MindSpore的get_auto_parallel_context("device_num")识别设备信息错误

1 系统环境

硬件环境(Ascend/GPU/CPU): Ascend

MindSpore版本: mindspore=2.4.1、mindformers==1.3.0

执行模式(PyNative/ Graph):不限

Python版本: Python=3.9

操作系统平台: linux

2 报错信息

2.1****问题描述

参考msrun启动教程:不设置device_target会自动指定为MindSpore包对应的后端硬件设备,但是通过以下命令发现识别到的NPU卡数量为1,导致MSRUN跑并行任务时只有一个NPU节点有日志。

2.2****脚本信息

import mindspore as ms  
  
ms.set_context(mode=ms.GRAPH_MODE, device_target="Ascend")  
device_num = ms.context.get_auto_parallel_context("device_num")  
print("MindSpore识别到的NPU卡数量为:", device_num)

2.3****报错信息

(MindSpore) [root ringmo] $npu-smi info  
+------------------------------------------------------------------------------------------------------------------------------+  
 |  npu-smi 23.0.rc2.0                        Version: 24.1.rc2.1                                                                                                 |  
  
+------------------------+------------------------+---------------------------------------------------------------------------+  
 |  NPU    Name             |  Health                       |  Power(W)        Temp(C)                          Hugepages-Usage(page)     |  
  
 |  Chip                          |  Bus-Id                       |  AICore(%)        Memory-Usage(MB)        HBM-Usage(MB)                 |  
+==============+==============+============================================+  
 |  3         910                 |  OK                            |  97.9                 45                                   0         /  0                           |  
 |  0                               |  0000:02:00.0              |  0                      0         /  0                      3673    /  65536                    |  
  
+==============+==============+============================================+  
  
 |  5         910                 |  OK                            |  94.5                 45                                   0         /  0                           |  
  
 |  0                               |  0000:41:00.0              |  0                      0         /  0                      3340    /  65536                    |  
  
+==============+==============+============================================+  
  
+------------------------+------------------------+---------------------------------------------------------------------------+  
  
 |  NPU    Chip               |  Process id                  |  Process name                            |  Process memory(MB)                     |  
  
+==============+==============+============================================+  
  
 |  No running process found in NPU 3                                                                                                                          |  
  
+==============+==============+============================================+  
  
 |  No running process found in NPU 5                                                                                                                          |  
  
+==============+==============+============================================+  
(MindSpore) [root ringmo] $python  
Python 3.9.18  |  packaged by conda-forge  |  (main, Aug 30 2023, 04: 25: 25)  
[GCC 12.3.0]  on Linux  
Type "help", "copyright", "credits" or "license" for more information.  
>>> inport nindspore as ms  
>>>  
>>> ms.set_ context(mode=ns.GRAPH_MODE, device_target="Ascend")  
>>> device_num = ms.context.get_auto_parallel_context("device_num")  
>>> print("MindSpore识别到的NPU卡数量为:", device_num)  
MindSpore识别到的NPU卡数量为: 1

3 根因分析

这个获取卡数的方式好像不对,get_auto_parallel_context(“device_num”)只是获取上下文的一个属性值。

4 解决方案

首先需要使用set_auto_parallel_context给它赋值,否则就是固定的默认值;

通过api获取实际的卡数可以使用get_group_size方法,如果想让get_auto_parallel_context(“device_num”)返回正确的卡数,那需要把get_group_size的结果传入其中,可参考如下脚本:

import mindspore as ms  
from mindspore.communication import init, get_group_size  
  
ms.set_context(mode=ms.GRAPH_MODE, device_target="Ascend")  
init()  
group_size = get_group_size()  
print("group_size_is: ", group_size)  
  
ms.set_auto_parallel_context(device_num=group_sizze)  
device_num = ms.context.get_auto_parallel_context("device_num")  
print("NPU卡数为: ", device_num)