MindSpore报错:wq.weight in the argument 'net' should have the same shape as wq.weight in the argument 'parameter_dict'.

1 系统环境

硬件环境(Ascend/GPU/CPU): Ascend

MindSpore版本: mindspore=2.3

执行模式(PyNative/ Graph):Graph

Python版本: Python=3.9

操作系统平台: Linux

2.报错信息

2.1问题描述

执行分布式训练任务,加载ckpt出现如下报错

对于'load_param_into_net',参数'net'中的wq.weight的shape应与参数'parameter_dict'中的wq.weight的shape相同。但参数'net'中wq.weight的shape为(512, 8192),而参数'parameter_dict'中wq.weight的shape为(8192, 8192)。请检查加载的checkpoint是否正确,或者'net'和'parameter_dict'中的batch size等是否一致。

3 根因分析

a. 错误出现在加载checkpoint时失败

b. 报错信息为:网络中wq.weight参数的shape为(512, 8192),但是checkpoint文件中q.weight参数的shape为(8192, 8192),shape不一致导致加载失败。

c. 排查是网络的shape还是checkpoint文件的shape不符合预期,经过排查发现,网络中的shape是切分后的,是符合预期的。

d. 发现checkpoint中保存的参数shape是全量的,并不是切分好的。

e. 进一步排查发现,save_checkpoint方法默认会把所有parameter合并保存成全量的,导致和网络中shape不一致。

4 解决方案

在调用save_checkpoint的时候,把integrated_save参数手动设置为false,使所有节点保存切分后的参数。

重新加载保存后的checkpoint文件,问题解决