1 系统环境
硬件环境(Ascend/GPU/CPU): Ascend
MindSpore版本: mindspore=2.3
执行模式(PyNative/ Graph):Graph
Python版本: Python=3.9
操作系统平台: Linux
2.报错信息
2.1问题描述
执行分布式训练任务,加载ckpt出现如下报错
对于'load_param_into_net',参数'net'中的wq.weight的shape应与参数'parameter_dict'中的wq.weight的shape相同。但参数'net'中wq.weight的shape为(512, 8192),而参数'parameter_dict'中wq.weight的shape为(8192, 8192)。请检查加载的checkpoint是否正确,或者'net'和'parameter_dict'中的batch size等是否一致。
3 根因分析
a. 错误出现在加载checkpoint时失败
b. 报错信息为:网络中wq.weight参数的shape为(512, 8192),但是checkpoint文件中q.weight参数的shape为(8192, 8192),shape不一致导致加载失败。
c. 排查是网络的shape还是checkpoint文件的shape不符合预期,经过排查发现,网络中的shape是切分后的,是符合预期的。
d. 发现checkpoint中保存的参数shape是全量的,并不是切分好的。
e. 进一步排查发现,save_checkpoint方法默认会把所有parameter合并保存成全量的,导致和网络中shape不一致。
4 解决方案
在调用save_checkpoint的时候,把integrated_save参数手动设置为false,使所有节点保存切分后的参数。
重新加载保存后的checkpoint文件,问题解决