MindSpore应用案例:ICT实现图像修复(下)

推理

下面进行所训练的模型在ImageNet验证集集上的推理,由于ImageNet数据集规模过大,在此案例中从ImageNet数据集中挑选几张图片作为案例推理测试使用。
模型训练好的权重可以在此处ict_ckpt进行下载,下载完毕后放到ckpts_ICT下,其中权重文件目录如下所示,其中origin是由pytorch模型转换而来的权重文件,ms_train则是由MindSpore框架训练的权重文件:

.ckpts_ICT/
├── origin
   ├── Transformer/
   │   ├── ImageNet.ckpt
   │   ├── FFHQ.ckpt
   │   └── Places2_Nature.ckpt
   └── Upsample/
       ├── FFHQ
       │   ├──InpaintingModel_dis.ckpt
       │   └──InpaintingModel_gen.ckpt
       ├── ImageNet
       │   ├──InpaintingModel_dis.ckpt
       │   └──InpaintingModel_gen.ckpt
       └── Places2_Nature
           ├──InpaintingModel_dis.ckpt
           └──InpaintingModel_gen.ckpt
├── ms_train
   ├── Transformer/
   │   └── ImageNet_best.ckpt
   └── Upsample/
       ├── InpaintingModel_dis_best.ckpt
       └── InpaintingModel_gen_best.ckpt
├── VGG19.ckpt
import numpy as np
import mindspore

if not os.getcwd().endswith('ICT') or not os.getcwd().endswith('ict'):
    os.chdir("..")

opts = parse_args()
input_path = './input'

if not os.path.exists(input_path):
    os.mkdir(input_path)

if os.path.exists(opts.image_url):
    opts.image_url = os.path.join(opts.image_url, 'n01440764')
    os.system('cp {} {}'.format(os.path.join(opts.image_url, 'ILSVRC2012_val_00000293.JPEG'), input_path))
    os.system('cp {} {}'.format(os.path.join(opts.image_url, 'ILSVRC2012_val_00002138.JPEG'), input_path))

opts.image_url = input_path
C = np.load('./kmeans_centers.npy')
C = np.rint(127.5 * (C + 1.0))
C = mindspore.Tensor.from_numpy(C)

img_list = sorted(os.listdir(opts.image_url))
mask_list = sorted(os.listdir(opts.mask_url))
n_samples = opts.condition_num

First Stage —— Transformer

定义网络以及加载权重

from mindspore.train import Model

# Define the model
block_size = opts.prior_size * opts.prior_size
transformer = GPT(vocab_size=C.shape[0], n_embd=opts.n_embd, n_layer=opts.n_layer, n_head=opts.n_head,
                  block_size=block_size, use_gelu2=opts.GELU_2, embd_pdrop=0.0, resid_pdrop=0.0, attn_pdrop=0.0)
# 根据相应权重文件进行路径修改,如果进行修改,请注意使用绝对路径避免不必要的错误
opts.ckpt_path = './ckpts_ICT/ms_train/Transformer/ImageNet_best.ckpt'
if os.path.exists(opts.ckpt_path):
    print('Start loading the model parameters from %s' % (opts.ckpt_path))
    checkpoint = mindspore.load_checkpoint(opts.ckpt_path)
    mindspore.load_param_into_net(transformer, checkpoint)
    print('Finished load the model')
transformer.set_train(False)
model = Model(transformer)

from PIL import Image

from transformer_utils.util import sample_mask

