MindSpore 静态图网络编译性能优化--使用Select算子优化编译性能

1. MindSpore静态图网络编译性能优化的前因

静态图的特点是将计算图的构建和实际计算分开, 也就是编译和运行分开。
在构建阶段,根据完整的计算流程对原始的计算图进行优化和调整,编译得到更省内存和计算量更少的计算图。
由于编译之后图的结构不再改变,所以称之为 “静态图”。
网络执行过程的时间主要由编译耗时与运行耗时两个组成,
在推理的时候,编译耗时远大于运行耗时, 因此减少编译耗时势在必得
且优化编译性能对于提升网络在实际应用时的部署效果有着极为重要的意义

2. MindSpore静态图网络编译性能优化的方法

  1. 使用HyperMap优化编译性能
  2. 使用Select算子优化编译性能
  3. 使用编译缓存优化编译性能
  4. 使用vmap优化编译性能

3. 使用Select算子优化编译性能

在mindspore中,if若是进入控制流,那么每个if都会产生额外的子图,
在静态图模式下,子图数量越多,编译耗时越久
而if在网络编写的时候,又是一个常用的语句
在mindspore中,可以用Select等价替换if语句来优化编译性能。
虽然Select的编译耗时比if控制流少,但是Select会同时执行true分支和false分支,
所以Select的运行耗时比if控制流多。
当分支中算子数量较少, 建议使用Select算子。 若是较多,则推荐使用if控制流

4. Select 用法介绍

ops.select API
mindspore.ops.select(_cond_ ,  _x_ ,  _y_)

根据条件判断Tensor中的元素的值,来决定输出中的相应元素是从 x (如果元素值为True)
还是从 y (如果元素值为False)中选择。
该算法可以被定义为:


参数:

  • cond (Tensor[bool]) - 条件Tensor,决定选择哪一个元素,shape是 (x1,x2,…,xN,…,xR)。
  • x (Union[Tensor, int, float]) - 第一个被选择的Tensor或者数字。 如果x是一个Tensor,那么shape是或者可以被广播为 (x1,x2,…,xN,…,xR)。 如果x是int或者float,那么将会被转化为int32或者float32类型,并且被广播为与y相同的shape。x和y中至少要有一个Tensor。
  • y (Union[Tensor, int, float]) - 第二个被选择的Tensor或者数字。 如果y是一个Tensor,那么shape是或者可以被广播为 (x1,x2,…,xN,…,xR)。 如果y是int或者float,那么将会被转化为int32或者float32类型,并且被广播为与x相同的shape。x和y中至少要有一个Tensor。
    返回:
  • Tensor,与 cond 的shape相同。
    支持的平台有:Ascend GPU CPU
    示例代码如下
import time  
from mindspore import ops  
import mindspore as ms  
  
@ms.jit  
def if_net(x, y):  
    out = 0  
    for _ in range(100):  
        if x < y:  
            x = x - y  
        else:  
            x = x + y  
        out = out + x  
    return out  
  
start_time = time.time()  
out = if_net(ms.Tensor([0]), ms.Tensor([1]))  
end_time = time.time()  
print("if net cost time:", end_time - start_time)  
  
@ms.jit  
def select_net(x, y):  
    out = x  
    for _ in range(100):  
        cond = x < y  
        x = ops.select(cond, x - y, x + y)  
        out = out + x  
    return out  
  
start_time = time.time()  
out = select_net(ms.Tensor([0]), ms.Tensor([1]))  
end_time = time.time()  
print("select net cost time:", end_time - start_time)  

cke_1789.png
上面的示例中, 算子较少,执行完成时间相比较而言
使用Select的运行耗时比较少,所以使用select能优化性能