1 系统环境
硬件环境(Ascend/GPU/CPU): Ascend
MindSpore版本: mindspore=2.0.0
执行模式(PyNative/ Graph):不限
Python版本: Python=3.7
操作系统平台: 不限
2 报错信息
2.1 问题描述
运行这里(https://openi.pcl.ac.cn/drizzlezyk/ddpm2/src/branch/master/train.py)的代码和数据,python train.py,报错ValueError: Input buffer_size is not within the required interval of [2, 2147483647].
2.2 报错信息
ValueError: Input buffer_size is not within the required interval of [2, 2147483647].复制
2.3 脚本代码
import ddpm as ddpm
import argparse
import os
from mindspore import context
from mindspore.communication.management import init
from mindspore.context import ParallelMode
import time
from upload import UploadOutput
import moxing as mox
def parse_args():
parser = argparse.ArgumentParser(description="train ddpm",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--pretrain_path',
type=str,
default=None,
help='the pretrain model path')
parser.add_argument('--data_url',
type=str,
default="C:\\Users\\Administrator\\PycharmProjects\\DDPM\\datasets\\test",
help='training data file path')
parser.add_argument('--train_url',
default='./results',
type=str,
help='the path model and fig save path')
parser.add_argument('--steps',
default=20000,
type=int,
help='training steps')
parser.add_argument('--save_every',
default=5000,
type=int,
help='save_every')
parser.add_argument('--num_samples',
default=4,
type=int,
help='num_samples must have a square root, like 4, 9, 16 ...')
parser.add_argument('--device_target',
default="Ascend",
type=str,
help='device target')
parser.add_argument('--image_size',
default=200,
type=int,
help='image size')
args, _ = parser.parse_known_args()
return args
def ObsToEnv(obs_data_url, data_dir):
try:
mox.file.copy_parallel(obs_data_url, data_dir)
print("Successfully Download {} to {}".format(obs_data_url, data_dir))
except Exception as e:
print('moxing download {} to {} failed: '.format(obs_data_url, data_dir) + str(e))
f = open("/cache/download_input.txt", 'w')
f.close()
try:
if os.path.exists("/cache/download_input.txt"):
print("download_input succeed")
except Exception as e:
print("download_input failed")
return
def ObsUrlToEnv(obs_ckpt_url, ckpt_url):
try:
mox.file.copy(obs_ckpt_url, ckpt_url)
print("Successfully Download {} to {}".format(obs_ckpt_url,ckpt_url))
except Exception as e:
print('moxing download {} to {} failed: '.format(obs_ckpt_url, ckpt_url) + str(e))
return
def EnvToObs(train_dir, obs_train_url):
try:
mox.file.copy_parallel(train_dir, obs_train_url)
print("Successfully Upload {} to {}".format(train_dir,obs_train_url))
except Exception as e:
print('moxing upload {} to {} failed: '.format(train_dir,obs_train_url) + str(e))
return
def DownloadFromQizhi(obs_data_url, data_dir):
device_num = int(os.getenv('RANK_SIZE'))
if device_num == 1:
ObsToEnv(obs_data_url, data_dir)
# context.set_context(device_target=args_opt.device_target)
if device_num > 1:
# set device_id and init for multi-card training
context.set_context(mode=context.GRAPH_MODE,
device_target=args_opt.device_target,
device_id=int(os.getenv('ASCEND_DEVICE_ID')))
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=device_num,
parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True,
parameter_broadcast=True)
init()
local_rank = int(os.getenv('RANK_ID'))
if local_rank % 8 == 0:
ObsToEnv(obs_data_url, data_dir)
while not os.path.exists("/cache/download_input.txt"):
time.sleep(1)
return
def UploadToQizhi(train_dir, obs_train_url):
device_num = int(os.getenv('RANK_SIZE'))
local_rank = int(os.getenv('RANK_ID'))
if device_num == 1:
EnvToObs(train_dir, obs_train_url)
if device_num > 1:
if local_rank % 8 == 0:
EnvToObs(train_dir, obs_train_url)
return
def train_ddpm():
steps = args_opt.steps
image_size = args_opt.image_size
data_dir = '/cache/data'
train_dir = '/cache/output'
ckpt_url = '/cache/checkpoint.ckpt'
try:
if not os.path.exists(data_dir):
os.makedirs(data_dir)
if not os.path.exists(train_dir):
os.makedirs(train_dir)
except Exception as e:
print("path already exists")
# ObsUrlToEnv(args_opt.ckpt_url, ckpt_url)
DownloadFromQizhi(args_opt.data_url, data_dir)
print("List /cache/data: ", os.listdir(data_dir))
model = ddpm.Unet(
dim=image_size,
out_dim=3,
dim_mults=(1, 2, 4, 8)
)
diffusion = ddpm.GaussianDiffusion(
model,
image_size=image_size,
timesteps=20, # number of time steps
sampling_timesteps=10,
loss_type='l2' # L1 or L2
)
trainer = ddpm.Trainer(
diffusion,
os.path.join(data_dir, 'test'),
train_batch_size=1,
train_lr=8e-5,
train_num_steps=steps, # total training steps
gradient_accumulate_every=1, # gradient accumulation steps
ema_decay=0.995, # exponential moving average decay
save_and_sample_every=args_opt.save_every, # image sampling and step
num_samples=4,
results_folder=train_dir,
distributed=False
)
if args_opt.pretrain_path:
trainer.load(args_opt.pretrain_path)
trainer.train()
UploadToQizhi(train_dir, args_opt.train_url)
if __name__ == '__main__':
args_opt = parse_args()
train_ddpm()
3 根因分析
创建数据集时,当数据集中的文件过大或者过小时,shuffle()接口检测失败。
4 解决方案
手动设置buffer_size,buffer_size和数据集样本大小解耦
def create_dataset(folder, image_size, extensions=None, augment_horizontal_flip=False,
batch_size=32, shuffle=True, num_workers=cpu_count()):
extensions = ['.jpg', '.jpeg', '.png', '.tiff'] if not extensions else extensions
dataset = ImageFolderDataset(folder, num_parallel_workers=num_workers, shuffle=False,
extensions=extensions, decode=True)
transformers = [
# CenterCrop(image_size*2),
transformer.RandomHorizontalFlip() if augment_horizontal_flip else nn.Identity(),
transformer.Resize([image_size, image_size], transformer.Inter.BILINEAR),
transformer.ToTensor()
]
dataset = dataset.project('image')
dataset = dataset.map(transformers, 'image')
if shuffle:
dataset = dataset.shuffle(buffer_size=1000)
dataset = dataset.batch(batch_size, drop_remainder=False)
return dataset