1 系统环境
硬件环境(Ascend/GPU/CPU): GPU
MindSpore版本: mindspore=2.0.0rc1
执行模式(PyNative/ Graph):不限
Python版本: Python=3.7.5
操作系统平台: 不限
2 报错信息
2.1 问题描述
在·init·按照如下方法使用shard方法:
self.w_q = nn.Dense(dim, n_heads * self.head_dim, has_bias=False)
self.w_q.shard(in_strategy=((parallel_config.data_parallel, 1, 1),), out_strategy=(None,),
parameter_plan={"self.w_q.weight": (parallel_config.model_parallel, 1)}, device="GPU")
2.2 报错信息
在construct中调用self.w_q
时报错:
RuntimeError: The pointer[comm_lib_instance_] is null.
3 根因分析
仔细看报错,是在使用shard时报错的,对比官网cell.shard并没有使用错误,但shard是分布式才使能的,
if __name__ == "__main__":
里面却并没有init()
4 解决方案
加上init()
,并使用mpirun启动,能够执行成功,代码更新如下
import mindspore
from mindspore import nn, ops
from mindspore import numpy as mnp
from mindspore.communication import init
# from model import Test
class ParallelConfig():
def __init__(self, data_parallel=1, model_parallel=1):
self.data_parallel = data_parallel
self.model_parallel = model_parallel
import mindspore
from mindspore import nn, ops
from mindspore import numpy as mnp
class Test(nn.Cell):
def __init__(self, dim=512, n_layers=8, n_heads=8, vocab_size=-1, multiple_of=256, norm_eps=1e-5,
max_batch_size=32, max_seq_len=2048, parallel_config=None, ):
super().__init__()
self.n_local_heads = n_heads // 2
self.head_dim = dim // n_heads
self.w_q = nn.Dense(dim, n_heads * self.head_dim, has_bias=False)
self.w_q.shard(in_strategy=((parallel_config.data_parallel, 1, 1), ), out_strategy=(None, ),
parameter_plan={"self.w_q.weight": (parallel_config.model_parallel, 1)},
device="GPU")
def construct(self, _x):
bsz, seqien, _ = _x.shape
x_q = self.w_q(_x)
return x_q
if __name__ == "__main__":
mindspore.set_context(mode=mindspore.PYNATIVE_MODE)
mindspore.set_auto_parallel_context(device_num=2)
mindspore.set_auto_parallel_context(parallel_mode="auto_parallel")
mindspore.set_auto_parallel_context(search_mode="sharding_propagation")
init()
batch_size = 2
dim1 = 8
n_heads = 4
input = mnp.ones((batch_size, 16, dim1))
Config = ParallelConfig()
attn = Test(dim=dim1, n_heads=n_heads, parallel_config=Config)
output = attn(input)
print(output)
启动命令为mpirun -n 2 --allow-run-as-root python test.py能够输出output