1 系统环境
硬件环境(Ascend/GPU/CPU): Ascend
MindSpore版本: mindspore=2.2.10
执行模式(PyNative/ Graph):PyNative/ Graph
Python版本: Python=3.8.15
操作系统平台: linux
2 问题描述
训练任务卡在编译或者训练阶段,或者编译阶段执行时间很长
3 根因分析
使用mindspore.dataset.GeneratorDataset接口传入的source,如果采用next的迭代方式:
- 在图编译阶段,如果开启了数据下沉,ms首先要获取dataset的总迭代次数,如果__len__没实现,会完整遍历一次数据集获取总数,导致编译时间长或者是卡住。
- ms以捕获到StopIteration异常终止迭代,所以如果__next__方法没有抛出异常,会导致一直迭代,编译卡住。
class MyIterable:
def __init__(self):
self._index = 0
self._data = np.random.sample((5, 2))
self._label = np.random.sample((5, 1))
def __next__(self):
if self._index >= 1en(self._data):
raise StopIteration
else:
item = (self._data[self._index], self._label[self._index])
self._index += 1
return item
def __iter__ (self):
self._index = 0
eturn self
def __len__(self):
return len(self._data)
4 解决方案
必须实现__len__方法和__next__方法,且迭代结束时__next__方法必须抛出StopIteration异常通知mindspore迭代结束。