基于MindSpore Layout推导各卡上的Tensor分片(图解法)

前情提要:

1. Layout简介

关于Layout的基础知识可以参照以下资料:

1.1 Layout

Layout是mindspore中用于统一描述Tensor排布的类。

其主要有3个关键成员变量:

  • ①device_matrix:设备矩阵,用于描述集群中的卡如何排布。
  • ②alias_name:别名,用于指定device_matrix中各个设备维度的名字。
  • ③tensor_map:张量映射,用于描述Tensor的各个维度如何切分。

如下代码样例是为一个Tensor配置切分策略的通用模板:

# 设备矩阵,假设一共有8张卡
device_matrix = (2, 4)
# 别名,为了方便起见直接按从左到右的维度顺序命名
alias_name = ("axis_0", "axis_1")
# 张量映射,假设Tensor.shape=(2, 4)
tensor_map = ("axis_0", "axis_1")
​
layout = Layout(device_matrix, alias_name)
shard_strategy = layout(tensor_map)

下面我们依据这个例子逐一介绍这3个变量的含义.

1.2 device_matrix

设备矩阵,用于描述集群中的卡如何排布,为了便于理解此处以2维设备矩阵为例。

# 设备矩阵
device_matrix = (2, 4)

假设一共有8张卡,根据device_matrix = (2, 4),这8张卡会排布成2*4的矩阵。

1.3 alias_name

别名,用于指定device_matrix中各个设备维度的名字。

# 别名,为了方便起见直接按从左到右的维度顺序命名
alias_name = ("axis_0", "axis_1")

根据alias_name = ("axis_0", "axis_1"),我们可以为设备矩阵中的各个维度标明别名。

1.4 tensor_map

张量映射,用于描述Tensor的各个维度如何切分。

# 张量映射,假设Tensor.shape=(2, 4)
tensor_map = ("axis_0", "axis_1")

其中len(tensor_map) == len(Tensor.shape),对于一个tensor_map,假设tensor_map[i] = j:

  • 如果j = "None",表示张量的第i维度不切分。
  • 如果j = "axis_?",表示沿着设备矩阵的"axis_?"维度切分Tensor的第i维度。

对于样例中的tensor_map = ("axis_0", "axis_1"),构造一个shape = (2, 4)的Tensor,其依据device_matrix和tensor_map的分片过程如下:

1.5 Tensor分片结果

最终基于如下Layout配置,得到的分片结果如图所示:

# 设备矩阵,假设一共有8张卡
device_matrix = (2, 4)
# 别名,为了方便起见直接按从左到右的维度顺序命名
alias_name = ("axis_0", "axis_1")
# 张量映射,假设Tensor.shape=(2, 4)
tensor_map = ("axis_0", "axis_1")

至此,Layout中的3个关键成员变量介绍完毕,同时也以一个最简单的例子演示了一遍Tensor分片的推导过程,下面开始介绍完整的基于Layout推导Tensor分片的方法。

2. 推导方法

基于Layout推导Tensor分片的方法总结起来一共2步:

  • ①依据tensor_map从左到右切分Tensor: 按从左到右的顺序,沿着tensor_map中指定的device_matrix维度,依次切分Tensor的指定维度。
  • ②依据device_matrix从右到左复制Tensor: 按从右到左的顺序,沿着①中device_matrix未使用的维度,依次复制Tensor分片。

本章以3维设备矩阵和2维Tensor举例,讲解推导方法,更多更复杂的切分样例请参照第3节中的推导样例。

# 设备矩阵,假设一共有8张卡
device_matrix = (2, 2, 2)
# 别名,依旧按从左到右的维度顺序命名
alias_name = ("axis_0", "axis_1", "axis_2")
# 张量映射,假设Tensor.shape=(4, 4)
tensor_map = ("axis_2", "axis_0")

2.1 依据tensor_map从左到右切分Tensor

首先依据device_matrix = (2, 2, 2)将设备排成2 * 2 * 2的3维分布:

