pangu-100b 2k集群线性度问题定位

1. 环境信息

硬件环境(Ascend/GPU/CPU): Ascend 910B
MindSpore版本: 2.2.11
执行模式(PyNative/ Graph):Graph
Python版本: 不限
操作系统平台: linux

2. 问题现象

pangu-100b在不同规模的集群上的性能:64->30s/step, 1k->33s/step, 2k->45s/step,在2k集群上性能劣化,不满足线性度。

3. 问题分析

profile后分析timeline数据,发现有个allgather算子通信很慢,占比很大。
存图后定位到allgather算子是attention_mask产生的,attention_mask的shape[batch_size,sql_length,sql_length],在数据集里生成会导致oom,所以采用入图的方式生成。
attention_mask初始化脚本如下:

def construct(self, input_ids, input_position, attention_mask, lables,loss_mask):
    r"""Forward process of the pangu model
    input_ids: [b, seq_lenght]
    input_position: [b, seq_lenght]
    attention_mask: [b, seq_lenght, seq_lenght]
    labels: [b, lenght]
    loss_mask: [b, seq_lenght]
    """
    attention_mask = self.tril(F.ones((self.batch_size, self.seq_lenght, self.seq_lenght), dtype=matype.float16))

根本原因是tril算子不支持shard,配置了切分策略不生效,所以会插入allgather通信算子,本身shape比较大,导致通信耗时长,拖慢整个训练。

4. 解决方案

采用np方式初始化Tensor后,消除掉了allgather通信算子,问题解决,2k集群性能32s/step,线性度达标。

self.attention_mask= Tensor(np.tril(xxx))