使用SummaryRecord记录计算图报错:Failed to get proto for graph.

1 系统环境

硬件环境(Ascend/GPU/CPU): Ascend/GPU
MindSpore版本: 2.2.0
执行模式(PyNative/ Graph): 不限

2 报错信息

2.1 问题描述

使用SummaryRecord记录计算图报错:

Failed to get proto for graph.

2.2 脚本信息

class AutoEncoderResNet(nn.Cell):
    def __init__(self):
        super().__init__()
        self.logdir = "./log/AutoEncoder-ResNet"
        if not os.path.exists(self.logdir):
            os.makedirs(self.logdir)
        num = len(os.listdir(self.logdir))
        self.logdir = os.path.join(self.logdir,"summary_"+str(num))
        self.encoder = ResNet(ResidualBlockBase,[3,4,6,3],512,512)
        self.decoder = ResNetInv(ResidualBlockBaseInv,[3,4,6,3],512,512)
        self.sigmoid = nn.Sigmoid()
    def construct(self, x):
        return self.decode(self.encode(x),False)
    def train(self,batch=32,epoch=50):
        
        from data.CIFAR10 import getCIFAR10
        dataset = getCIFAR10()
        dataset = dataset.batch(batch)
        print("start training")
        print("step per epoch:",len(dataset))
        opt = nn.Adam(self.trainable_params(),learning_rate=1e-4)
        loss_fn = nn.MSELoss()
        def forward(image):
            y_hat = self(image)
            loss = loss_fn(y_hat,image)
            return loss,y_hat
        grad_fn = ms.value_and_grad(forward,None,opt.parameters,has_aux=True)
        def train_step(image):
            (loss, y_hat),grad = grad_fn(image)
            opt(grad)
            return loss,y_hat
        import tqdm

        weight_path = "./weights/AutoEncoder-ResNet"
        if not os.path.exists(weight_path):
            os.makedirs(weight_path)
        from mindspore import SummaryRecord
        with SummaryRecord(log_dir=self.logdir,network=self) as summary_writer:
            summary_writer.set_mode("train")
        ......

2.3 报错信息

[ERROR] ME(21058:139964592719680,MainProcess):2024-03-03-14:51:26.699.488 [mindspore/train/summary/summary_record.py:384] Failed to get proto for graph.
[ERROR] ME(21058:139964592719680,MainProcess):2024-03-03-14:51:28.649.745 [mindspore/train/summary/summary_record.py:384] Failed to get proto for graph.

3 根因分析

PyNative模式导致报错Failed to get proto for graph.

4 解决方案

将模型运行的模式设置为graph。

ms.set_context(mode=ms.GRAPH_MODE)

并将jit_level设置为o0