解决mindspore.dataset.Dataset.split切分数据集时randomize=True时分割出的数据不够随机问题

1 系统环境

硬件环境(Ascend/GPU/CPU): Ascend/GPU/CPU
MindSpore版本: mindspore=2.0.0
执行模式(PyNative/ Graph):不限
Python版本: Python=3.9.7
操作系统平台: 不限

2 报错信息

2.1 问题描述

使用mindspore.dataset.Dataset.split切分数据集时,randomize=True切分出来的数据不够随机,数据集标签为0~23
使用NumpySlicesDataset函数加载数据(shuffle=False),split中randmoize=True切分出来效果如下图

使用NumpySlicesDataset函数加载数据(shuffle=True),split中randmoize=Flase切分出来效果如图

2.2 脚本信息

dataset = ds.NumpySlicesDataset((X,Y),['data','lable'],shuffle=False)  
dataset = dataset.map(operations=ds.trainforms.TypeCast(ms.int32),input_columns="lable")  
batch_size = 128  
train_dataset,test_dataset = dataset.split([0.8,0.2],randomize=False)  
train_dataset = train_dataset.batch(batch_size=batch_size)  
test_dataset = test_dataset.batch(batch_size=128)  
for data,label in test_dataset.create_tuple_iterator():  
    print(label)

3 解决方案

创建数据集的时候打开shuffle操作,然后对数据集进行相应的切分。

dataset = ds.NumpySlicesDataset((X,Y),['data','lable'],shuffle=True)  
dataset = dataset.map(operations=ds.trainforms.TypeCast(ms.int32),input_columns="lable")  
batch_size = 128  
train_dataset,test_dataset = dataset.split([0.8,0.2],randomize=False)  
train_dataset = train_dataset.batch(batch_size=batch_size)  
test_dataset = test_dataset.batch(batch_size=128)  
for data,label in test_dataset.create_tuple_iterator():  
    print(label)