if语句产生的控制流
Type Join Failed
在前端编译的推理阶段,会对节点的抽象类型(包含 type 、shape 等)进行推导,常见抽象类型包括 AbstractScalar 、AbstractTensor 、AbstractFunction 、AbstractTuple 、AbstractList 等。在一些场景比如多分支场景,会对不同分支返回值的抽象类型进行 join 合并,推导出返回结果的抽象类型。如果抽象类型不匹配,或者 type /shape 不一致,则会抛出以上异常。
当出现类似Type Join Failed: dtype1 = Float32, dtype2 = Float16 的报错时,说明数据类型不一致,导致抽象类型合并失败。根据提供的数据类型和代码行信息,可以快速定位出错范围。此外,报错信息中提供了具体的抽象类型信息、节点信息。代码样例如下:
import numpy as np
import mindspore as ms
import mindspore.ops as ops
from mindspore import nn, jit
class Net(nn.Cell):
def __init__(self):
super().__init__()
self.relu = ops.ReLU()
self.cast = ops.Cast()
@jit
def construct(self, x, a, b):
if a > b: # if的两个分支返回值的type不一致
return self.relu(x) # shape: (2, 3, 4, 5), dtype:Float32
else:
return self.cast(self.relu(x), ms.float16) # shape: (2, 3, 4, 5), dtype:Float16
input_x = ms.Tensor(np.random.rand(2, 3, 4, 5).astype(np.float32))
input_a = ms.Tensor(2, ms.float32)
input_b = ms.Tensor(6, ms.float32)
net = Net()
out_me = net(input_x, input_a, input_b)
执行结果如下:
TypeError: Cannot join the return values of different branches, perhaps you need to make them equal.
Type Join Failed: dtype1 = Float32, dtype2 = Float16.
For more details, please refer to https://www.mindspore.cn/search?inputValue=Type%20Join%20Failed
----------------------------------------------------
- Framework Error Message: (For framework developers)
----------------------------------------------------
The abstract type of the return value of the current branch is:
AbstractTensor(shape: (2, 3, 4, 5), element: AbstractScalar(Type: Float16, Value: ValueAny, Shape: NoShape), value_ptr: 0x15694f0, value: ValueAny),
and that of the previous branch is:
AbstractTensor(shape: (2, 3, 4, 5), element: AbstractScalar(Type: Float32, Value: ValueAny, Shape: NoShape), value_ptr: 0x15694f0, value: ValueAny).
The node is @1___main___Net_construct_47:CNode_45{[0]: @1___main___Net_construct_47:CNode_42{[0]: ValueNode<Primitive> Switch, [1]: CNode_31, [2]: ValueNode<FuncGraph> 3_✓__main___Net_construct_43, [3]: ValueNode<FuncGraph> 4_✗__main___Net_construct_44}}, true branch: 3_✓__main___Net_construct_43
In file test.py:15, 12~31
return self.relu(x) # shape: (2, 3, 4, 5), dtype:Float32
^~~~~~~~~~~~~~~~~~~
, false branch: 4_✗__main___Net_construct_44
In file test.py:17, 12~54
return self.cast(self.relu(x), ms.float16) # shape: (2, 3, 4, 5), dtype:Float16
^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
----------------------------------------------------
- C++ Call Stack: (For framework developers)
----------------------------------------------------
mindspore/core/abstract/abstract_value.cc:607 ThrowException
----------------------------------------------------
- The Traceback of Net Construct Code:
----------------------------------------------------
# In file test.py:12~17, 4~54
@jit
# In file test.py:14~17, 8~54
if a > b: # if的两个分支返回值的type不一致
^
通过报错信息可以得知,return self.relu(x)的节点抽象类型和return self.cast(self.relu(x), ms.float16)的节点抽象类型不一致,故无法推导出控制流的抽象类型。如想将该段代码用图编译加速的方式提升性能,需要先修改代码,使控制流的两个分支的抽象保持一致。对于上方用例,修改方式可以是去掉cast。