使用jit接口将一个函数或者cell用静态图编译加速时,如果输入发生以下变化,将会在每次执行时重新遍历,大大降低执行性能。因此为了避免该情况,可以通过改变脚本的方式使图编译流程识别到输入的动态状态。
Tensor的shape发生改变
当cell或者函数输入的tensor的shape改变时,多次调用该函数或cell将会导致多次编译。如以下代码
import numpy as np
import mindspore as ms
from mindspore import nn, jit
class Net(nn.Cell):
def __init__(self):
super().__init__()
@jit
def construct(self, x, a):
return x + a
net = Net()
out_me = []
for i in range(1, 10):
input_x = ms.Tensor(np.random.rand(i, 3, 4, 5).astype(np.float32))
input_a = ms.Tensor(np.random.rand(i, 3, 4, 5).astype(np.float32))
out_me.append(net(input_x, input_a))
其中net的每次调用之间输入的shape都不一样,会导致每次调用net都重新编图。如以下代码用@jit(dynamic=1)
代替@jit
可以避免多次编图,只重新编译两次。
import numpy as np
import mindspore as ms
from mindspore import nn, jit
class Net(nn.Cell):
def __init__(self):
super().__init__()
@jit(dynamic=1)
def construct(self, x, a):
return x + a
net = Net()
out_me = []
for i in range(1, 10):
input_x = ms.Tensor(np.random.rand(i, 3, 4, 5).astype(np.float32))
input_a = ms.Tensor(np.random.rand(i, 3, 4, 5).astype(np.float32))
out_me.append(net(input_x, input_a))
标量值发生改变
当cell或者函数的输入有标量,且标量值发生改变时,多次调用该函数或cell将会导致多次编译。如以下代码
import numpy as np
import mindspore as ms
from mindspore import nn, jit
class Net(nn.Cell):
def __init__(self):
super().__init__()
@jit(dynamic=1)
def construct(self, x, a, i):
if i > 5:
return x + a
else:
return x
net = Net()
out_me = []
for i in range(1, 10):
input_x = ms.Tensor(np.random.rand(2, 3, 4, 5).astype(np.float32))
input_a = ms.Tensor(np.random.rand(2, 3, 4, 5).astype(np.float32))
out_me.append(net(input_x, input_a, i))
其中net的每次调用之间输入的标量i的值都不一样,会导致每次调用net都重新编图。如以下代码使用mutable接口,用num = mutable(i)
代替i
可以避免多次编图。
import numpy as np
import mindspore as ms
from mindspore import nn, jit, mutable
class Net(nn.Cell):
def __init__(self):
super().__init__()
@jit
def construct(self, x, a, i):
if i > 5:
return x + a
else:
return x
net = Net()
out_me = []
for i in range(1, 10):
num = mutable(i)
input_x = ms.Tensor(np.random.rand(2, 3, 4, 5).astype(np.float32))
input_a = ms.Tensor(np.random.rand(2, 3, 4, 5).astype(np.float32))
out_me.append(net(input_x, input_a, num))
Tuple或List的长度发生改变
当cell或者函数的输入有Tuple或List,且长度发生改变时,多次调用该函数或cell将会导致多次编译。如以下代码
import numpy as np
import mindspore as ms
from mindspore import nn, jit, mutable
class Net(nn.Cell):
def __init__(self):
super().__init__()
@jit
def construct(self, x, a, in_me):
if len(in_me) > 5:
return x + a
else:
return x
net = Net()
in_me = []
out_me = []
for i in range(1, 10):
in_me.append(i)
input_x = ms.Tensor(np.random.rand(2, 3, 4, 5).astype(np.float32))
input_a = ms.Tensor(np.random.rand(2, 3, 4, 5).astype(np.float32))
out_me.append(net(input_x, input_a, in_me))
其中net的每次调用之间输入的list in_me
的长度都不一样,会导致每次调用net都重新编图。如以下代码使用mutable接口,并将dynamic_len参数设置为True,用in_me = mutable([], True)
代替in_me = []
可以避免多次编图。
import numpy as np
import mindspore as ms
from mindspore import nn, jit, mutable
class Net(nn.Cell):
def __init__(self):
super().__init__()
@jit
def construct(self, x, a, in_me):
if len(in_me) > 5:
return x + a
else:
return x
net = Net()
in_me = mutable([], True)
out_me = []
for i in range(1, 10):
in_me.append(i)
input_x = ms.Tensor(np.random.rand(2, 3, 4, 5).astype(np.float32))
input_a = ms.Tensor(np.random.rand(2, 3, 4, 5).astype(np.float32))
out_me.append(net(input_x, input_a, in_me))
网络的输入是tuple[Tensor]、list[Tensor]或Dict[Tensor],即使里面Tensor的shape和dtype没有发生变化
当cell或者函数的输入有Tuple或List,且元素为tensor时,多次调用该函数或cell将会导致多次编译。如以下代码
import numpy as np
import mindspore as ms
from mindspore import nn, jit, mutable, Tensor
class Net(nn.Cell):
def __init__(self):
super().__init__()
@jit
def construct(self, x, a, in_me):
if len(in_me) > 0:
return x + a
else:
return x
net = Net()
in_me = [Tensor(1), Tensor(2)]
out_me = []
for i in range(1, 10):
input_x = ms.Tensor(np.random.rand(2, 3, 4, 5).astype(np.float32))
input_a = ms.Tensor(np.random.rand(2, 3, 4, 5).astype(np.float32))
out_me.append(net(input_x, input_a, in_me))
其中net的输入中有tensor list,会导致每次调用net都重新编图。如以下代码使用mutable接口,用in_me = mutable([Tensor(1), Tensor(2)])
代替in_me = [Tensor(1), Tensor(2)]
可以避免多次编图。
import numpy as np
import mindspore as ms
from mindspore import nn, jit, mutable, Tensor
class Net(nn.Cell):
def __init__(self):
super().__init__()
@jit
def construct(self, x, a, in_me):
if len(in_me) > 0:
return x + a
else:
return x
net = Net()
in_me = mutable([Tensor(1), Tensor(2)])
out_me = []
for i in range(1, 10):
input_x = ms.Tensor(np.random.rand(2, 3, 4, 5).astype(np.float32))
input_a = ms.Tensor(np.random.rand(2, 3, 4, 5).astype(np.float32))
out_me.append(net(input_x, input_a, in_me))