使用昇思MindSpore实现广告Banner图片生成:从理论到实践(中)

使用昇思MindSpore实现广告Banner图片生成:从理论到实践(上)

1.1.5 训练策略

我们使用自定义的训练步骤类来管理GAN的训练过程:

# 使用MindSpore的训练步骤类
class TrainOneStepD(nn.Cell):
    def __init__(self, netD, netG, optimizerD):
        super(TrainOneStepD, self).__init__()
        self.netD = netD
        self.netG = netG
        self.optimizerD = optimizerD
        self.weights = self.optimizerD.parameters
        self.grad = ops.GradOperation(get_by_list=True)

    def construct(self, real_images, real_labels, fake_labels):
        # 计算真实图片的损失
        output_real = self.netD(real_images)
        lossD_real = criterion(output_real, real_labels)
        
        # 生成假图片
        batch_size = real_images.shape[0]
        noise = ops.randn((batch_size, 100, 1, 1))
        fake_images = self.netG(noise)
        
        # 计算假图片的损失
        output_fake = self.netD(fake_images)
        lossD_fake = criterion(output_fake, fake_labels)
        
        # 判别器总损失
        lossD = (lossD_real + lossD_fake) / 2
        
        # 计算梯度
        grads = self.grad(self.forward, self.weights)(real_images, real_labels, fake_labels)
        # 更新参数
        self.optimizerD(grads)
        return lossD

    def forward(self, real_images, real_labels, fake_labels):
        output_real = self.netD(real_images)
        lossD_real = criterion(output_real, real_labels)
        
        batch_size = real_images.shape[0]
        noise = ops.randn((batch_size, 100, 1, 1))
        fake_images = self.netG(noise)
        output_fake = self.netD(fake_images)
        lossD_fake = criterion(output_fake, fake_labels)
        
        return (lossD_real + lossD_fake) / 2

class TrainOneStepG(nn.Cell):
    def __init__(self, netG, netD, optimizerG):
        super(TrainOneStepG, self).__init__()
        self.netG = netG
        self.netD = netD
        self.optimizerG = optimizerG
        self.weights = self.optimizerG.parameters
        self.grad = ops.GradOperation(get_by_list=True)

    def construct(self, real_labels):
        batch_size = real_labels.shape[0]
        noise = ops.randn((batch_size, 100, 1, 1))
        
        # 生成器希望判别器将假图片判断为真
        fake_images = self.netG(noise)
        output = self.netD(fake_images)
        lossG = criterion(output, real_labels)
        
        # 计算梯度
        grads = self.grad(self.forward, self.weights)(real_labels)
        # 更新参数
        self.optimizerG(grads)
        return lossG

    def forward(self, real_labels):
        batch_size = real_labels.shape[0]
        noise = ops.randn((batch_size, 100, 1, 1))
        fake_images = self.netG(noise)
        output = self.netD(fake_images)
        return criterion(output, real_labels)

# 创建训练实例
train_step_D = TrainOneStepD(netD, netG, optimizerD)
train_step_G = TrainOneStepG(netG, netD, optimizerG)

# 训练循环
print("开始训练...")
num_epochs = 30

for epoch in range(num_epochs):
    epoch_lossD = 0
    epoch_lossG = 0
    batch_count = 0
    
    for i, data in enumerate(dataset.create_dict_iterator()):
        real_images = data["image"]
        batch_size = real_images.shape[0]
        
        # 准备标签
        real_labels = ops.ones((batch_size,), ms.float32)
        fake_labels = ops.zeros((batch_size,), ms.float32)
        
        # 训练判别器
        lossD = train_step_D(real_images, real_labels, fake_labels)
        
        # 训练生成器
        lossG = train_step_G(real_labels)
        
        epoch_lossD += lossD.asnumpy()
        epoch_lossG += lossG.asnumpy()
        batch_count += 1
    
    # 打印训练信息
    if batch_count > 0:
        avg_lossD = epoch_lossD / batch_count
        avg_lossG = epoch_lossG / batch_count
        
        if epoch % 5 == 0:
            print(f"Epoch [{epoch}/{num_epochs}] Avg Loss D: {avg_lossD:.4f}, Avg Loss G: {avg_lossG:.4f}")
            
            # 保存生成的图片
            netG.set_train(False)
            fixed_noise = ops.randn((1, latent_dim, 1, 1))
            generated_image = netG(fixed_noise)
            
            # 转换为PIL图像
            img = generated_image[0].asnumpy().transpose(1, 2, 0)
            img = (img + 1) / 2.0  # 反归一化到[0, 1]
            img = np.clip(img, 0, 1)
            img = (img * 255).astype(np.uint8)
            
            banner = Image.fromarray(img)
            banner.save(f"generated_banner_epoch_{epoch}.png")
            print(f"已保存 generated_banner_epoch_{epoch}.png")
            netG.set_train(True)

print("训练完成!")

运行结果:

生成图片:


训练过程与结果: 在30个训练周期后,我们的GAN模型能够生成具有相当质量的Banner图片。以下是训练过程中的关键观察:

损失函数变化: 判别器和生成器的损失逐渐收敛,表明模型正在学习

图片质量提升: 随着训练进行,生成的图片从噪声逐渐变得清晰可辨

多样性保持: 模型能够生成多种不同风格的Banner图片
使用昇思MindSpore实现广告Banner图片生成:从理论到实践(下)