大模型动态图训练内存优化

内存优化

训练过程中内存开销一般可分为模型参数、优化器状态、前向计算结果和临时变量存储几个维度。内存包括Host侧内存及Device侧显存,因为显存资源较为稀缺,一般开发者更多关注在显存优化。

用户编程checklist

  • 自定义算子反向数据按需保留:自定义算子易引入内存生命周期过长问题。用户实现自定义算子时,可调用used_bprop_inputs仅保存反向需要用到的数据。如下代码所示,其反向仅需要dout,即index=2,则需要在__init__函数中添加标识,此时在动态图正向执行完仅会保留dout数据。
class ReduceScatterToSequenceParallelRegion(nn.Cell):
    "Reduce scatter the input from the model parallel region."

    def __init__(self, need_to_swapaxes):
super(ReduceScatterToSequenceParallelRegion, self).__init__()
        self.world_size = get_tensor_model_parallel_world_size()
self.need_to_swapaxes = need_to_swapaxes
        if self.world_size > 1:
            self.tp_group = get_tensor_model_parallel_group()
self.used_bprop_inputs = [2]#bprop仅需要保留dout,则仅将dout的index传入标记

    def construct(self, input_):
        if self.world_size == 1:
            return ops.stop_gradient(input_)
        if self.need_to_swapaxes:
            input_ = input_.swapaxes(0, 1)
        output = comm_func.reduce_scatter_tensor(input_, group=self.tp_group)[0]
        if self.need_to_swapaxes:
            output = output.swapaxes(0, 1)
        return output

    # pylint: disable=W0613, C0111
    def bprop(self, *args):
        dout = args[-1]
        if self.world_size == 1:
            return dout
        if self.need_to_swapaxes:
            dout = dout.swapaxes(0, 1)
        output = comm_func.all_gather_into_tensor(dout, group=self.tp_group)[0]
        if self.need_to_swapaxes:
            output = output.swapaxes(0, 1)

        return (output,)
  • 打开虚拟内存:建议打开虚拟内存,减少大块Block的申请导致的显存峰值过高,开启方式如下。值得注意的是,MS_ALLOC_CONF为kv格式,多个配置项需要一次性设置,以逗号分隔,避免多次设置被刷新。
export MS_ALLOC_CONF="enable_vmm:True"
  • 手动开启GC:建议在每隔几百个step时,手动开启垃圾回收机制,它会检测不再使用的对象,并释放它们所占用的内存空间,如下所示。
import gc
gc.collect()

显存问题分析流程

如上述checklist排查和显存配置优化均完成后仍有显存不足的风险,则可根据下列流程分析具体细节。

  1. 根据模型计算理论显存值,
  2. 打开流同步
ms.set_context(pynative_synchronize=True)
  1. 打开显存抓取配置
export MS_ALLOC_CONF="memory_tracker:True"
  1. 执行获取显存分析报告(二选一)
    • 缩小层数实际运行(推荐)
    • Dryrun模拟:export MS_SIMULATION_LEVEL=1(注:需要配套MindSpore 2.5版本)
  2. 根据显存分析报告,重点分析未及时释放的显存