此处我们可以得到在不同的device_matrix维度上,各个rank之间的对应关系:

  • 对于device_matrix的"axis_0"维度: rank0↔rank4 rank1↔rank5 rank2↔rank6 rank3↔rank7
  • 对于device_matrix的"axis_1"维度: rank0↔rank2 rank1↔rank3 rank4↔rank6 rank5↔rank7
  • 对于device_matrix的"axis_2"维度: rank0↔rank1 rank2↔rank3 rank4↔rank5 rank6↔rank7

接着,初始化一个shape=(4, 4)的Tensor,其tensor_map = ("axis_2", "axis_0")

首先将Tensor放置在rank0上:

下面依据tensor_map从左到右切分Tensor

此时原本rank0上的Tensor已经切分完毕,但是可以发现rank2, rank3, rank6, rank7上没有Tensor分片,而这正是因为device_matrix的"axis_1"维度并未参与切分,因此需要在该维度上进行复制。

2.2 依据device_matrix从右到左复制Tensor

正如前文所述,由于device_matrix的"axis_1"维度并未参与切分,因此rank2, rank3, rank6, rank7上没有Tensor分片,因此我们需要依据device_matrix从右到左复制Tensor(本例中只剩"axis_1"维度未参与切分,更多维度未参与切分的样例请参照第3节。):

至此我们便完成了基于以下Layout可视化推导Tensor分片的流程。

# 设备矩阵,假设一共有8张卡
device_matrix = (2, 2, 2)
# 别名,依旧按从左到右的维度顺序命名
alias_name = ("axis_0", "axis_1", "axis_2")
# 张量映射,假设Tensor.shape=(4, 4)
tensor_map = ("axis_2", "axis_0")

并且其实可以发现如果device_matrix中的每个维度都参与了切分的话,实际上并不需要进行复制,只需要第一步“依据tensor_map从左到右切分Tensor”即可。

为了验证结果的正确性,我们使用原始的列表推导法进行校验(Wiki: 依据dev_matrix和tensor_map推导各卡上的Tensor分片):

可见列表推导法的结果与可视化推导的结果一致,证明了此“基于Layout可视化推导Tensor分片”方法的正确性,更多样例请参照第3节。

3. 推导样例

本节将列举几种常见Layout配置样例,并按照第2节的方法进行推导。

样例 Tensor.shape device_matrix alias_name tensor_map 是否切满 按tensor_map从左到右切分 按device_matrix从右到左复制
3.1 (4, 4) (2, 4) (“axis_0”, “axis_1”) (“axis_1”, “axis_0”) ①"axis_1"切tensor第0维;②"axis_0"切tensor第1维 切满无需复制
3.2 (4, 4) (4, 2) (“axis_0”, “axis_1”) (“axis_1”, “None”) ①"axis_1"切tensor第0维;②"None"不切 ①"axis_0"复制
3.3 (4, 4) (2, 1, 4) (“axis_0”, “axis_1”, “axis_2”) (“axis_0”, “axis_2”) ①"axis_0"切tensor第0维;②"axis_2"切tensor第1维; 切满无需复制
3.4 (4, 4) (2, 2, 2) (“axis_0”, “axis_1”, “axis_2”) (“None”, “axis_1”) ①"None"不切;②"axis_1"切tensor第2维 ①"axis_2"复制;②"axis_0"复制
3.5 (2, 2, 4) (2, 2, 2) (“axis_0”, “axis_1”, “axis_2”) (“axis_1”, “axis_2”, “axis_0”) ①"axis_1"切tensor第0维;②"axis_2"切tensor第1维;③"axis_0"切tensor第2维 切满无需复制
3.6 (2, 2, 4) (2, 2, 2) (“axis_0”, “axis_1”, “axis_2”) (“None”, “axis_0”, “None”) ①"None"不切;②"axis_0"切tensor第1维;③"None"不切 ①"axis_2"复制;②"axis_1"复制
3.7 (2, 2, 2, 2) (2, 2, 2, 2) (“axis_0”, “axis_1”, “axis_2”, “axis_3”) (“None”, “axis_2”, “axis_1”, “None”) ①"None"不切;②"axis_2"切tensor第1维;③"axis_1"切tensor第2维;④"None"不切 ①"axis_3"复制;②"axis_0"复制