for x_name, y_name in zip(img_list, mask_list):
    # load image
    print(x_name)
    image_url = os.path.join(opts.image_url, x_name)
    x = Image.open(image_url).convert("RGB")
    x = x.resize((opts.prior_size, opts.prior_size), resample=Image.BILINEAR)
    x = mindspore.Tensor.from_numpy(np.array(x)).view(-1, 3)
    x = P.Cast()(x, mindspore.float32)
    x = ((x[:, None, :] - C[None, :, :]) ** 2).sum(-1).argmin(1)

    # load mask
    mask_url = os.path.join(opts.mask_url, y_name)
    y = Image.open(mask_url).convert("L")
    y = y.resize((opts.prior_size, opts.prior_size), resample=Image.NEAREST)
    y = (np.array(y) / 255.) > 0.5
    y = mindspore.Tensor.from_numpy(y).view(-1)
    y = P.Cast()(y, mindspore.float32)

    x_list = [x] * n_samples
    x_tensor = P.Stack()(x_list)
    y_list = [y] * n_samples
    y_tensor = P.Stack()(y_list)
    x_tensor = P.Cast()(x_tensor * (1 - y_tensor), mindspore.int32)
    outputs = sample_mask(model, x_tensor, y_tensor, length=opts.prior_size * opts.prior_size,
                          top_k=opts.top_k)

    img_name = x_name[:x_name.find('.')] + x_name[x_name.find('.'):]
    for i in range(n_samples):
        current_url = os.path.join(opts.save_url, 'condition_%d' % (i + 1))
        os.makedirs(current_url, exist_ok=True)
        current_img = C[outputs[i]].view(opts.prior_size, opts.prior_size, 3).asnumpy().astype(np.uint8)
        tmp = Image.fromarray(current_img)
        tmp.save(os.path.join(current_url, img_name))

Second Stage —— Upsample

利用Transformer产生的图像先验信息进行图像上采样,恢复到高分辨图片。

def stitch_images(inputs, *outputs, img_per_row=2):
    gap = 5
    columns = len(outputs) + 1

    width, height = inputs[0][:, :, 0].shape
    img = Image.new('RGB',
                    (width * img_per_row * columns + gap * (img_per_row - 1), height * int(len(inputs) / img_per_row)))
    images = [inputs, *outputs]

    for ix in range(len(inputs)):
        xoffset = int(ix % img_per_row) * width * columns + int(ix % img_per_row) * gap
        yoffset = int(ix / img_per_row) * height

        for cat in range(len(images)):
            im = images[cat][ix].asnumpy().astype(np.uint8).squeeze()
            im = Image.fromarray(im)
            img.paste(im, (xoffset + cat * width, yoffset))

    return img


from upsample_utils.util import postprocess, imsave

opts.mask_type = 3
opts.mode = 2
opts.input = './input/'
opts.kmeans = './kmeans_centers.npy'
# 确保第二阶段的输入与第一阶段的输出相同
opts.prior = opts.save_url

generator = Generator()
generator.set_train(False)
# 根据相应权重文件进行路径修改,如果进行修改,请注意使用绝对路径避免不必要的错误
opts.ckpt_path = './ckpts_ICT/ms_train/Upsample/InpaintingModel_gen_best.ckpt'
if os.path.exists(opts.ckpt_path):
    print('Start loading the model parameters from %s' % (opts.ckpt_path))
    checkpoint = mindspore.load_checkpoint(opts.ckpt_path)
    mindspore.load_param_into_net(generator, checkpoint)
    print('Finished load the model')

psnr_func = PSNR(255.0)

test_dataset = load_dataset(image_flist=opts.input, edge_flist=opts.prior, mask_filst=opts.mask,
                            image_size=opts.image_size, prior_size=opts.prior_size, mask_type=opts.mask_type,
                            kmeans=opts.kmeans, condition_num=opts.condition_num,
                            augment=False, training=False)

index = 0
psnr = AverageMeter()
mae = AverageMeter()
test_batch_size = 1
test_dataset = test_dataset.batch(test_batch_size)
for sample in test_dataset.create_dict_iterator():
    name = sample['name'].asnumpy()[0]
    images = sample['images']
    edges = sample['edges']
    masks = sample['masks']
    inputs = (images * (1 - masks)) + masks
    index += test_batch_size
    outputs = generator(images, edges, masks)
    outputs_merged = (outputs * masks) + (images * (1 - masks))
    psnr.update(psnr_func(postprocess(images), postprocess(outputs_merged)), 1)
    mae.update((P.ReduceSum()(P.Abs()(images - outputs_merged)) / P.ReduceSum()(images)), 1)
    result_merge = stitch_images(
        postprocess(images),
        postprocess(inputs),
        postprocess(outputs_merged),
        img_per_row=1
    )
    result_merge.show()
    output = postprocess(outputs_merged)[0]
    path = os.path.join(opts.save_url, name[:-4] + "_%d" % (index % opts.condition_num) + '.png')
    imsave(output, path)
print('PSNR: {}, MAE: {}'.format(psnr.avg, mae.avg))