开发者说 | 基于昇思MindSpore实现Palette扩散模型

作者:Adream

来源:昇思论坛

昇思MindSpore2024年技术帖分享大会圆满结束!全年收获80+高质量技术帖, 2025年全新升级,推出“2025年昇思干货小卖部,你投我就收!”,活动继续每月征集技术帖。本期技术文章由社区开发者Adrem输出并投稿。如果您对活动感兴趣,欢迎在昇思论坛投稿。

# 01

概述

Palette 是一种基于扩散模型的通用图像转换模型,可用于图像着色,图像修复,图像补全和 JPEG 图像恢复等多种转换任务,它通过输入的条件x来构建分布p(y|x),并通过图像级联的方式引入条件图像。

Palette 的主要特点包括:

  • 统一框架:Palette 模型采用了统一的框架,通过条件扩散模型解决多个图像到图像翻译任务,而无需针对每个任务进行特定的超参数调整、架构定制或辅助损失函数的设计。
  • 样本多样性:通过采用 L2 损失函数,Palette 模型在多个任务上均展现出了较高的样本多样性。这对于需要生成多种可能结果的图像到图像翻译任务尤为重要。
  • 条件 UNet 主干:采用 256×256 分辨率的条件 UNet,通过跳跃连接融合多尺度特征,增强跨层信息传递。

# 02

主要步骤

1、 预训练基础扩散模型

Palette 的预训练阶段建立在扩散模型的核心框架上,通过学习图像的噪声分布来捕获数据的潜在结构。

  • 扩散过程:逐步向图像添加高斯噪声,将原始图像 x转换为噪声样本 x_T。
  • 去噪过程:训练 UNet 模型学习反向过程,从噪声中恢复原始图像。
  • 关键技术:使用时间嵌入编码当前噪声强度,帮助模型理解去噪阶段。

2、 任务条件化适配

针对不同图像转换任务,Palette 通过以下方式进行条件化:

  • 任务特定提示:为每个任务设计专用的输入提示(如掩码区域、低分辨率图像、灰度图等)。
  • 条件嵌入:将任务提示通过额外的网络层转换为条件嵌入,与 UNet 的中间特征融合。
  • 多任务学习:通过共享主干网络、任务特定分支的方式,使模型同时掌握多个任务。

3、自适应噪声调度

Palette 引入自适应噪声调度提升不同任务的性能:

  • 任务感知噪声强度:根据任务难度动态调整噪声水平,例如修复任务使用更高噪声强度。
  • 渐进式细化:在去噪后期增加额外的细化步骤,提升细节质量。

4、后处理与质量优化

生成结果经过以下处理提升最终质量:

  • 细化网络

局部增强:对生成的图像进行局部细化,例如增强边缘、平滑过渡区域。

对抗训练:引入判别器网络,提升生成图像的真实性。

  • 一致性检查

**区域匹配:**确保生成区域与周围环境自然融合,避免人工痕迹。

**语义一致性:**在多对象场景中,保持对象间的语义关系(如比例、遮挡)。

  • 质量控制

置信度估计:预测每个像素的生成置信度,对低置信度区域进行重采样。

混合输出:结合多个采样结果,通过加权平均生成最终图像。

# 03

模型结构

从论文中可以看出,Palette 模型的核心架构就是采用了条件扩散模型作为基础框架,并在此基础上做成了一些改进。

  • **UNet 主干:**使用的是 256×256 的条件 UNet,这是一种对称结构。它包含上下采样路径和瓶颈层。同时在上下采样路径中,都包含了每个残差块,它们使用了跳跃连接来整合不同尺度的特征信息,达到捕捉图像的多尺度细节的效果。
  • **条件嵌入:**模型融合了多种条件信息来改变生成图像的结果,这其中包括了任务类型嵌入、时间步嵌入和条件图像特征。任务类型嵌入是用于区分图像的转换,例如着色、填充等;而时间步嵌入则是用于帮助模型理解生成过程中的时间信息;条件图像特征则是通过对输入的条件图像进行特征提取,将其作为额外的条件信息输入到模型中,使模型能够根据给定的条件图像生成相应的输出。
  • **多尺度特征融合:**在模型生成图像过程中,有一个步骤为跳跃连接,它是将下采样过程中不同维度的图像特征与上采样过程中的对应维度图像特征进行融合,从而丰富图像在不同尺度上的表现。
  • L2 损失函数:Palette 模型的损失函数是基于去噪目标的,即给定一个训练输出图像y,生成其噪声版本\tilde{y},并训练神经网络f_θ在条件x和噪声水平指标\gamma下对\tilde{y}进行去噪。而模型所使用的损失函数一般为 L2 范数,其数学公式为:

