1. 环境信息
硬件环境(Ascend/GPU/CPU): Ascend 910B
MindSpore版本: 2.2.11
执行模式(PyNative/ Graph):Graph
Python版本: 不限
操作系统平台: linux
2. 问题现象
pangu-100b在不同规模的集群上的性能:64->30s/step, 1k->33s/step, 2k->45s/step,在2k集群上性能劣化,不满足线性度。
3. 问题分析
profile后分析timeline数据,发现有个allgather算子通信很慢,占比很大。
存图后定位到allgather算子是attention_mask产生的,attention_mask的shape[batch_size,sql_length,sql_length],在数据集里生成会导致oom,所以采用入图的方式生成。
attention_mask初始化脚本如下:
def construct(self, input_ids, input_position, attention_mask, lables,loss_mask):
r"""Forward process of the pangu model
input_ids: [b, seq_lenght]
input_position: [b, seq_lenght]
attention_mask: [b, seq_lenght, seq_lenght]
labels: [b, lenght]
loss_mask: [b, seq_lenght]
"""
attention_mask = self.tril(F.ones((self.batch_size, self.seq_lenght, self.seq_lenght), dtype=matype.float16))
根本原因是tril算子不支持shard,配置了切分策略不生效,所以会插入allgather通信算子,本身shape比较大,导致通信耗时长,拖慢整个训练。
4. 解决方案
采用np方式初始化Tensor后,消除掉了allgather通信算子,问题解决,2k集群性能32s/step,线性度达标。
self.attention_mask= Tensor(np.tril(xxx))