使用hypermap+multitypefuncgraph代替控制流

在使用jit编译加速的时候,代码中存在的控制流有时会使计算图的子图数量增加。多子图可能影响网络执行性能,也可能导致很多未知报错,因此减少控制流的使用既可以提升网络性能,也可以减少调试成本。
其中一个减少控制流的方式就是使用hypermap+multitypefuncgraph代替控制流。hypermap文档可见:mindspore.ops.HyperMap | MindSpore 2.6.0 文档 | 昇思MindSpore社区 Multitypefuncgraph文档可见:mindspore.ops.MultitypeFuncGraph | MindSpore 2.6.0 文档 | 昇思MindSpore社区
如以下代码可见:

from mindspore import Tensor, ops, jit
from mindspore import dtype as mstype
nest_tensor_list = ((Tensor(1, mstype.float32), Tensor(2, mstype.float32)),
                    (Tensor(3, mstype.float32), Tensor(4, mstype.float32)))
# square all the tensor in the nested list

square = ops.MultitypeFuncGraph('square')
@square.register("Tensor")
def square_tensor(x):
    return ops.square(x)
@jit
def foo():
    common_map = ops.HyperMap()
    output = common_map(square, nest_tensor_list)
    return output
print(foo())

代码中HyperMap起到的作用是对nest_tensor_list中的每个元素都作为square的输入调用。而MultitypeFuncGraph起到的作用是,对于输入x的类型为Tensor的调用,将x作为输入调用square_tensor.
因此上面的代码等同于如下控制流写法:

from mindspore import Tensor, ops, jit
from mindspore import dtype as mstype
nest_tensor_list = ((Tensor(1, mstype.float32), Tensor(2, mstype.float32)),
                    (Tensor(3, mstype.float32), Tensor(4, mstype.float32)))

def square_tensor(x):
    return ops.square(x)
@jit
def foo():
    output = []
    for x in nest_tensor_list:
        output.append(square_tensor(x))
    return output
print(foo())

使用hypermap可以达成控制流写法的效果,但是可以规避控制流写法的弊端。