Mindspore网络精度自动比对功能中protobuf问题分析

1 系统环境

硬件环境(Ascend/GPU/CPU): GPU CUDA 11.7

MindSpore版本: mindspore=2.0.0

执行模式(PyNative/ Graph):Graph

Python版本: Python=3.8.13

操作系统平台: 不限

2 报错信息

2.1 问题描述

在mindspore的网络精度自动比对功能中,图模式下调用mindinsightx.compare.proto.me.mindinsight_anf_ir_pb2依赖包中的ModelProto方法报错TypeError: Descriptors cannot not be created directly。

同时,在使用onnx包下的onnx.ModelProto()存储模型信息时也会出现报错onnx 1.14.1 requires protobuf>=3.20.2。 markdown

markdown

2.2 脚本代码(代码格式)

    def _get_model_proto(self):
        return me_ir_pb2.ModelProto()

    def load(self, path):
        """
        Loads graph file.

        Args:
            path (str): Graph file path to be loaded.
        """
        self.real_path = os.path.realpath(path)
        self._proto = self._get_model_proto()

        try:
            with open(self.real_path, 'rb') as f:
                self._proto.ParseFromString(f.read())
        except message.DecodeError:
            try:
                with open(self.real_path, 'r') as f:
                    text_format.Merge(f.read(), self._proto)
            except Exception:
                raise ValueError(f'Cannot load MindSpore graph file {self.real_path}.')

        id_node_map = dict()

        graph_def = self._proto.graph

        for node_def in graph_def.node:
            node = self._parse_proto_node(node_def)
            if not node.node_id or not node.name:
                continue
            self.add_node(node)
            id_node_map[node.node_id] = node

        for param_def in graph_def.parameters:
            node = self._parse_proto_parameter(param_def)
            if not node.node_id or not node.name:
                continue
            self.add_node(node)
            id_node_map[node.node_id] = node

        for node_def in graph_def.node:
            node = self.get_node_by_name(node_def.full_name)
            if node is None:
                continue

            for input_def in node_def.input:
                input_node = id_node_map.get(input_def.name, None)
                if input_node is None:
                    continue
                self.add_edge(input_node, node)

    def load_pbtxt(self, path):
        self.real_path = os.path.realpath(path)
        self._proto = onnx.ModelProto()
        try:
            with open(self.real_path, 'r') as f:
                text_format.Parse(f.read(), self._proto)
        except Exception:
            raise ValueError(f'Cannot load MindSpore graph file {self.real_path}.')
        
        node_map = dict()

        graph_def = self._proto.graph
        id = 0

        for node_def in graph_def.node:
            id += 1
            node = MeNode(name=node_def.name,
                          node_id=id,
                          node_type=node_def.op_type,
                          scope='Default')
            if not node.node_id or not node_def.name:
                continue
            
            forward_pattern = r':(\d+)'
            backward_pattern = r':-(\d+)'
            forward_shapes = []
            backward_shapes = []
            
            for output in node_def.output:
                forward_match = re.search(forward_pattern, output)
                if forward_match:
                    forward_shapes.append(int(forward_match.group(1)))
                backward_match = re.search(backward_pattern, output)
                if backward_match:
                    backward_shapes.append(int(backward_match.group(1)))

            node.shape = forward_shapes + backward_shapes
            self.add_node(node)
            node_map[node.name] = node

        for node_def in graph_def.node:
            node = self.get_node_by_name(node_def.name)
            if node is None:
                continue

            for input_def in node_def.input:
                input_node_name = input_def.split(':')[0]
                input_node = node_map.get(input_node_name)
                if input_node is None:
                    continue
                self.add_edge(input_node, node)

3 根因分析

Protobuf的版本不匹配,在mindinsight中,需要protobuf版本满足<=3.20.1,而在onnx中,需要protobuf版本满足>=3.20.2。

4 解决方案

pip uninstall protobuf

pip install protobuf==3.20.1

通过安装中间版本protobuf==3.20.1,顺利同时解决所有问题。

最终成功运行比较文件~