如何使用MindSpore替换PyTorch的torch.nn.init

1 系统环境

硬件环境(Ascend/GPU/CPU): Ascend/GPU/CPU
MindSpore版本: mindspore=2.0.0
执行模式(PyNative/ Graph):不限
Python版本: Python=3.7
操作系统平台: 不限

2 问题描述

做pytorch迁移到MindSpore的过程中需要更改调用torch.nn.init相关的部分,MindSpore该如何替换

3 根因分析

torch.nn.init => mindspore.common.initializer 当然除了 torch.nn.init.eye_ 对应的是mindspore.ops.eye

4 解决方案

import torch
import torch.nn.init as torch_init
import mindspore as ms
from mindspore.common import initializer as ms_init
from mindspore import ops

uniform

w = torch.empty(3, 5)
torch_init.uniform_(w)
print(w)

tensor([[0.6273, 0.6722, 0.4320, 0.0803, 0.4569],
        [0.8467, 0.2664, 0.9759, 0.4138, 0.4804],
        [0.6302, 0.9383, 0.8273, 0.3065, 0.8729]])

w = ms_init.initializer(ms_init.Uniform(), (3, 5))
print(w)
w = ms_init.initializer("uniform", (3, 5))
print(w)

[[ 0.02076868 -0.00353139 -0.02801976  0.01483764  0.06549707]
    [ 0.05656932  0.0125913  -0.04658895  0.01214548 -0.06701916]
    [-0.01464871 -0.02846382  0.05644834 -0.05948469 -0.00985342]]
[[ 0.00580789 -0.05222291  0.04916679 -0.0379335  -0.00949833]
    [-0.04182563 -0.04885998  0.03340004 -0.05639874  0.0465859 ]
    [-0.01674308 -0.03311962  0.06953544  0.01696784  0.03579749]]

constant

w = torch.empty(3, 5)
torch_init.constant_(w, 0.3)
print(w)

tensor([[0.3000, 0.3000, 0.3000, 0.3000, 0.3000],
        [0.3000, 0.3000, 0.3000, 0.3000, 0.3000],
        [0.3000, 0.3000, 0.3000, 0.3000, 0.3000]])

w = ms_init.initializer(ms_init.Constant(0.3), (3, 5))
print(w)

[[0.3 0.3 0.3 0.3 0.3]
    [0.3 0.3 0.3 0.3 0.3]
    [0.3 0.3 0.3 0.3 0.3]]

ones

w = torch.empty(3, 5)
torch_init.ones_(w)
print(w)

tensor([[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]])

w = ms_init.initializer(ms_init.One(), (3, 5))
print(w)
w = ms_init.initializer("one", (3, 5))
print(w)

[[1. 1. 1. 1. 1.]
    [1. 1. 1. 1. 1.]
    [1. 1. 1. 1. 1.]]
[[1. 1. 1. 1. 1.]
    [1. 1. 1. 1. 1.]
    [1. 1. 1. 1. 1.]]

Zero

w = torch.empty(3, 5)
torch_init.zeros_(w)
print(w)

tensor([[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]])

w = ms_init.initializer(ms_init.Zero(), (3, 5))
print(w)
w = ms_init.initializer("zero", (3, 5))
print(w)

[[0. 0. 0. 0. 0.]
    [0. 0. 0. 0. 0.]
    [0. 0. 0. 0. 0.]]
[[0. 0. 0. 0. 0.]
    [0. 0. 0. 0. 0.]
    [0. 0. 0. 0. 0.]]

eye

w = torch.empty(3, 5)
torch_init.eye_(w)

tensor([[1., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 0., 1., 0., 0.]])

w = ops.eye(3, 5)
print(w)

[[1. 0. 0. 0. 0.]
    [0. 1. 0. 0. 0.]
    [0. 0. 1. 0. 0.]]

dirac

w = torch.empty(2, 2, 3, 3)
torch_init.dirac_(w, 2)
print(w)

tensor([[[[0., 0., 0.],
            [0., 1., 0.],
            [0., 0., 0.]],

            [[0., 0., 0.],
            [0., 0., 0.],
            [0., 0., 0.]]],


        [[[0., 0., 0.],
            [0., 1., 0.],
            [0., 0., 0.]],

            [[0., 0., 0.],
            [0., 0., 0.],
            [0., 0., 0.]]]])

w = ms_init.initializer(ms_init.Dirac(2), (2,2,3,3))
print(w)

[[[[0. 0. 0.]
    [0. 1. 0.]
    [0. 0. 0.]]

    [[0. 0. 0.]
    [0. 0. 0.]
    [0. 0. 0.]]]


    [[[0. 0. 0.]
    [0. 1. 0.]
    [0. 0. 0.]]

    [[0. 0. 0.]
    [0. 0. 0.]
    [0. 0. 0.]]]]

xavier_uniform

w = torch.empty(3, 5)
torch_init.xavier_uniform_(w, gain=1)
print(w)

tensor([[-0.4621,  0.0128,  0.6648,  0.8388, -0.7054],
        [ 0.1299,  0.7065,  0.3830,  0.0871,  0.1640],
        [ 0.8043,  0.7462, -0.0764,  0.5021,  0.7438]])

w = ms_init.initializer(ms_init.XavierUniform(1), (3,5))
print(w)
w = ms_init.initializer("xavier_uniform", (3, 5))
print(w)

[[-0.1365639  -0.6591273  -0.02490093  0.53927845 -0.40755984]
    [-0.30184573  0.5353374   0.05441066  0.39372778  0.13548988]
    [-0.33337212 -0.7249835  -0.45743167  0.27087837  0.3974703 ]]
