1. 系统环境
硬件环境(Ascend/GPU/CPU): Ascend/GPU/CPU
软件环境: MindSpore 版本: 2.0.0
执行模式(静态图/动态图): 不限
Python 版本: 3.9
操作系统平台: linux
2. 报错信息
2.1 问题描述
采用以下方法将数据集进行切分,怎样将data切分出的[“data0”, “data1”, “data2”, “data3”, “data4”, “data5”, “data6”, “data7”]列按照顺序进行合并,合并成一列。
2.2 脚本信息
# num_h=1,num_w=8
num_h, num_w = split_size
slice_patches_op = vision.SlicePatches(num_h, num_w)
transforms_list = [slice_patches_op]
# data_cols = ["data0", "data1", "data2", "data3", "data4", "data5", "data6", "data7"]
data_cols = ['data' + str(x) for x in range(num_h * num_w)]
data_set = data_set.map(operations=transforms_list, input_columns=["data"], output_columns=data_cols)
3. 根因分析
SlicePatches 在水平和垂直方向上将Tensor切片为多个块。输出Tensor的数量等于 num_height*num_width
原数据集 ["data"]
切分为["data0", "data1", "data2", "data3", "data4", "data5", "data6", "data7"]
4. 解决方案
通过np.vstack将数据集中的对应的数据进行拼接
import numpy as np
import mindspore.dataset as ds
import mindspore.dataset.vision as vision
num_h=1
num_w=8
# num_h, num_w = split_size
slice_patches_op = vision.SlicePatches(num_h, num_w)
transforms_list = [slice_patches_op]
# data_cols = ["data0", "data1", "data2", "data3", "data4", "data5", "data6", "data7"]
data_cols = ['data' + str(x) for x in range(num_h * num_w)]
data = np.random.randint(0, 255, size=(1, 100, 96)).astype(np.uint8)
data_set = ds.NumpySlicesDataset(data, ["data"])
for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
print(len(item), item["data"].shape, item["data"].dtype)
data_set = data_set.map(operations=transforms_list, input_columns=["data"], output_columns=data_cols)
print(data_set)
for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
x = item["data0"]
for i in range(1,num_h * num_w):
x = np.vstack((x, item[f"data{i}"] ))
break
print(x.shape)
结果如下:
1 (100, 96) uint8
<mindspore.dataset.engine.datasets.MapDataset object at 0x0000029D81616198>
(800, 12)