# 04

MindSpo re代码实现

import mindspore
from mindspore import nn, ops
from mindspore.common.initializer import Normal

class Swish(nn.Cell):
    def construct(self, x):
        return x * ops.sigmoid(x)

class TimeEmbedding(nn.Cell):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        self.half_dim = dim // 2
        self.emb = 9.210340371976184 / (self.half_dim - 1)
        self.linear_1 = nn.Dense(dim, dim * 4)
        self.linear_2 = nn.Dense(dim * 4, dim * 4)
        self.swish = Swish()
   
    def construct(self, time):
        emb = time * ops.exp(ops.arange(self.half_dim, dtype=mindspore.float32) * -self.emb)
        emb = ops.concat([ops.sin(emb), ops.cos(emb)], axis=1)
        emb = self.linear_1(emb)
        emb = self.swish(emb)
        emb = self.linear_2(emb)
        return emb

class ResidualBlock(nn.Cell):
    def __init__(self, in_channels, out_channels, time_channels, dropout):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, pad_mode='pad')
        self.norm1 = nn.GroupNorm(32, in_channels)
        self.swish = Swish()
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, pad_mode='pad')
        self.norm2 = nn.GroupNorm(32, out_channels)
        self.time_emb_proj = nn.Dense(time_channels, out_channels)
        self.dropout = nn.Dropout(1 - dropout)
       
        if in_channels != out_channels:
            self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        else:
            self.shortcut = nn.Identity()
   
    def construct(self, x, t):
        h = self.swish(self.norm1(x))
        h = self.conv1(h)
       
        time_emb = self.swish(self.time_emb_proj(t))
        time_emb = time_emb.expand_dims(-1).expand_dims(-1)
        h = h + time_emb
       
        h = self.swish(self.norm2(h))
        h = self.dropout(h)
        h = self.conv2(h)
       
        return h + self.shortcut(x)

class AttentionBlock(nn.Cell):
    def __init__(self, channels):
        super().__init__()
        self.channels = channels
        self.norm = nn.GroupNorm(32, channels)
        self.q = nn.Conv2d(channels, channels, kernel_size=1)
        self.k = nn.Conv2d(channels, channels, kernel_size=1)
        self.v = nn.Conv2d(channels, channels, kernel_size=1)
        self.proj_out = nn.Conv2d(channels, channels, kernel_size=1)
        self.scale = channels ** -0.5
 
    def construct(self, x):
        batch_size, channels, height, width = x.shape
        h = self.norm(x)
       
        q = self.q(h)
        k = self.k(h)
        v = self.v(h)
       
        q = q.reshape(batch_size, channels, -1)
        k = k.reshape(batch_size, channels, -1)
        v = v.reshape(batch_size, channels, -1)
        
        attn = ops.bmm(q.transpose(0, 2, 1), k) * self.scale
        attn = ops.softmax(attn, axis=-1)
       
        h = ops.bmm(v, attn.transpose(0, 2, 1))
        h = h.reshape(batch_size, channels, height, width)
        h = self.proj_out(h)
       
        return x + h

class UNetDown(nn.Cell):
    def __init__(self, in_channels, out_channels, time_channels, has_attn, dropout):
        super().__init__()
        self.res = ResidualBlock(in_channels, out_channels, time_channels, dropout)
        self.attn = AttentionBlock(out_channels) if has_attn else nn.Identity()
    
    def construct(self, x, t):
        x = self.res(x, t)
        x = self.attn(x)
        return x

class UNetUp(nn.Cell):
    def __init__(self, in_channels, out_channels, time_channels, has_attn, dropout):
        super().__init__()
        self.res = ResidualBlock(in_channels + out_channels, out_channels, time_channels, dropout)
        self.attn = AttentionBlock(out_channels) if has_attn else nn.Identity()
   
    def construct(self, x, skip, t):
        x = ops.cat([x, skip], axis=1)
        x = self.res(x, t)
        x = self.attn(x)
        return x

