使用昇思MindSpore实现广告Banner图片生成:从理论到实践(上)
1.1.5 训练策略
我们使用自定义的训练步骤类来管理GAN的训练过程:
# 使用MindSpore的训练步骤类
class TrainOneStepD(nn.Cell):
def __init__(self, netD, netG, optimizerD):
super(TrainOneStepD, self).__init__()
self.netD = netD
self.netG = netG
self.optimizerD = optimizerD
self.weights = self.optimizerD.parameters
self.grad = ops.GradOperation(get_by_list=True)
def construct(self, real_images, real_labels, fake_labels):
# 计算真实图片的损失
output_real = self.netD(real_images)
lossD_real = criterion(output_real, real_labels)
# 生成假图片
batch_size = real_images.shape[0]
noise = ops.randn((batch_size, 100, 1, 1))
fake_images = self.netG(noise)
# 计算假图片的损失
output_fake = self.netD(fake_images)
lossD_fake = criterion(output_fake, fake_labels)
# 判别器总损失
lossD = (lossD_real + lossD_fake) / 2
# 计算梯度
grads = self.grad(self.forward, self.weights)(real_images, real_labels, fake_labels)
# 更新参数
self.optimizerD(grads)
return lossD
def forward(self, real_images, real_labels, fake_labels):
output_real = self.netD(real_images)
lossD_real = criterion(output_real, real_labels)
batch_size = real_images.shape[0]
noise = ops.randn((batch_size, 100, 1, 1))
fake_images = self.netG(noise)
output_fake = self.netD(fake_images)
lossD_fake = criterion(output_fake, fake_labels)
return (lossD_real + lossD_fake) / 2
class TrainOneStepG(nn.Cell):
def __init__(self, netG, netD, optimizerG):
super(TrainOneStepG, self).__init__()
self.netG = netG
self.netD = netD
self.optimizerG = optimizerG
self.weights = self.optimizerG.parameters
self.grad = ops.GradOperation(get_by_list=True)
def construct(self, real_labels):
batch_size = real_labels.shape[0]
noise = ops.randn((batch_size, 100, 1, 1))
# 生成器希望判别器将假图片判断为真
fake_images = self.netG(noise)
output = self.netD(fake_images)
lossG = criterion(output, real_labels)
# 计算梯度
grads = self.grad(self.forward, self.weights)(real_labels)
# 更新参数
self.optimizerG(grads)
return lossG
def forward(self, real_labels):
batch_size = real_labels.shape[0]
noise = ops.randn((batch_size, 100, 1, 1))
fake_images = self.netG(noise)
output = self.netD(fake_images)
return criterion(output, real_labels)
# 创建训练实例
train_step_D = TrainOneStepD(netD, netG, optimizerD)
train_step_G = TrainOneStepG(netG, netD, optimizerG)
# 训练循环
print("开始训练...")
num_epochs = 30
for epoch in range(num_epochs):
epoch_lossD = 0
epoch_lossG = 0
batch_count = 0
for i, data in enumerate(dataset.create_dict_iterator()):
real_images = data["image"]
batch_size = real_images.shape[0]
# 准备标签
real_labels = ops.ones((batch_size,), ms.float32)
fake_labels = ops.zeros((batch_size,), ms.float32)
# 训练判别器
lossD = train_step_D(real_images, real_labels, fake_labels)
# 训练生成器
lossG = train_step_G(real_labels)
epoch_lossD += lossD.asnumpy()
epoch_lossG += lossG.asnumpy()
batch_count += 1
# 打印训练信息
if batch_count > 0:
avg_lossD = epoch_lossD / batch_count
avg_lossG = epoch_lossG / batch_count
if epoch % 5 == 0:
print(f"Epoch [{epoch}/{num_epochs}] Avg Loss D: {avg_lossD:.4f}, Avg Loss G: {avg_lossG:.4f}")
# 保存生成的图片
netG.set_train(False)
fixed_noise = ops.randn((1, latent_dim, 1, 1))
generated_image = netG(fixed_noise)
# 转换为PIL图像
img = generated_image[0].asnumpy().transpose(1, 2, 0)
img = (img + 1) / 2.0 # 反归一化到[0, 1]
img = np.clip(img, 0, 1)
img = (img * 255).astype(np.uint8)
banner = Image.fromarray(img)
banner.save(f"generated_banner_epoch_{epoch}.png")
print(f"已保存 generated_banner_epoch_{epoch}.png")
netG.set_train(True)
print("训练完成!")
运行结果:
生成图片:
训练过程与结果: 在30个训练周期后,我们的GAN模型能够生成具有相当质量的Banner图片。以下是训练过程中的关键观察:
损失函数变化: 判别器和生成器的损失逐渐收敛,表明模型正在学习
图片质量提升: 随着训练进行,生成的图片从噪声逐渐变得清晰可辨
多样性保持: 模型能够生成多种不同风格的Banner图片
使用昇思MindSpore实现广告Banner图片生成:从理论到实践(下)

