使用MindSpore对vision.SlicePatches的数据集切分和合并

1 系统环境

硬件环境(Ascend/GPU/CPU): Ascend/GPU/CPU
MindSpore版本: mindspore=2.0.0
执行模式(PyNative/ Graph):不限
Python版本: Python=3.9
操作系统平台: 不限

2 报错信息

2.1 问题描述

采用以下方法将数据集进行切分,怎样将data切分出的[“data0”, “data1”, “data2”, “data3”, “data4”, “data5”, “data6”, “data7”]列按照顺序进行合并,合并成一列。

# num_h=1,num_w=8
num_h, num_w = split_size 
slice_patches_op = vision.SlicePatches(num_h, num_w) 
transforms_list = [slice_patches_op] 
# data_cols = ["data0", "data1", "data2", "data3", "data4", "data5", "data6", "data7"]
data_cols = ['data' + str(x) for x in range(num_h * num_w)] 
data_set = data_set.map(operations=transforms_list, input_columns=["data"], output_columns=data_cols) 

3解决方案

在data0 - data7 是8个不同的数据,其中的一个列就只代表一种数据。如果想把这些数据都加起来可以参考

# num_h=1,num_w=8  
num_h, num_w = split_size   
slice_patches_op = vision.SlicePatches(num_h, num_w)   
transforms_list = [slice_patches_op]   
# data_cols = ["data0", "data1", "data2", "data3", "data4", "data5", "data6", "data7"]  
data_cols = ['data' + str(x) for x in range(num_h * num_w)]   
data_set = data_set.map(operations=transforms_list, input_columns=["data"], output_columns=data_cols)  
def merge_col(data0, data1, data2, data3, data4, data5, data6, data7):  
    data_X = data0+data1+data2+data3+data4+data5+data6+data7  
    return data_X  
data_set = data_set.map(operations=merge_col, input_columns=data_cols, output_columns=["data_X"])

如果是一张size为40964096的图片,经过切分之后变成了8张4096512的图片,现在想把这个8张4096512的图片整合成一个数据集,之后进行训练,且样本的顺序为每张图片的切分顺序,data0-data7是8个不同的数据集,每个数据集共有10个样本。这里的合并是把这8列数据重新合成一列,这一列一共有810个样本。

data_set = data_set.map(operations=transforms_list, input_columns=["data"], output_columns=data_cols)  
for data in data_set.create_dict_iterator(output_numpy=True):  
    data0,data1,data2,data3,data4,data5,data6,data7 = data  
    save or merge data locally...

接着使用自定义类GeneratorDataset读取进来。