MindSpore拆分dataset输入给多输入模型

1 系统环境

硬件环境(Ascend/GPU/CPU): Ascend/GPU/CPU
MindSpore版本: mindspore=2.0.0
执行模式(PyNative/ Graph):不限
Python版本: Python=3.9.7
操作系统平台: 不限

2 报错信息

2.1 问题描述

当前模型是由四个小模型的输出横向拼接成一个大模型,需要把一个dataset在输入模型时横向拆分成4部分数据,分别输入给4个小模型,用mindspore该如何实现?
在tensorflow用的是如下代码实现上述功能的,即使用map把dataset拆分成4类数据分别输入4个小模型,把4个小模型的输出用keras.layers.Concatenate函数合并成一行:

dataset_map_func(*args):
    """
    把dataset切片成4类数据,分别输入4个model。
    feature_slice是一个dict,存储的是slice切片类型。
    """
    feature = dict( ('input_'+c, tf.transpose(args[feature_slice[c]])) for c in ['1','2','3','4'])
    return feature, label

x = keras.layers.Concatenate(axis=1)([model_1.output, model_2.output, model_3.output, model_4.output])
model = keras.Model(inputs=[model_1.input, model_2.input, model_3.input, model_4.input], outputs=outputs)
dataset = dataset.map(dataset_map_func)

3 解决方案

可以参考如下的逻辑实现

dataset = GeneratorDataset(source=loader, column_names=["x1", "x2", "x3", "x4", "label"])  
class Net(nn.Cell):  
    def __init__(self):  
        super(Net, self).__init__()  
        self.net1 = Net1()  
        self.net2 = Net2()  
        self.net3 = Net3()  
        self.net4 = Net4()  
    def construct(self, x1, x2, x3, x4):  
        y1 = self.net1(x)  
        y2 = self.net1(x)  
        y3 = self.net1(x)  
        y4 = self.net1(x)  
        return y1,y2,y3,y4  
dataset_train = create_dataset()  
network = Net()  
opt = nn.Momentum()  
loss_fn = nn.SoftmaxCrossEntropyWithLogits()  
model = Model(network)  
model .train(args.epoch_size, dataset_train)