报错ValueError: Input buffer_size is not within the required interval of [2, 2147483647].

1 系统环境

硬件环境(Ascend/GPU/CPU): Ascend
MindSpore版本: mindspore=2.0.0
执行模式(PyNative/ Graph):不限
Python版本: Python=3.7
操作系统平台: 不限

2 报错信息

2.1 问题描述

运行这里(https://openi.pcl.ac.cn/drizzlezyk/ddpm2/src/branch/master/train.py)的代码和数据,python train.py,报错ValueError: Input buffer_size is not within the required interval of [2, 2147483647].

2.2 报错信息

ValueError: Input buffer_size is not within the required interval of [2, 2147483647].复制

2.3 脚本代码

import ddpm as ddpm
import argparse
import os
from mindspore import context
from mindspore.communication.management import init
from mindspore.context import ParallelMode
import time
from upload import UploadOutput
import moxing as mox


def parse_args():
    parser = argparse.ArgumentParser(description="train ddpm",
                                        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument('--pretrain_path',
                        type=str,
                        default=None,
                        help='the pretrain model path')

    parser.add_argument('--data_url',
                        type=str,
                        default="C:\\Users\\Administrator\\PycharmProjects\\DDPM\\datasets\\test",
                        help='training data file path')

    parser.add_argument('--train_url',
                        default='./results',
                        type=str,
                        help='the path model and fig save path')

    parser.add_argument('--steps',
                        default=20000,
                        type=int,
                        help='training steps')

    parser.add_argument('--save_every',
                        default=5000,
                        type=int,
                        help='save_every')

    parser.add_argument('--num_samples',
                        default=4,
                        type=int,
                        help='num_samples must have a square root, like 4, 9, 16 ...')

    parser.add_argument('--device_target',
                        default="Ascend",
                        type=str,
                        help='device target')
    parser.add_argument('--image_size',
                        default=200,
                        type=int,
                        help='image size')
    args, _ = parser.parse_known_args()
    return args


def ObsToEnv(obs_data_url, data_dir):
    try:
        mox.file.copy_parallel(obs_data_url, data_dir)
        print("Successfully Download {} to {}".format(obs_data_url, data_dir))
    except Exception as e:
        print('moxing download {} to {} failed: '.format(obs_data_url, data_dir) + str(e))

    f = open("/cache/download_input.txt", 'w')
    f.close()
    try:
        if os.path.exists("/cache/download_input.txt"):
            print("download_input succeed")
    except Exception as e:
        print("download_input failed")
    return


def ObsUrlToEnv(obs_ckpt_url, ckpt_url):
    try:
        mox.file.copy(obs_ckpt_url, ckpt_url)
        print("Successfully Download {} to {}".format(obs_ckpt_url,ckpt_url))
    except Exception as e:
        print('moxing download {} to {} failed: '.format(obs_ckpt_url, ckpt_url) + str(e))
    return


def EnvToObs(train_dir, obs_train_url):
    try:
        mox.file.copy_parallel(train_dir, obs_train_url)
        print("Successfully Upload {} to {}".format(train_dir,obs_train_url))
    except Exception as e:
        print('moxing upload {} to {} failed: '.format(train_dir,obs_train_url) + str(e))
    return


def DownloadFromQizhi(obs_data_url, data_dir):
    device_num = int(os.getenv('RANK_SIZE'))
    if device_num == 1:
        ObsToEnv(obs_data_url, data_dir)
        # context.set_context(device_target=args_opt.device_target)
    if device_num > 1:
        # set device_id and init for multi-card training
        context.set_context(mode=context.GRAPH_MODE,
                            device_target=args_opt.device_target,
                            device_id=int(os.getenv('ASCEND_DEVICE_ID')))
        context.reset_auto_parallel_context()
        context.set_auto_parallel_context(device_num=device_num,
                                            parallel_mode=ParallelMode.DATA_PARALLEL,
                                            gradients_mean=True,
                                            parameter_broadcast=True)
        init()

        local_rank = int(os.getenv('RANK_ID'))
        if local_rank % 8 == 0:
            ObsToEnv(obs_data_url, data_dir)

        while not os.path.exists("/cache/download_input.txt"):
            time.sleep(1)
    return


def UploadToQizhi(train_dir, obs_train_url):
    device_num = int(os.getenv('RANK_SIZE'))
    local_rank = int(os.getenv('RANK_ID'))
    if device_num == 1:
        EnvToObs(train_dir, obs_train_url)
    if device_num > 1:
        if local_rank % 8 == 0:
            EnvToObs(train_dir, obs_train_url)
    return


def train_ddpm():
    steps = args_opt.steps
    image_size = args_opt.image_size

    data_dir = '/cache/data'
    train_dir = '/cache/output'
    ckpt_url = '/cache/checkpoint.ckpt'
    try:
        if not os.path.exists(data_dir):
            os.makedirs(data_dir)
        if not os.path.exists(train_dir):
            os.makedirs(train_dir)
    except Exception as e:
        print("path already exists")

    # ObsUrlToEnv(args_opt.ckpt_url, ckpt_url)
    DownloadFromQizhi(args_opt.data_url, data_dir)

    print("List /cache/data:  ", os.listdir(data_dir))
    model = ddpm.Unet(
        dim=image_size,
        out_dim=3,
        dim_mults=(1, 2, 4, 8)
    )

    diffusion = ddpm.GaussianDiffusion(
        model,
        image_size=image_size,
        timesteps=20,  # number of time steps
        sampling_timesteps=10,
        loss_type='l2'  # L1 or L2
    )

    trainer = ddpm.Trainer(
        diffusion,
        os.path.join(data_dir, 'test'),
        train_batch_size=1,
        train_lr=8e-5,
        train_num_steps=steps,  # total training steps
        gradient_accumulate_every=1,  # gradient accumulation steps
        ema_decay=0.995,  # exponential moving average decay
        save_and_sample_every=args_opt.save_every,  # image sampling and step
        num_samples=4,
        results_folder=train_dir,
        distributed=False
    )
    if args_opt.pretrain_path:
        trainer.load(args_opt.pretrain_path)
    trainer.train()
    UploadToQizhi(train_dir, args_opt.train_url)


if __name__ == '__main__':
    args_opt = parse_args()
    train_ddpm()

3 根因分析

创建数据集时,当数据集中的文件过大或者过小时,shuffle()接口检测失败。

4 解决方案

手动设置buffer_size,buffer_size和数据集样本大小解耦

def create_dataset(folder, image_size, extensions=None, augment_horizontal_flip=False,  
                    batch_size=32, shuffle=True, num_workers=cpu_count()):  
    extensions = ['.jpg', '.jpeg', '.png', '.tiff'] if not extensions else extensions  
    dataset = ImageFolderDataset(folder, num_parallel_workers=num_workers, shuffle=False,  
                                    extensions=extensions, decode=True)  
    
    transformers = [  
        # CenterCrop(image_size*2),  
        transformer.RandomHorizontalFlip() if augment_horizontal_flip else nn.Identity(),  
        transformer.Resize([image_size, image_size], transformer.Inter.BILINEAR),  
        transformer.ToTensor()  
    ]  
    
    dataset = dataset.project('image')  
    dataset = dataset.map(transformers, 'image')  
    if shuffle:  
        dataset = dataset.shuffle(buffer_size=1000)  
    dataset = dataset.batch(batch_size, drop_remainder=False)  
    return dataset