jit编译加速功能开启时,如何避免多次重新编译。

使用jit接口将一个函数或者cell用静态图编译加速时,如果输入发生以下变化,将会在每次执行时重新遍历,大大降低执行性能。因此为了避免该情况,可以通过改变脚本的方式使图编译流程识别到输入的动态状态。

Tensor的shape发生改变

当cell或者函数输入的tensor的shape改变时,多次调用该函数或cell将会导致多次编译。如以下代码

import numpy as np
import mindspore as ms
from mindspore import nn, jit

class Net(nn.Cell):
    def __init__(self):
        super().__init__()

    @jit 
    def construct(self, x, a):
        return x + a

net = Net()
out_me = []
for i in range(1, 10):
    input_x = ms.Tensor(np.random.rand(i, 3, 4, 5).astype(np.float32))
    input_a = ms.Tensor(np.random.rand(i, 3, 4, 5).astype(np.float32))
    out_me.append(net(input_x, input_a))

其中net的每次调用之间输入的shape都不一样,会导致每次调用net都重新编图。如以下代码用@jit(dynamic=1)代替@jit可以避免多次编图,只重新编译两次。

import numpy as np
import mindspore as ms
from mindspore import nn, jit

class Net(nn.Cell):
    def __init__(self):
        super().__init__()

    @jit(dynamic=1)  
    def construct(self, x, a):
        return x + a

net = Net()
out_me = []
for i in range(1, 10):
    input_x = ms.Tensor(np.random.rand(i, 3, 4, 5).astype(np.float32))
    input_a = ms.Tensor(np.random.rand(i, 3, 4, 5).astype(np.float32))
    out_me.append(net(input_x, input_a))

标量值发生改变

当cell或者函数的输入有标量,且标量值发生改变时,多次调用该函数或cell将会导致多次编译。如以下代码

import numpy as np
import mindspore as ms
from mindspore import nn, jit
class Net(nn.Cell):
    def __init__(self):
        super().__init__()

    @jit(dynamic=1)  
    def construct(self, x, a, i):
        if i > 5:
            return x + a
        else:
            return x

net = Net()
out_me = []
for i in range(1, 10):
    input_x = ms.Tensor(np.random.rand(2, 3, 4, 5).astype(np.float32))
    input_a = ms.Tensor(np.random.rand(2, 3, 4, 5).astype(np.float32))
    out_me.append(net(input_x, input_a, i))

其中net的每次调用之间输入的标量i的值都不一样,会导致每次调用net都重新编图。如以下代码使用mutable接口,用num = mutable(i)代替i可以避免多次编图。

import numpy as np
import mindspore as ms
from mindspore import nn, jit, mutable

class Net(nn.Cell):
    def __init__(self):
        super().__init__()

    @jit  
    def construct(self, x, a, i):
        if i > 5:
            return x + a
        else:
            return x

net = Net()
out_me = []
for i in range(1, 10):
    num = mutable(i)
    input_x = ms.Tensor(np.random.rand(2, 3, 4, 5).astype(np.float32))
    input_a = ms.Tensor(np.random.rand(2, 3, 4, 5).astype(np.float32))
    out_me.append(net(input_x, input_a, num))

Tuple或List的长度发生改变

当cell或者函数的输入有Tuple或List,且长度发生改变时,多次调用该函数或cell将会导致多次编译。如以下代码

import numpy as np
import mindspore as ms
from mindspore import nn, jit, mutable

class Net(nn.Cell):
    def __init__(self):
        super().__init__()

    @jit  
    def construct(self, x, a, in_me):
        if len(in_me) > 5:
            return x + a
        else:
            return x

net = Net()
in_me = []
out_me = []
for i in range(1, 10):
    in_me.append(i)
    input_x = ms.Tensor(np.random.rand(2, 3, 4, 5).astype(np.float32))
    input_a = ms.Tensor(np.random.rand(2, 3, 4, 5).astype(np.float32))
    out_me.append(net(input_x, input_a, in_me))

其中net的每次调用之间输入的list in_me 的长度都不一样,会导致每次调用net都重新编图。如以下代码使用mutable接口,并将dynamic_len参数设置为True,用in_me = mutable([], True)代替in_me = []可以避免多次编图。

import numpy as np
import mindspore as ms
from mindspore import nn, jit, mutable

class Net(nn.Cell):
    def __init__(self):
        super().__init__()

    @jit  
    def construct(self, x, a, in_me):
        if len(in_me) > 5:
            return x + a
        else:
            return x

net = Net()
in_me = mutable([], True)
out_me = []
for i in range(1, 10):
    in_me.append(i)
    input_x = ms.Tensor(np.random.rand(2, 3, 4, 5).astype(np.float32))
    input_a = ms.Tensor(np.random.rand(2, 3, 4, 5).astype(np.float32))
    out_me.append(net(input_x, input_a, in_me))

网络的输入是tuple[Tensor]、list[Tensor]或Dict[Tensor],即使里面Tensor的shape和dtype没有发生变化

当cell或者函数的输入有Tuple或List,且元素为tensor时,多次调用该函数或cell将会导致多次编译。如以下代码

import numpy as np
import mindspore as ms
from mindspore import nn, jit, mutable, Tensor

class Net(nn.Cell):
    def __init__(self):
        super().__init__()

    @jit  
    def construct(self, x, a, in_me):
        if len(in_me) > 0:
            return x + a
        else:
            return x

net = Net()
in_me = [Tensor(1), Tensor(2)]
out_me = []
for i in range(1, 10):
    input_x = ms.Tensor(np.random.rand(2, 3, 4, 5).astype(np.float32))
    input_a = ms.Tensor(np.random.rand(2, 3, 4, 5).astype(np.float32))
    out_me.append(net(input_x, input_a, in_me))

其中net的输入中有tensor list,会导致每次调用net都重新编图。如以下代码使用mutable接口,用in_me = mutable([Tensor(1), Tensor(2)])代替in_me = [Tensor(1), Tensor(2)]可以避免多次编图。

import numpy as np
import mindspore as ms
from mindspore import nn, jit, mutable, Tensor

class Net(nn.Cell):
    def __init__(self):
        super().__init__()

    @jit  
    def construct(self, x, a, in_me):
        if len(in_me) > 0:
            return x + a
        else:
            return x

net = Net()
in_me = mutable([Tensor(1), Tensor(2)])
out_me = []
for i in range(1, 10):
    input_x = ms.Tensor(np.random.rand(2, 3, 4, 5).astype(np.float32))
    input_a = ms.Tensor(np.random.rand(2, 3, 4, 5).astype(np.float32))
    out_me.append(net(input_x, input_a, in_me))