Classifier-Free Guidance
论文地址:Classifier-Free Diffusion Guidance
代码地址:https://github.com/teapearce/conditional_diffusion_mnist
一、概述
Classifier-Free Guidance(无分类器引导,简称 CFG) 是谷歌于 2022 年提出的扩散模型优化技术,旨在增强生成样本的质量与条件契合度。该方法通过联合训练无条件与有条件生成模型,避免了传统Classifier Guidance中显式分类器的依赖,无需额外训练噪声图像分类器。推理阶段,通过线性组合两种生成模式的预测结果,实现灵活的条件引导,显著提升文本生成图像、图像修复等任务的性能。值得强调的是,CFG 是一种推理优化技术,不改变模型训练目标。
Classifier-Free Guidance 的主要特点包括:
- 无分类器依赖:无需额外训练和维护分类器,减少了模型复杂性和计算资源消耗。
- 联合训练策略:通过随机丢弃条件信息,使同一网络同时学习无条件与有条件生成;
- 灵活的引导强度:引入引导强度超参数 $w$,可灵活调节生成结果对条件的依赖程度。
二、主要步骤
CFG的实现分为训练阶段与推理阶段,通过巧妙的参数共享与结果融合达成高效引导。
-
训练阶段
- 无条件样本生成:以概率 $p_\theta$ 随机置空条件输入 $z$,训练模型生成无约束样本。
$$
\max_{\theta} \mathbb{E}{x \sim p\theta(z)} \left[ \log p_\theta(z) \right]
$$其中,$p_\theta(z)$ 是无条件生成分布。- 有条件样本生成:同时使用完整条件 $z$ 训练模型生成符合条件的样本。
$$
\max_{\theta} \mathbb{E}{x \sim p\theta(z)} \left[ \log p_\theta(z|c) - \log p_\theta(z) \right]
$$其中,$p_\theta(z|c)$ 是有条件生成分布,$p_\theta(z)$ 是无条件生成分布。- 共享参数:无条件与有条件模型共享大部分参数,仅通过条件标记区分输入。
$$
\max_{\theta} \mathbb{E}{x \sim p\theta(z)} \left[ \log p_\theta(z|c) - \log p_\theta(z) \right] + \lambda \mathbb{E}{x \sim p\theta(z)} \left[ \log p_\theta(z) \right]
$$其中,$\lambda$ 是权重参数,用于控制条件匹配度和无条件生成分布的熵之间的平衡。
-
推理阶段
- 并行生成:同时生成无条件预测 $p_\theta(z)$ 和有条件预测 $p_\theta(z|c)$。
- 线性插值:通过超参数 $w$ 对两者进行线性组合:
$$
\text{最终输出} = p_\theta(z) + w \cdot \left[ p_\theta(z|c) - p_\theta(z) \right]
$$ - 动态调整:$w$ 控制条件依赖强度,$w=0$ 时为纯无条件生成,$w \to \infty$ 时强制匹配条件。
三、关键点
-
无需分类器设计
- 传统方法需额外训练分类器 $p(c|z)$,而CFG通过联合训练无条件生成模型和有条件生成模型,实现了条件信息的隐式建模。
-
梯度组合策略
-
提出了基于引导强度参数$w$的梯度组合公式,实现了无条件生成梯度与有条件生成梯度的动态融合,本质上是对预测结果的加权差值放大。公式为:
$$
\text{最终输出} = p_\theta(z) + w \cdot \left[ p_\theta(z|c) - p_\theta(z) \right]
$$- $w=1$:标准条件生成(如文本对齐图像)。
- $w>1$:增强条件契合度(如更鲜艳的颜色)。
- $w<1$:提升多样性(如抽象艺术风格)。
-
-
联合训练范式
- 随机条件丢弃:训练时以50%概率随机屏蔽条件输入,强制模型学习数据分布的共性与条件特化;
- 参数共享:无条件与有条件生成共享U-Net网络参数,仅通过条件标记区分输入,避免冗余训练。
四、模型结构
Classifier-Free Guidance 通常基于DDPM、DDIM等扩散模型,核心设计围绕共享架构与双路径训练展开:
-
共享网络架构
- 模型共享UNet结构,通过条件标记区分输入,负责预测噪声或数据分布梯度。
- 使用Transformer或Embedding层将条件信息(如文本、标签)编码为向量,与时间步嵌入融合后输入网络。
-
双路径训练机制
- 在条件训练路径中,输入条件编码与噪声图像,再训练模型预测条件去噪目标。
- 在无条件训练路径中,会以一定概率丢弃条件信息,训练模型预测无条件去噪目标。此设计使模型同时学习条件依赖与无条件生成能力。
-
线性插值策略
- 在推理阶段,通过超参数w对无条件预测与有条件预测进行线性组合,实现逼真性与多样性的权衡引导。
- 当w=0时,仅生成无条件预测;当w=1时,仅生成有条件预测;当w>1时,增强条件匹配度;当w<1时,提升生成多样性。
五、代码实现
import mindspore as ms
from mindspore import nn, ops, Tensor
import numpy as np
# 设置随机种子确保可重复性
ms.set_seed(42)
# 定义U-Net扩散模型
class UNet(nn.Cell):
def __init__(self, in_channels=3, out_channels=3, channels=128, time_dim=256):
super().__init__()
# 时间步嵌入
self.time_mlp = nn.SequentialCell(
nn.Dense(time_dim, time_dim * 2),
nn.SiLU(),
nn.Dense(time_dim * 2, time_dim)
)
# 编码器
self.enc1 = nn.SequentialCell(
nn.Conv2d(in_channels + time_dim, channels, 3, padding=1),
nn.GroupNorm(4, channels),
nn.SiLU()
)
self.enc2 = nn.SequentialCell(
nn.Conv2d(channels, channels * 2, 3, padding=1, stride=2),
nn.GroupNorm(8, channels * 2),
nn.SiLU()
)
# 中间层
self.mid = nn.SequentialCell(
nn.Conv2d(channels * 2, channels * 2, 3, padding=1),
nn.GroupNorm(8, channels * 2),
nn.SiLU()
)
# 解码器
self.dec1 = nn.SequentialCell(
nn.Conv2dTranspose(channels * 2, channels, 3, padding=1, stride=2),
nn.GroupNorm(4, channels),
nn.SiLU()
)
self.dec2 = nn.SequentialCell(
nn.Conv2d(channels + time_dim, channels, 3, padding=1),
nn.GroupNorm(4, channels),
nn.SiLU()
)
self.final_conv = nn.Conv2d(channels, out_channels, 3, padding=1)
def construct(self, x, t):
# 时间步嵌入
t_emb = self.time_mlp(t)
t_emb = ops.tile(t_emb.unsqueeze(1).unsqueeze(2), (1, x.shape[1], x.shape[2], 1)).transpose(0, 3, 1, 2)
# 编码
x = ops.concat([x, t_emb], axis=1)
x = self.enc1(x)
x = self.enc2(x)
# 中间层
x = self.mid(x)
# 解码
x = self.dec1(x)
x = ops.concat([x, t_emb], axis=1)
x = self.dec2(x)
# 输出噪声预测
return self.final_conv(x)
# 定义Classifier-Free Guidance训练器
class CFGDiffusion(nn.Cell):
def __init__(self, model, guidance_scale=5.0):
super().__init__()
self.model = model
self.guidance_scale = guidance_scale
self.loss_fn = nn.MSELoss()
self.beta_schedule = self._linear_beta_schedule(timesteps=1000)
self.alpha_cumprod = np.cumprod(1 - self.beta_schedule)
def _linear_beta_schedule(self, timesteps):
beta_start = 1e-4
beta_end = 0.02
return np.linspace(beta_start, beta_end, timesteps)
def construct(self, x0, t, noise=None, y=None):
# 添加噪声
if noise is None:
noise = ops.randn_like(x0)
# 前向过程
sqrt_alpha_cumprod = ops.sqrt(Tensor(self.alpha_cumprod[t], ms.float32))
sqrt_one_minus_alpha_cumprod = ops.sqrt(1 - Tensor(self.alpha_cumprod[t], ms.float32))
xt = sqrt_alpha_cumprod * x0 + sqrt_one_minus_alpha_cumprod * noise
# 条件丢弃(50%概率)
if y is not None and ops.rand(1) > 0.5:
y = None
# 预测噪声
if y is None:
pred_noise = self.model(xt, t)
else:
pred_noise_cond = self.model(xt, t)
pred_noise_uncond = self.model(xt, t) # 无条件路径复用模型
pred_noise = pred_noise_uncond + self.guidance_scale * (pred_noise_cond - pred_noise_uncond)
# 计算损失
loss = self.loss_fn(pred_noise, noise)
return loss