使用shard接口遇到空指针的报错RuntineError: The pointer [comm_lib_instance_] is null.

1 系统环境

硬件环境(Ascend/GPU/CPU): GPU
MindSpore版本: mindspore=2.0.0rc1
执行模式(PyNative/ Graph):不限
Python版本: Python=3.7.5
操作系统平台: 不限

2 报错信息

2.1 问题描述

在·init·按照如下方法使用shard方法:

self.w_q = nn.Dense(dim, n_heads * self.head_dim, has_bias=False)
self.w_q.shard(in_strategy=((parallel_config.data_parallel, 1, 1),), out_strategy=(None,),
parameter_plan={"self.w_q.weight": (parallel_config.model_parallel, 1)}, device="GPU")

2.2 报错信息

在construct中调用self.w_q时报错:

RuntimeError: The pointer[comm_lib_instance_] is null.

3 根因分析

仔细看报错,是在使用shard时报错的,对比官网cell.shard并没有使用错误,但shard是分布式才使能的,

if __name__ == "__main__":

里面却并没有init()

4 解决方案

加上init(),并使用mpirun启动,能够执行成功,代码更新如下

import mindspore
from mindspore import nn, ops
from mindspore import numpy as mnp
from mindspore.communication import init


# from model import Test

class ParallelConfig():
    def __init__(self, data_parallel=1, model_parallel=1):
        self.data_parallel = data_parallel
        self.model_parallel = model_parallel


import mindspore
from mindspore import nn, ops
from mindspore import numpy as mnp


class Test(nn.Cell):
    def __init__(self, dim=512, n_layers=8, n_heads=8, vocab_size=-1, multiple_of=256, norm_eps=1e-5,
                max_batch_size=32, max_seq_len=2048, parallel_config=None, ):
        super().__init__()
        self.n_local_heads = n_heads // 2
        self.head_dim = dim // n_heads
        self.w_q = nn.Dense(dim, n_heads * self.head_dim, has_bias=False)
        self.w_q.shard(in_strategy=((parallel_config.data_parallel, 1, 1), ), out_strategy=(None, ),
                        parameter_plan={"self.w_q.weight": (parallel_config.model_parallel, 1)},
                        device="GPU")

    def construct(self, _x):
        bsz, seqien, _ = _x.shape
        x_q = self.w_q(_x)

        return x_q


if __name__ == "__main__":
    mindspore.set_context(mode=mindspore.PYNATIVE_MODE)
    mindspore.set_auto_parallel_context(device_num=2)
    mindspore.set_auto_parallel_context(parallel_mode="auto_parallel")
    mindspore.set_auto_parallel_context(search_mode="sharding_propagation")
    init()
    batch_size = 2
    dim1 = 8
    n_heads = 4
    input = mnp.ones((batch_size, 16, dim1))
    Config = ParallelConfig()
    attn = Test(dim=dim1, n_heads=n_heads, parallel_config=Config)
    output = attn(input)
    print(output)

启动命令为mpirun -n 2 --allow-run-as-root python test.py能够输出output