自定义Callback重载函数调用顺序错误及解决

1 系统环境

硬件环境(Ascend/GPU/CPU): Ascend/GPU/CPU
MindSpore版本: mindspore=1.10.1
执行模式(PyNative/ Graph):不限
Python版本: Python=3.9.7
操作系统平台: 不限

2 报错信息

2.1 问题描述

自定义Callback的函数调用顺序违背方法的名称。现有调用顺序为on_train_epoch_begin → on_train_step_* → on_eval_epoch_begin → on_eval_step_* → on_eval_epoch_end → on_train_epoch_end。
这导致train_epoch内含了eval_epoch。这样的运行流程导致设计callback的回显内容时增加了不必要的麻烦,并且逻辑上也不直观。
当前结构:

on_train_epoch_begin
    on_train_step_
    on_eval_epoch_begin
        on_eval_step_
    on-eval_epoch_end
on_train_epoch_end

期望结构:

on_train_epoch_begin
    on_train_step_
on_train_epoch_end
on_eval_epoch_begin
    on_eval_step_
on-eval_epoch_end

3 根因分析

分析当前结构和期望结构发现, 当前结构是变训练边推理,期望是先训练后推理,查看当前代码是用的是model.fit,根据官网介绍model.fit就是边训练边推理

4 解决方案

要达到先训练后推理,可将model.fit换成model.train和model.eval,代码如下

epoch = 4 
for i in range(epoch): 
    model.train(i+1, train_ds, callbacks=[callback1], dataset_sink_mode=False) 
    model.eval(val_ds, callbacks=[callback1], dataset_sink_mode=False) 

可以达到先训练后推理