class MidBlock(nn.Cell):
    def __init__(self, channels, time_channels, dropout):
        super().__init__()
        self.res1 = ResidualBlock(channels, channels, time_channels, dropout)
        self.attn = AttentionBlock(channels)
        self.res2 = ResidualBlock(channels, channels, time_channels, dropout)
    
    def construct(self, x, t):
        x = self.res1(x, t)
        x = self.attn(x)
        x = self.res2(x, t)
        return x

class Upsample(nn.Cell):
    def __init__(self, channels):
        super().__init__()
        self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1, pad_mode='pad')
   
    def construct(self, x):
        x = ops.interpolate(x, scale_factor=2, mode='nearest')
        x = self.conv(x)
        return x

class Downsample(nn.Cell):
    def __init__(self, channels):
        super().__init__()
        self.conv = nn.Conv2d(channels, channels, kernel_size=3, stride=2, padding=1, pad_mode='pad')
  
    def construct(self, x):
        return self.conv(x)

class Palette(nn.Cell):
    def __init__(self, img_channels=3, base_channels=128, channel_mults=(1, 2, 4, 8),
                 num_res_blocks=2, time_emb_dim=512, dropout=0.1, num_classes=None):
        super().__init__()
        self.img_channels = img_channels
        self.base_channels = base_channels
        self.num_classes = num_classes
       
        # Time embedding
        self.time_embed = TimeEmbedding(time_emb_dim)
       
        # Class embedding (if applicable)
        if num_classes is not None:
            self.label_emb = nn.Embedding(num_classes, time_emb_dim)
       
        # Initial convolution
        self.init_conv = nn.Conv2d(img_channels, base_channels, kernel_size=3, padding=1, pad_mode='pad')
       
        # Downsample path
        down_channels = [base_channels]
        curr_channels = base_channels
        self.down = nn.CellList()
        
        for i, mult in enumerate(channel_mults):
            out_channels = base_channels * mult
            for _ in range(num_res_blocks):
                self.down.append(UNetDown(curr_channels, out_channels, time_emb_dim, i > 1, dropout))
                curr_channels = out_channels
                down_channels.append(curr_channels)
           
            if i != len(channel_mults) - 1:
                self.down.append(Downsample(curr_channels))
                down_channels.append(curr_channels)
        
        # Middle block
        self.mid = MidBlock(curr_channels, time_emb_dim, dropout)
       
        # Upsample path
        self.up = nn.CellList()
        up_channels = down_channels[::-1]
        
        for i, mult in enumerate(reversed(channel_mults)):
            out_channels = base_channels * mult
            for _ in range(num_res_blocks + 1):
                self.up.append(UNetUp(curr_channels, out_channels, time_emb_dim, i > 1, dropout))
                curr_channels = out_channels
           
            if i != len(channel_mults) - 1:
                self.up.append(Upsample(curr_channels))
       
        # Output block
        self.norm_out = nn.GroupNorm(32, curr_channels)
        self.swish_out = Swish()
        self.final_conv = nn.Conv2d(curr_channels, img_channels, kernel_size=3, padding=1, pad_mode='pad')
   
    def construct(self, x, time, y=None):
        # Time embedding
        t = self.time_embed(time)
       
        # Class embedding (if applicable)
        if y is not None:
            t = t + self.label_emb(y)
        
        # Initial convolution
        x_orig = x
        x = self.init_conv(x)
        h = [x]
       
        # Downsample path
        for module in self.down:
            if isinstance(module, UNetDown):
                x = module(x, t)
                h.append(x)
            else:
                x = module(x)
                h.append(x)
       
        # Middle block
        x = self.mid(x, t)
       
        # Upsample path
        for module in self.up:
            if isinstance(module, UNetUp):
                skip = h.pop()
                x = module(x, skip, t)
            else:
                x = module(x)
       
        # Output block
        x = self.swish_out(self.norm_out(x))
        x = self.final_conv(x)
        return x

参考链接

[1] 论文地址:https://arxiv.org/pdf/2111.05826