1. Pix2Pix
在生成对抗网络(GAN)的浩瀚星空中,Pix2Pix(Image-to-Image Translation)是一颗极其耀眼的恒星。它不是生成随机的噪声图像,而是实现了“有条件的变换”——把线稿变上色图、把航拍图变地图、把白天变黑夜。
对于开发者而言,复现Pix2Pix是掌握**条件GAN(cGAN)**的最佳练兵场。本次复现基于MindSpore官方Notebook,在Ascend/GPU环境下均可运行。
2. 复现准备:环境与数据
2.1 数据集获取
我们使用的是经典的外墙(Facades)数据集。在Notebook中,直接调用 download 接口即可获取:
from download import download
url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/application/dataset_pix2pix.tar"
download(url, "./dataset", kind="tar", replace=True)
体验记录: 下载完成后,数据集解压在 ./dataset/dataset_pix2pix 目录下。MindSpore提供了 MindDataset 接口直接读取 .mindrecord 格式文件,这比直接读取图片文件夹效率更高。
# 可视化部分训练数据
dataset = ds.MindDataset("./dataset/dataset_pix2pix/train.mindrecord", columns_list=["input_images", "target_images"], shuffle=True)
查看数据时发现,输入图像包含了“线稿图”和“真实图”成对出现,这正是Pix2Pix“有监督训练”的基础。
3. 核心源码解析
在开始训练前,我们先深入理解一下模型架构。
3.1 生成器:U-Net的递归之美
Pix2Pix的生成器采用U-Net结构。MindSpore官方提供了一个非常优雅的递归写法 UNetSkipConnectionBlock。
class UNetSkipConnectionBlock(nn.Cell):
def __init__(self, outer_nc, inner_nc, ... submodule=None ...):
# ...
if outermost:
model = down + [submodule] + up
elif innermost:
model = down + up
else:
model = down + [submodule] + up + [nn.Dropout(p=0.5)]
# ...
解析: 这种“俄罗斯套娃”式的递归定义,使得构建任意深度的U-Net变得异常简洁。特别注意中间层加入了 Dropout,这是为了在生成过程中引入随机性,防止模型生成的图像过于单一。
3.2 判别器:PatchGAN的矩阵判别
判别器 Discriminator 输出的不是一个标量,而是一个矩阵(PatchGAN)。
class Discriminator(nn.Cell):
# ...
def construct(self, x, y):
x_y = ops.concat((x, y), axis=1) # 将条件图和生成图拼接
output = self.features(x_y)
return output
解析: 输入是 x(线稿)和 y(真图或假图)的拼接。输出矩阵的每个点代表原图一块区域的真假。计算Loss时,对这个矩阵求均值。
4. 训练过程实录
4.1 训练策略
-
优化器:Adam (lr=0.0002, beta1=0.5)
-
Loss:L1 Loss (权重100) + GAN Loss (权重1)
-
Epochs:100
-
Batch Size:从Notebook输出看,每个Epoch有25个Step,Dataset size为400,说明Batch Size约为16。
4.2 训练日志分析
启动训练后,我们可以看到实时的Loss变化:
ms per step:532.31 epoch:1/100 step:0/25 Dloss:0.6940 Gloss:38.1245
ms per step:301.06 epoch:1/100 step:6/25 Dloss:1.6741 Gloss:47.7600
......
ms per step:289.41 epoch:100/100 step:24/25 Dloss:0.4199 Gloss:9.2418
复现观察:
-
初期波动:在第1个Epoch,生成器Loss(Gloss)非常高(约38-47),这是因为初始生成的图像与真实图像差距巨大,L1 Loss占主导。
-
收敛趋势:随着训练进行,Gloss逐渐下降。到了第100个Epoch,Gloss稳定在9.0左右。
-
判别器博弈:Dloss始终在0.4-0.7之间波动,这说明判别器和生成器一直在激烈博弈,没有一方彻底压倒另一方(这是GAN训练理想的状态)。
-
性能:单步耗时约290ms,速度相当可观。
4.3 关键代码:函数式微分
MindSpore的训练循环使用了函数式微分 value_and_grad,这一点在复现时印象深刻:
# 定义梯度函数
grad_d = value_and_grad(forword_dis, None, net_discriminator.trainable_params())
grad_g = value_and_grad(forword_gan, None, net_generator.trainable_params())
def train_step(reala, realb):
loss_dis, d_grads = grad_d(reala, realb)
loss_gan, g_grads = grad_g(reala, realb)
d_opt(d_grads)
g_opt(g_grads)
return loss_dis, loss_gan
这种写法比PyTorch的 loss.backward() 更加透明和可控,特别是在GAN这种需要交替更新参数的场景下,逻辑非常清晰。
5. 推理与结果展示
训练完成后,加载 Generator.ckpt 进行推理。
param_g = load_checkpoint(ckpt_dir + "Generator.ckpt")
load_param_into_net(net_generator, param_g)
predict_show = net_generator(data_iter["input_images"])
视觉效果: 对比输入(线稿)和输出(生成图),模型成功还原了建筑物的窗户、门口等细节。虽然在一些极其复杂的纹理上略显模糊,但整体结构和色调已经非常接近真实照片。
6. 结语与展望
通过这次复现,我们不仅验证了MindSpore在图像生成任务上的能力,也深入理解了Pix2Pix的设计精髓。
-
代码简洁:U-Net和PatchGAN的实现非常精简。
-
训练稳定:在Ascend/GPU上均能稳定收敛。
如果你想进一步优化,可以尝试:
-
增加Epoch:100个Epoch可能还不够,尝试增加到200。
-
更换Backbone:将U-Net的Encoder换成ResNet50。
-
高分辨率:目前是256x256,尝试训练512x512的高清模型。
参考资料:
