使用jit编译加速时编译流程中的常见if控制流问题

if语句产生的控制流

Type Join Failed

在前端编译的推理阶段,会对节点的抽象类型(包含 typeshape 等)进行推导,常见抽象类型包括 AbstractScalarAbstractTensorAbstractFunctionAbstractTupleAbstractList 等。在一些场景比如多分支场景,会对不同分支返回值的抽象类型进行 join 合并,推导出返回结果的抽象类型。如果抽象类型不匹配,或者 type /shape 不一致,则会抛出以上异常。

当出现类似Type Join Failed: dtype1 = Float32, dtype2 = Float16 的报错时,说明数据类型不一致,导致抽象类型合并失败。根据提供的数据类型和代码行信息,可以快速定位出错范围。此外,报错信息中提供了具体的抽象类型信息、节点信息。代码样例如下:

import numpy as np
import mindspore as ms
import mindspore.ops as ops
from mindspore import nn, jit

class Net(nn.Cell):
    def __init__(self):
        super().__init__()
        self.relu = ops.ReLU()
        self.cast = ops.Cast()

    @jit    
    def construct(self, x, a, b):
        if a > b:    # if的两个分支返回值的type不一致
            return self.relu(x)    # shape: (2, 3, 4, 5), dtype:Float32
        else:
            return self.cast(self.relu(x), ms.float16)    # shape: (2, 3, 4, 5), dtype:Float16

input_x = ms.Tensor(np.random.rand(2, 3, 4, 5).astype(np.float32))
input_a = ms.Tensor(2, ms.float32)
input_b = ms.Tensor(6, ms.float32)
net = Net()
out_me = net(input_x, input_a, input_b)

执行结果如下:

TypeError: Cannot join the return values of different branches, perhaps you need to make them equal.
Type Join Failed: dtype1 = Float32, dtype2 = Float16.
For more details, please refer to https://www.mindspore.cn/search?inputValue=Type%20Join%20Failed

----------------------------------------------------
- Framework Error Message: (For framework developers)
----------------------------------------------------
The abstract type of the return value of the current branch is:
AbstractTensor(shape: (2, 3, 4, 5), element: AbstractScalar(Type: Float16, Value: ValueAny, Shape: NoShape), value_ptr: 0x15694f0, value: ValueAny),
 and that of the previous branch is:
AbstractTensor(shape: (2, 3, 4, 5), element: AbstractScalar(Type: Float32, Value: ValueAny, Shape: NoShape), value_ptr: 0x15694f0, value: ValueAny).
The node is @1___main___Net_construct_47:CNode_45{[0]: @1___main___Net_construct_47:CNode_42{[0]: ValueNode<Primitive> Switch, [1]: CNode_31, [2]: ValueNode<FuncGraph> 3_✓__main___Net_construct_43, [3]: ValueNode<FuncGraph> 4_✗__main___Net_construct_44}}, true branch: 3_✓__main___Net_construct_43
In file test.py:15, 12~31
            return self.relu(x)    # shape: (2, 3, 4, 5), dtype:Float32
            ^~~~~~~~~~~~~~~~~~~

, false branch: 4_✗__main___Net_construct_44
In file test.py:17, 12~54
            return self.cast(self.relu(x), ms.float16)    # shape: (2, 3, 4, 5), dtype:Float16
            ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~


----------------------------------------------------
- C++ Call Stack: (For framework developers)
----------------------------------------------------
mindspore/core/abstract/abstract_value.cc:607 ThrowException

----------------------------------------------------
- The Traceback of Net Construct Code:
----------------------------------------------------

# In file test.py:12~17, 4~54
    @jit

# In file test.py:14~17, 8~54
        if a > b:    # if的两个分支返回值的type不一致
        ^

通过报错信息可以得知,return self.relu(x)的节点抽象类型和return self.cast(self.relu(x), ms.float16)的节点抽象类型不一致,故无法推导出控制流的抽象类型。如想将该段代码用图编译加速的方式提升性能,需要先修改代码,使控制流的两个分支的抽象保持一致。对于上方用例,修改方式可以是去掉cast。