分享

扩散模型 (Diffusion Model) 简要介绍与源码分析

 520jefferson 2023-02-04 发布于北京

前言

近期同事分享了 Diffusion Model,这才发现生成模型的发展已经到了如此惊人的地步,OpenAI 推出的 Dall-E 2 可以根据文本描述生成极为逼真的图像,质量之高直让人惊呼哇塞。后来公众号给我推送了一篇关于 Stability AI 公司的报道,他们推出的 AI 文生图扩散模型 Stable Diffusion 已开源,能够在消费级显卡上实现 Dall-E 2 级别的图像生成,效率提升了 30 倍。

于是找到他们的开源产品体验了一把,在线体验地址在:

https:///spaces/stabilityai/stable-diffusion

开源代码在 Github 上:

https://github.com/CompVis/stable-diffusion

在搜索框中输入 'A dog flying in the sky' (一只狗在天空飞翔,生成效果如下:

图片

Amazing! 当然,不是每一张图片都符合预期,但好在可以生成无数张图片,其中总有效果好的。在震惊之余,不免对 Diffusion Model (扩散模型) 背后的原理感兴趣,就想看看是怎么实现的。

当时同事分享时,PPT 上那一堆堆公式扑面而来,把我给整懵圈了,但还是得撑起下巴,表现出似有所悟、深以为然的样子,在讲到关键处不由暗暗点头以表示理解和赞许。后面花了个周末专门学习了一下 k 公式推导 + 代码分析,感觉终于了解了基本概念,于是记录下来形成此文,不敢说自己完全懂了,毕竟我不做这个方向,但回过头去看 PPT 上的公式就不再发怵了。

总览

本文对 Diffusion Model 扩散模型的原理进行简要介绍,然后对源码进行分析。扩散模型的实现有多种形式,本文关注的是 DDPM (denoising diffusion probabilistic models)。在介绍完基本原理后,对作者释放的 Tensorflow 源码进行分析,加深对各种公式的理解。

参考文章

在理解扩散模型的路上, 受到下面这些文章的启发, 强烈推荐阅读:

Lilian 的博客,内容非常非常详实,干货十足,而且每篇文章都极其用心,向大佬学习: What are Diffusion Models?

ewrfcas 的知乎,公式推导补充了更多的细节: 由浅入深了解Diffusion Model

Lilian 的博客,介绍变分自动编码器 VAE: From Autoencoder to Beta-VAE, Diffusion Model 需要从分布中随机采样样本,该过程无法求导,需要使用到 VAE 中介绍的重参数技巧

Denoising Diffusion Probabilistic Models 论文:

其 TF 源码位于: https://github.com/hojonathanho/diffusion 

源码介绍以该版本为主

PyTorch 的开源实现: https://github.com/lucidrains/denoising-diffusion-pytorch, 核心逻辑和上面 TensorFlow 版本是一致的,Stable Diffusion 参考的是 pytorch 版本的代码

扩散模型介绍

基本原理

Diffusion Model (扩散模型) 是一类生成模型,和 VAE (Variational Autoencoder,变分自动编码器)、GAN (Generative Adversarial Network, 生成对抗网络) 等生成网络不同的是:扩散模型在前向阶段对图像逐步施加噪声, 直至图像被破坏变成完全的高斯噪声,然后在逆向阶段学习从高斯噪声还原为原始图像的过程。

具体来说, 前向阶段在原始图像 上逐步增加噪声,每一步得到的图像 只和上一步的结果 相关, 直至第 步的图像 变为纯高斯噪声。前向阶段图示如下:

图片

而逆向阶段则是不断去除噪声的过程,首先给定高斯噪声 ,通过逐步去噪, 直至最终将原图像 给恢复出来,逆向阶段图示如下:

图片

模型训练完成后, 只要给定高斯随机噪声,就可以生成一张从未见过的图像。下面分别介绍前向阶段和逆向阶段,只列出重要公式。

前向阶段

由于前向过程中图像 只和上一时刻的 有关,该过程可以视为马尔科夫过程, 满足:

图片

其中 为高斯分布的方差超参,并满足 。另外公式 (2) 中为何均值 前乘上系数 的原因将在后面的推导介绍。上述过程的一个美妙性质是我们可以在任意 time step 下通过重参数技巧采样得到 。

重参数技巧 (reparameterization trick) 是为了解决随机采样样本这一过程无法求导的问题。比如要从高斯分布 中采样样本 ,可以通过引入随机变量 ,使得 , 此时 依旧具有随机性, 且服从高斯分布 ,同时 与 (通常由网络生成) 可导。

简要了解了重参数技巧后, 再回到上面通过公式 (2) 采样 的方法,即生成随机变量 , 然后令 , 以及 ,从而可以得到:

其中公式 (3-1) 到公式 (3-2) 的推导是由于独立高斯分布的可见性,有:

图片

因此:

图片

注意公式 (3-2) 中 ,因此还需乘上 。从公式 (3) 可以看出:

图片

注意由于 且 , 而 , 因此 并且有 ,另外由于 ,因此当 时, 以及 , 此时 。从这里的推导来看,在公式 (2) 中的均值 前乘上系数 会使得 最后收敛到标准高斯分布。

逆向阶段

前向阶段是加噪声的过程, 而逆向阶段则是将噪声去除, 如果能得到逆向过程的分布 ,那么通过输入高斯噪声 ,我们将生成一个真实的样本。注意到当 足够小时, 也是高斯分布,具体的证明在 ewrfcas 的知乎文章: 由浅入深了解Diffusion Model 推荐的论文中: On the theory of stochastic processes, with particular reference to applications。我大致看了一下,哈哈,没太看明白。不过想到这个不是我关注的重点,因此 pass。由于我们无法直接推断 ,因此我们将使用深度学习模型 去拟合分布 , 模型参数为 :

图片

虽然我们无法直接求得 (注意这里是 而不是模型 ),但在知道 的情况下,可以通过贝叶斯公式得到 为:

图片

推导过程如下:

上面推导过程中,通过贝叶斯公式巧妙的将逆向过程转换为前向过程,且最终得到的概率密度函数和高斯概率密度函数的指数部分:

图片

能够对应, 即有:

通过公式 (8) 和公式 (9),我们能得到 (见公式 (7)) 的分布。此外由于公式 (3) 揭示的 和 之间的关系: , 可以得到:

图片

代入公式 (9) 中得到:

图片

补充一下公式 (11) 的详细推导过程:

图片

前面说到, 我们将使用深度学习模型 去拟合逆向过程的分布 , 由公式 (6) 知:

图片

,我们希望训练模型 以预估 

图片

。由于 在训练阶段会作为输入, 因此它是已知的, 我们可以转而让模型去预估噪声 , 即令:

图片

模型训练

前面谈到, 逆向阶段让模型去预估噪声 ,那么应该如何设计 Loss 函数?我们的目标是在真实数据分布下,最大化模型预测分布的对数似然,即优化在 下的 交叉熵:

图片

和 变分自动编码器 VAE 类似,使用 Variational Lower Bound 来优化: :

对公式 (15) 左右两边取期望 ,利用到重积分中的 Fubini 定理可得:

图片

因此最小化 就可以优化公式 (14) 中的目标函数。之后对 做进一步的推导,这部分的详细推导见上面的参考文章,最终的结论是:

图片

最终是优化两个高斯分布 

图片

(详见公式 (7)) 与 

图片

(详见公式(6),此为模型预估的分布)之间的 KL 散度。由于多元高斯分布的 KL 散度存在闭式解, 详见: Multivariate_normal_distributions,从而可以得到:

DDPM 将 Loss 简化为如下形式:

图片

因此 Diffusion 模型的目标函数即是学习高斯噪声    和 (来自模型输出) 之间的 MSE loss。

最终算法

最终 DDPM 的算法流程如下:

图片

训练阶段重复如下步骤:

从数据集中采样 

随机选取 time step

生成高斯噪声

调用模型预估

计算噪声之间的 MSE Loss: 

图片

, 并利用反向传播算法训练模型.

逆向阶段采用如下步骤进行采样:

从高斯分布采样

按照 的顺序进行迭代:

如果 , 令 ; 如果 , 从高斯分布中采样

利用公式 (12) 学习出均值

图片

, 并利用公式 (8) 计算均方差:

通过重参数技巧采样

经过以上过程的迭代, 最终恢复 .

源码分析

DDPM 文章以及代码的相关信息如下:

Denoising Diffusion Probabilistic Models 论文

其 TF 源码位于: https://github.com/hojonathanho/diffusion

源码介绍以该版本为主

PyTorch 的开源实现: https://github.com/lucidrains/denoising-diffusion-pytorch, 核心逻辑和上面 TensorFlow 版本是一致的,Stable Diffusion 参考的是 pytorch 版本的代码

本文以分析 TensorFlow 源码为主,Pytorch 版本的代码和 TensorFlow 版本的实现逻辑大体不差的,变量名字啥的都类似,阅读起来不会有啥门槛。Tensorlow 源码对 Diffusion 模型的实现位于 diffusion_utils_2.py,模型本身的分析以该文件为主。

训练阶段

以 CIFAR 数据集为例。

在 run_cifar.py 中进行前向传播计算 Loss:

图片

第 6 行随机选出

第 7 行 training_losses 定义在 GaussianDiffusion2 中,计算噪声间的 MSE Loss

进入 GaussianDiffusion2 中,看到初始化函数中定义了诸多变量,我在注释中使用公式的方式进行了说明:

图片

下面进入到 training_losses 函数中:

图片

第 19 行: self.model_mean_type 默认是 eps,模型学习的是噪声, 因此 target 是第 6 行定义的 noise,即

第  9 行: 调用 self.q_sample 计算 , 即公式 (3)

第 21 行: denoise_fn 是定义在 unet.py 中的 UNet 模型,只需知道它的输入和输出大小相同;结合第 9 行得到的 , 得到模型预估的噪声:

第 23 行: 计算两个噪声之间的 MSE:

图片

,并利用反向传播算法训练模型

上面第 9 行定义的 self.q_sample 详情如下:

图片

第 13 行的 q_sample 已经介绍过,不多说

第 2 行的  _extract 在代码中经常被使用到, 看到它只需知道它是用来提取系数的即可. 引入输入是一个 Batch,里面的每个样本都会随机采样一个 time step t,因此需要使用 tf.gather 来将 之类选出来,然后将系数 reshape 为 [B, 1, 1, ....] 的形式,目的是为了利用 broadcasting 机制和 这个 Tensor 相乘。

前向的训练阶段代码实现非常简单,下面看逆向阶段:

逆向阶段

逆向阶段代码定义在 GaussianDiffusion2 中:

图片

第 5 行生成高斯噪声 ,然后对其不断去噪直至恢复原始图像

第 11 行的 self.p_sample 就是公式 (6):

图片

的过程,使用模型来预估 以及

第 12 行的 denoise_fn 在前面说过, 是定义在 unet.py 中的 UNet 模型; img_ 表示 .

第 13 行的 noise_fn 则默认是 tf.random_normal,用于生成高斯噪声

进入 p_sample 函数:

图片

第 7 行调用 self.p_mean_variance 生成 以及 , 其中 通过计算 得到

第 11 行从高斯分布中采样 

第 18 行通过重参数技巧采样 , 其中

进入 self.p_mean_variance 函数:

图片

第 6 行调用模型 denoise_fn,通过输入 , 输出得到噪声

第 19 行 self.model_var_type 默认为 fixedlarge,但我当时看fixedsmall 比较爽,因此 model_variance 和 model_log_variance 分别为 (见公式 8),以及

第 29 行调用 self._predict_xstart_from_eps 函数,利用公式 (10) 得到

第 30 行调用 self.q_posterior_mean_variance 通过公式 (9) 得到:

图片

self._predict_xstart_from_eps 函数详情如下:

图片

该函数计算

self.q_posterior_mean_variance 函数详情如下:

图片

相关说明见注释, 另外发现对于 的计算使用的是公式 (9):

图片

而不是进一步推导后的公式 (11) :

图片

总结

    本站是提供个人知识管理的网络存储空间,所有内容均由用户发布,不代表本站观点。请注意甄别内容中的联系方式、诱导购买等信息,谨防诈骗。如发现有害或侵权内容,请点击一键举报。
    转藏 分享 献花(0

    0条评论

    发表

    请遵守用户 评论公约

    类似文章 更多