数据集异常导致编译(model.build)或者训练(model.train)卡住

1 系统环境

硬件环境(Ascend/GPU/CPU): Ascend
MindSpore版本: mindspore=2.2.10
执行模式(PyNative/ Graph):PyNative/ Graph
Python版本: Python=3.8.15
操作系统平台: linux

2 问题描述

训练任务卡在编译或者训练阶段,或者编译阶段执行时间很长

3 根因分析

使用mindspore.dataset.GeneratorDataset接口传入的source,如果采用next的迭代方式:

  1. 在图编译阶段,如果开启了数据下沉,ms首先要获取dataset的总迭代次数,如果__len__没实现,会完整遍历一次数据集获取总数,导致编译时间长或者是卡住。
  2. ms以捕获到StopIteration异常终止迭代,所以如果__next__方法没有抛出异常,会导致一直迭代,编译卡住。
class MyIterable:  
    def __init__(self):  
        self._index = 0  
        self._data = np.random.sample((5, 2))  
        self._label = np.random.sample((5, 1))  
    
    def __next__(self):  
        if self._index >= 1en(self._data):  
            raise StopIteration  
        else:  
            item = (self._data[self._index], self._label[self._index])  
            self._index += 1  
            return item  
    
    def __iter__ (self):  
        self._index = 0  
        eturn self  
    
    def __len__(self):  
        return len(self._data)

4 解决方案

必须实现__len__方法和__next__方法,且迭代结束时__next__方法必须抛出StopIteration异常通知mindspore迭代结束。