3.1 2维Tensor+2维device_matrix+8卡切满

tensor.shape = (4, 4)
device_matrix = (2, 4)
alias_name = ("axis_0", "axis_1")
tensor_map = ("axis_1", "axis_0")

3.1.1 依据tensor_map从左到右切分Tensor

3.1.2 依据device_matrix从右到左复制Tensor

device_matrix的所有维度都参与了切分,并且设备切满了,因此无需复制。 至此分片推导完毕。

3.2 2维Tensor+2维device_matrix+8卡不切满

tensor.shape = (4, 4)
device_matrix = (4, 2)
alias_name = ("axis_0", "axis_1")
tensor_map = ("axis_1", "None")

3.2.1 依据tensor_map从左到右切分Tensor

3.2.2 依据device_matrix从右到左复制Tensor

至此分片推导完毕。

3.3 2维Tensor+3维device_matrix+8卡切满

tensor.shape = (4, 4)
device_matrix = (2, 1, 4)
alias_name = ("axis_0", "axis_1", "axis_2")
tensor_map = ("axis_0", "axis_2")

3.3.1 依据tensor_map从左到右切分Tensor

3.3.2 依据device_matrix从右到左复制Tensor

device_matrix的"axis_1"维度未参与切分,但是"axis_1"==1,并且设备8卡切满,因此无需复制。 至此分片推导完毕。

3.4 2维Tensor+3维device_matrix+8卡不切满

tensor.shape = (4, 4)
device_matrix = (2, 2, 2)
alias_name = ("axis_0", "axis_1", "axis_2")
tensor_map = ("None", "axis_1")

3.4.1 依据tensor_map从左到右切分Tensor

3.4.2 依据device_matrix从右到左复制Tensor

此样例中剩余2个维度"axis_0"和"axis_2"未参与切分,按照从右到左的顺序"axis_2"->"axis_0"进行复制。

至此分片推导完毕。

3.5 3维Tensor+3维device_matrix+8卡切满

tensor.shape = (2, 2, 4)
device_matrix = (2, 2, 2)
alias_name = ("axis_0", "axis_1", "axis_2")
tensor_map = ("axis_1", "axis_0", "axis_2")

3.5.1 依据tensor_map从左到右切分Tensor

3.5.2 依据device_matrix从右到左复制Tensor

device_matrix的所有维度都参与了切分,并且设备切满了,因此无需复制。 至此分片推导完毕。

3.6 3维Tensor+3维device_matrix+8卡不切满

tensor.shape = (2, 2, 4)
device_matrix = (2, 2, 2)
alias_name = ("axis_0", "axis_1", "axis_2")
tensor_map = ("None", "axis_0", "None")

3.6.1 依据tensor_map从左到右切分Tensor

3.6.2 依据device_matrix从右到左复制Tensor

此样例中剩余2个维度"axis_1"和"axis_2"未参与切分,按照从右到左的顺序"axis_2"->"axis_1"进行复制。

至此分片推导完毕。

3.7 3维Tensor+4维device_matrix+8卡不切满

tensor.shape = (2, 2, 2, 2)
device_matrix = (2, 2, 2, 2)
alias_name = ("axis_0", "axis_1", "axis_2", "axis_3")
tensor_map = ("None", "axis_2", "axis_1", "None")

3.7.1 依据tensor_map从左到右切分Tensor

3.7.2 依据device_matrix从右到左复制Tensor

此样例中剩余2个维度"axis_0"和"axis_3"未参与切分,按照从右到左的顺序"axis_3"->"axis_0"进行复制。

至此分片推导完毕。