[[ 0.66771585  0.8372303  -0.30232698 -0.23809277 -0.24208683]
    [ 0.47933128  0.6843046   0.7311251  -0.47484264  0.78137606]
    [-0.715034   -0.28117204  0.28003275 -0.523037   -0.7518653 ]]

xavier_normal

w = torch.empty(3, 5)
torch_init.xavier_normal_(w)

tensor([[ 0.3390, -0.1305, -0.7165,  0.7791, -0.7877],
        [-0.7793, -0.6927,  0.0883,  0.1080, -1.6134],
        [ 0.2015, -1.2122,  1.5879,  0.1446, -0.3547]])

w = ms_init.initializer(ms_init.XavierNormal(1), (3,5))
print(w)
w = ms_init.initializer("xavier_normal", (3, 5))
print(w)

[[ 0.93760556  0.3392253   0.40076658 -0.19003701  0.14194943]
    [ 0.10027266 -0.45792234  0.5151587   0.5029552   0.13373359]
    [-0.19225736  0.21525055 -0.1321451   0.30269623  0.5750851 ]]
[[ 0.59728    -0.78122866  0.1958348   0.9556407  -0.35464117]
    [ 0.5489269   0.07231056  0.45982736 -0.03406669  0.6918162 ]
    [ 0.1374928  -1.3428471   0.21672817  0.2914441   0.5336329 ]]

kaiming_uniform

w = torch.empty(3, 5)
torch_init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu')

tensor([[ 0.2769, -0.8207, -0.6044, -0.0624,  0.5949],
        [ 0.8521, -0.9850,  0.5606, -0.8289, -0.2912],
        [ 0.4932, -0.7423, -0.5104, -0.3367, -0.1449]])

w = ms_init.initializer(ms_init.HeUniform(mode='fan_in', nonlinearity='relu'), (3,5))
print(w)

[[-0.5309736   0.8559388  -0.92003417  0.58600163  0.7770014 ]
    [-0.7070332   0.6938597   0.20635565  0.9363468  -0.04296735]
    [-0.42041913  0.2702668  -0.63106734 -0.73498166  0.26201499]]

kaiming_normal

w = torch.empty(3, 5)
torch_init.kaiming_normal_(w, mode='fan_in', nonlinearity='relu')

tensor([[-0.1372,  0.7091, -1.0794, -0.2982, -0.5171],
        [ 0.2817,  0.6475, -0.3793, -0.0194,  0.1257],
        [-0.2764, -1.0841,  0.5978,  0.1805,  0.0318]])

w = ms_init.initializer(ms_init.HeNormal(mode='fan_in', nonlinearity='relu'), (3,5))
print(w)

[[ 0.04525672  0.54234827  0.29712144 -0.6286129   0.7247357 ]
    [-0.72248465 -0.5347963   0.06399149  0.39091495 -0.04421718]
    [-0.22517438  0.12626456 -0.03354136 -0.4759619  -0.5327487 ]]

trunc_normal

w = torch.empty(3, 5)
torch_init.trunc_normal_(w)

tensor([[ 0.0626,  0.5258,  0.7458, -0.1970,  1.2821],
        [ 0.2344,  0.5350,  0.4402,  0.2152,  1.0044],
        [ 0.3880, -0.1332,  0.6127,  0.9253,  1.5593]])

w = ms_init.initializer(ms_init.TruncatedNormal(1), (3,5))
print(w)
w = ms_init.initializer("truncatedNormal", (3, 5))
print(w)

[[ 0.1207792  -0.21480061  0.09069321 -0.09830502 -0.61501193]
    [-0.08540737  0.684896   -0.6055059   0.07456751 -1.7249392 ]
    [-0.18215409  0.37354046  0.83819115 -0.25331137 -1.451913  ]]
[[-0.01586118 -0.00392491  0.00117821 -0.01158508 -0.00948633]
    [-0.00097138 -0.01572727 -0.00207051  0.00047637 -0.00899001]
    [ 0.0042324   0.00122826 -0.00749713  0.01155263  0.00115168]]

orthogonal

w = torch.empty(3, 5)
torch_init.orthogonal_(w)

tensor([[ 0.4940,  0.6654, -0.2924,  0.2797, -0.3866],
        [-0.3244,  0.7065,  0.0692, -0.4769,  0.4043],
        [-0.6315, -0.0317, -0.6525, -0.0621, -0.4130]])

w = ms_init.initializer(ms_init.Orthogonal(1), (3,5))
print(w)
w = ms_init.initializer("orthogonal", (3, 5))
print(w)

[[ 0.2581128   0.7996161   0.47184658 -0.21337771  0.16069418]
    [ 0.09046545  0.12070206  0.23474057  0.43679783 -0.85519093]
    [ 0.2695082   0.1381112  -0.53185076 -0.6602976  -0.43523848]]
[[ 0.5105753   0.54908097  0.17701794 -0.10290311 -0.6292047 ]
    [-0.71567476 -0.0970166   0.3541471  -0.29048243 -0.5182636 ]
    [ 0.3556234  -0.59223086 -0.25964984 -0.64582545 -0.1956683 ]]

sparse

w = torch.empty(3, 5)
torch_init.sparse_(w, sparsity=0.1)

tensor([[ 0.0126, -0.0038,  0.0023,  0.0000,  0.0000],
        [ 0.0086,  0.0000, -0.0105, -0.0119,  0.0146],
        [ 0.0000, -0.0107,  0.0000,  0.0161, -0.0201]])

w = ms_init.initializer(ms_init.Sparse(sparsity=0.1), (3,5))
print(w)

[[-0.01048618  0.          0.0013259   0.01346168 -0.00077514]
    [-0.01921918 -0.00778484 -0.00158085  0.00317146  0.        ]
    [ 0.          0.01052929  0.          0.          0.00343628]]