1 系统环境
硬件环境(Ascend/GPU/CPU): Ascend/GPU/CPU
MindSpore版本: mindspore=2.0.0
执行模式(PyNative/ Graph):不限
Python版本: Python=3.7
操作系统平台: 不限
2 问题描述
做pytorch迁移到MindSpore的过程中需要更改调用torch.nn.init相关的部分,MindSpore该如何替换
3 根因分析
torch.nn.init => mindspore.common.initializer 当然除了 torch.nn.init.eye_ 对应的是mindspore.ops.eye
4 解决方案
import torch
import torch.nn.init as torch_init
import mindspore as ms
from mindspore.common import initializer as ms_init
from mindspore import ops
uniform
w = torch.empty(3, 5)
torch_init.uniform_(w)
print(w)
tensor([[0.6273, 0.6722, 0.4320, 0.0803, 0.4569],
[0.8467, 0.2664, 0.9759, 0.4138, 0.4804],
[0.6302, 0.9383, 0.8273, 0.3065, 0.8729]])
w = ms_init.initializer(ms_init.Uniform(), (3, 5))
print(w)
w = ms_init.initializer("uniform", (3, 5))
print(w)
[[ 0.02076868 -0.00353139 -0.02801976 0.01483764 0.06549707]
[ 0.05656932 0.0125913 -0.04658895 0.01214548 -0.06701916]
[-0.01464871 -0.02846382 0.05644834 -0.05948469 -0.00985342]]
[[ 0.00580789 -0.05222291 0.04916679 -0.0379335 -0.00949833]
[-0.04182563 -0.04885998 0.03340004 -0.05639874 0.0465859 ]
[-0.01674308 -0.03311962 0.06953544 0.01696784 0.03579749]]
constant
w = torch.empty(3, 5)
torch_init.constant_(w, 0.3)
print(w)
tensor([[0.3000, 0.3000, 0.3000, 0.3000, 0.3000],
[0.3000, 0.3000, 0.3000, 0.3000, 0.3000],
[0.3000, 0.3000, 0.3000, 0.3000, 0.3000]])
w = ms_init.initializer(ms_init.Constant(0.3), (3, 5))
print(w)
[[0.3 0.3 0.3 0.3 0.3]
[0.3 0.3 0.3 0.3 0.3]
[0.3 0.3 0.3 0.3 0.3]]
ones
w = torch.empty(3, 5)
torch_init.ones_(w)
print(w)
tensor([[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]])
w = ms_init.initializer(ms_init.One(), (3, 5))
print(w)
w = ms_init.initializer("one", (3, 5))
print(w)
[[1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1.]]
[[1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1.]]
Zero
w = torch.empty(3, 5)
torch_init.zeros_(w)
print(w)
tensor([[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]])
w = ms_init.initializer(ms_init.Zero(), (3, 5))
print(w)
w = ms_init.initializer("zero", (3, 5))
print(w)
[[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]]
[[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]]
eye
w = torch.empty(3, 5)
torch_init.eye_(w)
tensor([[1., 0., 0., 0., 0.],
[0., 1., 0., 0., 0.],
[0., 0., 1., 0., 0.]])
w = ops.eye(3, 5)
print(w)
[[1. 0. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 0. 1. 0. 0.]]
dirac
w = torch.empty(2, 2, 3, 3)
torch_init.dirac_(w, 2)
print(w)
tensor([[[[0., 0., 0.],
[0., 1., 0.],
[0., 0., 0.]],
[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]]],
[[[0., 0., 0.],
[0., 1., 0.],
[0., 0., 0.]],
[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]]]])
w = ms_init.initializer(ms_init.Dirac(2), (2,2,3,3))
print(w)
[[[[0. 0. 0.]
[0. 1. 0.]
[0. 0. 0.]]
[[0. 0. 0.]
[0. 0. 0.]
[0. 0. 0.]]]
[[[0. 0. 0.]
[0. 1. 0.]
[0. 0. 0.]]
[[0. 0. 0.]
[0. 0. 0.]
[0. 0. 0.]]]]
xavier_uniform
w = torch.empty(3, 5)
torch_init.xavier_uniform_(w, gain=1)
print(w)
tensor([[-0.4621, 0.0128, 0.6648, 0.8388, -0.7054],
[ 0.1299, 0.7065, 0.3830, 0.0871, 0.1640],
[ 0.8043, 0.7462, -0.0764, 0.5021, 0.7438]])
w = ms_init.initializer(ms_init.XavierUniform(1), (3,5))
print(w)
w = ms_init.initializer("xavier_uniform", (3, 5))
print(w)
[[-0.1365639 -0.6591273 -0.02490093 0.53927845 -0.40755984]
[-0.30184573 0.5353374 0.05441066 0.39372778 0.13548988]
[-0.33337212 -0.7249835 -0.45743167 0.27087837 0.3974703 ]]
[[ 0.66771585 0.8372303 -0.30232698 -0.23809277 -0.24208683]
[ 0.47933128 0.6843046 0.7311251 -0.47484264 0.78137606]
[-0.715034 -0.28117204 0.28003275 -0.523037 -0.7518653 ]]
xavier_normal
w = torch.empty(3, 5)
torch_init.xavier_normal_(w)
tensor([[ 0.3390, -0.1305, -0.7165, 0.7791, -0.7877],
[-0.7793, -0.6927, 0.0883, 0.1080, -1.6134],
[ 0.2015, -1.2122, 1.5879, 0.1446, -0.3547]])
w = ms_init.initializer(ms_init.XavierNormal(1), (3,5))
print(w)
w = ms_init.initializer("xavier_normal", (3, 5))
print(w)
[[ 0.93760556 0.3392253 0.40076658 -0.19003701 0.14194943]
[ 0.10027266 -0.45792234 0.5151587 0.5029552 0.13373359]
[-0.19225736 0.21525055 -0.1321451 0.30269623 0.5750851 ]]
[[ 0.59728 -0.78122866 0.1958348 0.9556407 -0.35464117]
[ 0.5489269 0.07231056 0.45982736 -0.03406669 0.6918162 ]
[ 0.1374928 -1.3428471 0.21672817 0.2914441 0.5336329 ]]
kaiming_uniform
w = torch.empty(3, 5)
torch_init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu')
tensor([[ 0.2769, -0.8207, -0.6044, -0.0624, 0.5949],
[ 0.8521, -0.9850, 0.5606, -0.8289, -0.2912],
[ 0.4932, -0.7423, -0.5104, -0.3367, -0.1449]])
w = ms_init.initializer(ms_init.HeUniform(mode='fan_in', nonlinearity='relu'), (3,5))
print(w)
[[-0.5309736 0.8559388 -0.92003417 0.58600163 0.7770014 ]
[-0.7070332 0.6938597 0.20635565 0.9363468 -0.04296735]
[-0.42041913 0.2702668 -0.63106734 -0.73498166 0.26201499]]
kaiming_normal
w = torch.empty(3, 5)
torch_init.kaiming_normal_(w, mode='fan_in', nonlinearity='relu')
tensor([[-0.1372, 0.7091, -1.0794, -0.2982, -0.5171],
[ 0.2817, 0.6475, -0.3793, -0.0194, 0.1257],
[-0.2764, -1.0841, 0.5978, 0.1805, 0.0318]])
w = ms_init.initializer(ms_init.HeNormal(mode='fan_in', nonlinearity='relu'), (3,5))
print(w)
[[ 0.04525672 0.54234827 0.29712144 -0.6286129 0.7247357 ]
[-0.72248465 -0.5347963 0.06399149 0.39091495 -0.04421718]
[-0.22517438 0.12626456 -0.03354136 -0.4759619 -0.5327487 ]]
trunc_normal
w = torch.empty(3, 5)
torch_init.trunc_normal_(w)
tensor([[ 0.0626, 0.5258, 0.7458, -0.1970, 1.2821],
[ 0.2344, 0.5350, 0.4402, 0.2152, 1.0044],
[ 0.3880, -0.1332, 0.6127, 0.9253, 1.5593]])
w = ms_init.initializer(ms_init.TruncatedNormal(1), (3,5))
print(w)
w = ms_init.initializer("truncatedNormal", (3, 5))
print(w)
[[ 0.1207792 -0.21480061 0.09069321 -0.09830502 -0.61501193]
[-0.08540737 0.684896 -0.6055059 0.07456751 -1.7249392 ]
[-0.18215409 0.37354046 0.83819115 -0.25331137 -1.451913 ]]
[[-0.01586118 -0.00392491 0.00117821 -0.01158508 -0.00948633]
[-0.00097138 -0.01572727 -0.00207051 0.00047637 -0.00899001]
[ 0.0042324 0.00122826 -0.00749713 0.01155263 0.00115168]]
orthogonal
w = torch.empty(3, 5)
torch_init.orthogonal_(w)
tensor([[ 0.4940, 0.6654, -0.2924, 0.2797, -0.3866],
[-0.3244, 0.7065, 0.0692, -0.4769, 0.4043],
[-0.6315, -0.0317, -0.6525, -0.0621, -0.4130]])
w = ms_init.initializer(ms_init.Orthogonal(1), (3,5))
print(w)
w = ms_init.initializer("orthogonal", (3, 5))
print(w)
[[ 0.2581128 0.7996161 0.47184658 -0.21337771 0.16069418]
[ 0.09046545 0.12070206 0.23474057 0.43679783 -0.85519093]
[ 0.2695082 0.1381112 -0.53185076 -0.6602976 -0.43523848]]
[[ 0.5105753 0.54908097 0.17701794 -0.10290311 -0.6292047 ]
[-0.71567476 -0.0970166 0.3541471 -0.29048243 -0.5182636 ]
[ 0.3556234 -0.59223086 -0.25964984 -0.64582545 -0.1956683 ]]
sparse
w = torch.empty(3, 5)
torch_init.sparse_(w, sparsity=0.1)
tensor([[ 0.0126, -0.0038, 0.0023, 0.0000, 0.0000],
[ 0.0086, 0.0000, -0.0105, -0.0119, 0.0146],
[ 0.0000, -0.0107, 0.0000, 0.0161, -0.0201]])
w = ms_init.initializer(ms_init.Sparse(sparsity=0.1), (3,5))
print(w)
[[-0.01048618 0. 0.0013259 0.01346168 -0.00077514]
[-0.01921918 -0.00778484 -0.00158085 0.00317146 0. ]
[ 0. 0.01052929 0. 0. 0.00343628]]