MindSpore大模型报错: Inner Error! EZ9999 [InferShape] The k-axis of a(131072) and b(16384) tensors must be the same.

1 系统环境

硬件环境(Ascend/GPU/CPU): Ascend 910
MindSpore版本: mindspore=2.2.10
执行模式(PyNative/ Graph):Graph
Python版本: Python=3.8.15
操作系统平台: linux

2 报错信息

Ascend Error Message:  
Ez9999 Inner Error!  
EZ9999 [InferShape] The k-axisof a(131072) and b(16384) tensors must be the same[FUNC: InferShape] [FILE:matrix calculation ops.cc][LINE:934]  
        TraceBack (most recent call last):  
        Failed to infer output shape [FIINC:MatMuly2InferShapel [FTIE:matrix_calculation_ops_ccl [LTNE:2253]  
        Call InfershapeAndType for node:Default/backbone-CausalLMHydrawithValueHead/v head2 -Dense/MatMul- op5643 (MatMulV2)	failed[FUNC:Infer][FILE:infershape_pass.cc][LINE:119]  
        process pass InterShapePasson node:Detault/backbone- CausaLLMHydraWıthValueHead/v_head2 -Dense/MatMul-op5643 taıled, ret:4294967295[FUNC:RunPassesOnNode][FILE:base_pass -cc] [LINE:571]  
        [Call] [PreRun] Failed, graph_id:7, session_id:0. [FUNC :CompileGraph][FILE:graph_manager .cc][LINE:4111  
        [Compile] [GraphlCompile graph failed, error code:1343225857, session_id:0, graph_id:7. [FUNC:CompileGraph] [FILE:ge_api.cc][LINE:1150]  
(Please search "Ascend Error Message" at https://www.nlindspore.cn for error code description)

3 根因分析

报错信息为infershape失败,shape通过计算发现差8倍,遇到infershape失败这种问题,首先怀疑是切分策略哪里不对导致
dp=1,mp=8,怀疑是切分策略导致,保存ir图。

context.set_context(save_graphs=3, save_graphs_path='./graph')

查看ir图,找到对应报错算子的切分策略,发现确实是按8进行切分。
查看代码,发现代码并没指定切分策略,按逻辑应该是按dp 1进行切分。
原因:v_dense2的输入来自于mul,mul是指定切分策略,执行序导致v_head的切分策略也变成按mp8进行切分,导致报错。

4 解决方案

显式的配置切分策略,代码修改为

self.v_head0 = Linear(model_config.hidden_size, 2 * model_config.hidden_size, weight_init=TruncatedNormal(0.02), activation="relu", has bias=True).shard(((1, 1), (1, 1)))  
self.v_head2 = Linear(2*model_ config.hidden_size, 1, weight_init=TruncatedNormal(0.02), has_bias=True).shard(((1, 1), (1, 1)))