参考文献:
什么是扩散模型?
Generative Modeling by Estimating Gradients of the Data Distribution
Understanding Diffusion Models: A Unified Perspective
VAE、Flow-based Model、GAN的优缺点
给定来自感兴趣分布的观察到的样本x x x ,生成模型的目标是学习对其真实数据分布p ( x ) p(x) p ( x ) 进行建模。一旦学会,我们就可以随意从我们的近似模型中生成新的样本。现有的生成模型技术针对表示概率分布的方法可以被分为两类:
基于似然的模型 :通过(近似)最大似然直接学习分布的概率密度(或质量)函数。典型的基于似然的模型包括autoregressive models、flow-based models、energy-based models(EBM)和variational auto-encoders(VAE)。
隐式生成模型 :其中概率分布由其采样过程的模型隐式表示。最突出的例子是生成对抗网络(GAN),其中通过用生成器转换随机高斯向量来合成来自数据分布的新样本。
VAE(Variational Auto-Encoder)地基本思想是首先采样 z ∼ p ( z ) z\sim p(z) z ∼ p ( z ) ,然后根据得到的 z z z 来采样 x ∼ p θ ( x ∣ z ) x\sim p_\theta(x|z) x ∼ p θ ( x ∣ z ) 。我们可以把 p θ ( x ∣ z ) p_\theta(x|z) p θ ( x ∣ z ) 理解成解码器,它把标准高斯 z z z 通过某种随机映射的方式映射到了数据分布 x x x 上。其中隐变量 z z z 满足 p ( z ) = N ( 0 , I ) p(z)=\mathcal{N}(0,I) p ( z ) = N ( 0 , I ) ,条件分布 p θ ( x ∣ z ) p_\theta(x|z) p θ ( x ∣ z ) 一般是参数化的高斯分布或伯努利分布。当训练好模型后,只需要按照“祖先采样”(ancestral sampling)的方法就可以生成数据x x x 。那么在VAE里面如何表示出我们地p ( x ) p(x) p ( x ) ?
在数学上,我们可以把我们观察到的潜在变量和数据想象成一个联合分布 p ( x , z ) p(x, z) p ( x , z ) 的模型。我们有两种方法可以使用这个联合分布来恢复纯观测数据 p ( x ) p(x) p ( x ) 。
将潜在变量 z z z 边缘化
p ( x ) = ∫ p ( x , z ) d z p(x) =\int p(x, z) dz
p ( x ) = ∫ p ( x , z ) d z
使用链式规则
p ( x ) = p ( x , z ) p ( z ∣ x ) p(x)=\frac {p(x,z)}{p(z| x)}
p ( x ) = p ( z ∣ x ) p ( x , z )
训练VAE要基于最大似然 log p ( x ) \log p(x) log p ( x ) ,然而直接计算和最大化 p ( x ) p(x) p ( x ) 是困难的,因为它要么涉及到对联合概率分布中的所有潜在变量 z z z 进行积分,这对于复杂的模型来说是棘手的,要么涉及到在2中训练潜在变量编码器p ( z ∣ x ) p(z| x) p ( z ∣ x ) 。通常我们都需要借助变分推断(variational inference)的技巧,即采用q ϕ ( z ∣ x ) q_\phi(z | x) q ϕ ( z ∣ x ) 来近似真实后验,模型似然可以有一个下界ELBO(Evidence Lower Bound):
log p ( x ) = log ∫ p ( x , z ) d z = log ∫ p ( x , z ) q ϕ ( z ∣ x ) q ϕ ( z ∣ x ) d z = log E q ϕ ( z ∣ x ) [ p ( x , z ) q ϕ ( z ∣ x ) ] ≥ E q ϕ ( z ∣ x ) [ log p ( x , z ) q ϕ ( z ∣ x ) ] = E q ϕ ( z ∣ x ) [ log p θ ( x ∣ z ) p ( z ) q ϕ ( z ∣ x ) ] = E q ϕ ( z ∣ x ) [ log p θ ( x ∣ z ) ] + E q ϕ ( z ∣ x ) [ log p ( z ) q ϕ ( z ∣ x ) ] = E q ϕ ( z ∣ x ) [ log p θ ( x ∣ z ) ] ⏟ reconstruction term − D K L ( q ϕ ( z ∣ x ) ∥ p ( z ) ) ⏟ prior matching term \begin{aligned}
\log p(x) & =\log \int p(x, z) d z \\
& =\log \int \frac{p(x, z) q_\phi(z | x)}{q_\phi(z | x)} d z \\
& =\log \mathbb{E}_{q_\phi(z | x)}\left[\frac{p(x, z)}{q_\phi(z | x)}\right] \\
& \geq \mathbb{E}_{q_\phi(z | x)}\left[\log \frac{p(x, z)}{q_\phi(z | x)}\right] \\
& =\mathbb{E}_{q_\phi(z | x)}\left[\log \frac{p_{\theta}(x | z) p(z)}{q_\phi(z | x)}\right] \\
& =\mathbb{E}_{q_\phi(z | x)}\left[\log p_\theta(x | z)\right]+\mathbb{E}_{q_\phi(z | x)}\left[\log \frac{p(z)}{q_\phi(z | x)}\right] \\
& =\underbrace{\mathbb{E}_{q_\phi(z | x)}\left[\log p_\theta(x | z)\right]}_{\text {reconstruction term }}-\underbrace{D_{\mathrm{KL}}\left(q_\phi(z | x) \| p(z)\right)}_{\text {prior matching term }}
\end{aligned}
log p ( x ) = log ∫ p ( x , z ) d z = log ∫ q ϕ ( z ∣ x ) p ( x , z ) q ϕ ( z ∣ x ) d z = log E q ϕ ( z ∣ x ) [ q ϕ ( z ∣ x ) p ( x , z ) ] ≥ E q ϕ ( z ∣ x ) [ log q ϕ ( z ∣ x ) p ( x , z ) ] = E q ϕ ( z ∣ x ) [ log q ϕ ( z ∣ x ) p θ ( x ∣ z ) p ( z ) ] = E q ϕ ( z ∣ x ) [ log p θ ( x ∣ z ) ] + E q ϕ ( z ∣ x ) [ log q ϕ ( z ∣ x ) p ( z ) ] = reconstruction term E q ϕ ( z ∣ x ) [ log p θ ( x ∣ z ) ] − prior matching term D K L ( q ϕ ( z ∣ x ) ∥ p ( z ) )
训练VAE本质上就是最大化不等式的右边,这需要同时训练 p θ ( x ∣ z ) p_\theta(x|z) p θ ( x ∣ z ) 和 q ϕ ( z ∣ x ) q_\phi(z|x) q ϕ ( z ∣ x ) ,等号成立当且仅当 q ϕ ( z ∣ x ) q_\phi(z|x) q ϕ ( z ∣ x ) 等于真实后验。然而,VAE领域这么多年的核心问题就是这个变分后验(variational posterior) q ϕ ( z ∣ x ) q_\phi(z|x) q ϕ ( z ∣ x ) 很难选择。如果选得比较简单,那很可能没办法近似真实后验,导致模型效果不好;而如果选得比较复杂, log p ( x ) \log p(x) log p ( x ) 又会很难计算,导致难以优化。
相对于VAE,Flow-based Model和GAN只需要考虑生成器,用生成器把采样的高斯噪声 z z z 映射到数据分布 p θ ( x ) p_\theta(x) p θ ( x ) ,根本不关心这个后验分布到底是啥。但是两者也有相应的缺点:
Flow-based Model 要求模型是可逆函数
GAN 需要额外训练判别器,导致训练时间很长,浪费了一部分计算资源
为了克服这些问题,Diffusion Model横空出世。回想VAE,最大的难题是变分后验很难选择,这是因为我们首先定义了Decoder(条件分布 p θ ( x ∣ z ) p_\theta(x|z) p θ ( x ∣ z ) ),然后才定义了 Encoder(变分后验q ϕ ( z ∣ x ) q_\phi(z|x) q ϕ ( z ∣ x ) )来适配这个Decoder。那么我们是否可以做到先造一个Encoder让他来模拟这个变分后验让数据到标准高斯,然后定义Decoder去适配它。换句话说,我们能否先定义某种简单的过程,把数据分布映射到标准高斯,然后完成一个逆过程,就能生成样本。这就是diffusion model的核心思想:匹配简单前向过程对应的逆过程的每一小步。
那么如何描述这个简单的过程呢?他的本质其实就是从一个数据分布转换成另一个分布。那么根据随机过程,如果我们可以构造适当的马尔科夫链,使得不管从什么分布出发,沿着马尔可夫链一直采样下去最终可以得到我们想要的平稳分布(stationary distribution),这也是马尔可夫链蒙特卡洛(MCMC)算法的核心。Diffusion的核心就是两方面:如何采样到高斯分布?(前向过程)如何拟合采样过程?(逆向过程)
Diffusion的前向过程和逆向过程
扩散概率模型(简写为“扩散模型”)是一个参数化的马尔可夫链(parameterized Markov chain),通过变分推断(variational inference)进行训练,以生成在有限时间内匹配数据的样本。该链的转移是学习的,以便逆转扩散过程,即一个马尔可夫链,该链在采样的相反方向逐渐向数据中添加噪声,直到信号被破坏。当扩散过程由小量高斯噪声组成时,仅需要将采样链转移设置为条件高斯,从而允许使用特别简单的神经网络参数化。简单来说。就是diffusion model定义了一个简单的前向过程,不断地加噪来把真实数据映射到标准高斯;然后又定义一个逆向过程来去噪,并且逆向过程的每一步只需要是一个很简单的高斯分布。
Diffusion model最简单的理解方式是将其视为具有三个关键限制假设的马尔可夫层次变分自编码器(Hierarchical Variational Autoencoders):
潜在变量维度与数据维度完全相等
在每个时间步骤中,潜在变量编码器的结构不是通过学习得到的,而是预定义为线性高斯模型,即它是围绕前一个时间步骤的输出中心化的高斯分布
潜在变量编码器的高斯参数随时间变化,以使得最终时间步骤T的潜变量分布为标准高斯分布
扩散模型是形式 p θ ( x 0 ) p_{\theta}(\mathbf{x}_0) p θ ( x 0 ) 的潜变量模型,根据第一个关键假设,其中 x 1 , … , x T \mathbf{x}_1,\dots,\mathbf{x}_T x 1 , … , x T 是与真实数据样本x 0 ∼ q ( x 0 ) \mathbf{x}_0 \sim q(\mathbf{x}_0) x 0 ∼ q ( x 0 ) 具有相同维度的潜变量。联合分布 p θ ( x 0 : T ) p_{\theta}(\mathbf{x}_{0:T}) p θ ( x 0 : T ) 被称为逆过程(reverse process),它被定义为一个从p ( x T ) = N ( x T ; 0 , I ) p(\mathbf{x}_T)=\mathcal{N}(\mathbf{x}_T;0,I) p ( x T ) = N ( x T ; 0 , I ) 开始的具有学习高斯转移的马尔可夫链,他用来模拟近似前向过程的逆过程:
p θ ( x 0 : T ) : = p ( x T ) ∏ t = 1 T p θ ( x t − 1 ∣ x t ) , p θ ( x t − 1 ∣ x t ) : = N ( x t − 1 ; μ θ ( x t , t ) , Σ θ ( x t , t ) ) p_{\theta}(\mathbf{x}_{0: T}):=p(\mathbf{x}_T) \prod_{t=1}^T p_{\theta}(\mathbf{x}_{t-1} | \mathbf{x}_t), \quad p_{\theta}(\mathbf{x}_{t-1} | \mathbf{x}_t):=\mathcal{N}(\mathbf{x}_{t-1} ; \mu_{\theta}(\mathbf{x}_t, t), \Sigma_{\theta}(\mathbf{x}_t, t))
p θ ( x 0 : T ) : = p ( x T ) t = 1 ∏ T p θ ( x t − 1 ∣ x t ) , p θ ( x t − 1 ∣ x t ) : = N ( x t − 1 ; μ θ ( x t , t ) , Σ θ ( x t , t ) )
根据第二个假设,我们知道编码器中每个潜变量的分布都是围绕其前一个层次潜变量的高斯分布。在每个时间步骤t中,编码器的结构不是通过学习得到的;它被固定为线性高斯模型,其中均值和标准差可以预先设置为超参数或学习为参数。扩散模型与其他类型的潜变量模型的区别在于,其近似后验编码器 q ( x 1 : T ∣ x 0 ) q(\mathbf{x}_{1: T} | \mathbf{x}_0) q ( x 1 : T ∣ x 0 ) ,称为前向过程(forward process)或扩散过程(diffusion process),被固定为一个马尔可夫链,该链根据步长的方差时间表 β 1 , … , β T \beta_1,\dots,\beta_T β 1 , … , β T 逐渐向数据添加高斯噪声:
q ( x 1 : T ∣ x 0 ) : = ∏ t = 1 T q ( x t ∣ x t − 1 ) , q ( x t ∣ x t − 1 ) : = N ( x t ; 1 − β t x t − 1 , β t I ) q(\mathbf{x}_{1: T} | \mathbf{x}_0):=\prod_{t=1}^T q(\mathbf{x}_t | \mathbf{x}_{t-1}), \quad q(\mathbf{x}_t | \mathbf{x}_{t-1}):=\mathcal{N}(\mathbf{x}_t ; \sqrt{1-\beta_t} \mathbf{x}_{t-1}, \beta_t I)
q ( x 1 : T ∣ x 0 ) : = t = 1 ∏ T q ( x t ∣ x t − 1 ) , q ( x t ∣ x t − 1 ) : = N ( x t ; 1 − β t x t − 1 , β t I )
前向过程的方差时间表可以通过重新参数化来学习,也可以作为超参数保持不变,并且 p θ ( x t − 1 ∣ x t ) p_{\theta}(\mathbf{x}_{t-1} | \mathbf{x}_t) p θ ( x t − 1 ∣ x t ) 中高斯条件的选择部分保证了逆过程的表现力,因为当β t \beta_t β t 很小时,两个过程具有相同的函数形式。前向过程的一个显着特性是,它允许在任意时间步 t t t 中以闭合形式对 x t \mathbf{x}_t x t 进行采样,这也就是为什么算法中可以直接加入任意时刻噪音的原因,不用一步一步的加噪:使用符号α t : = 1 − β t \alpha_t:=1-\beta_t α t : = 1 − β t 和α t ˉ : = ∏ s = 1 t α s \bar{\alpha_t}:=\prod_{s=1}^t \alpha_s α t ˉ : = ∏ s = 1 t α s ,我们有:
q ( x t ∣ x 0 ) : = N ( x t ; α t ˉ x 0 , ( 1 − α t ˉ ) I ) q(\mathbf{x}_t | \mathbf{x}_{0}):=\mathcal{N}(\mathbf{x}_t ; \sqrt{\bar{\alpha_t}} \mathbf{x}_{0}, (1-\bar{\alpha_t}) I)
q ( x t ∣ x 0 ) : = N ( x t ; α t ˉ x 0 , ( 1 − α t ˉ ) I )
根据重参数的表示和高斯分布的特性(两个独立高斯随机变量的和仍然是高斯分布,其中均值是两个均值的和,方差是两个方差的和)就能推导出上面这个公式:
x t = α t x t − 1 + 1 − α t ϵ t − 1 = α t ( α t − 1 x t − 2 + 1 − α t − 1 ϵ t − 2 ) + 1 − α t ϵ t − 1 = α t α t − 1 x t − 2 + α t − α t α t − 1 ϵ t − 2 + 1 − α t ϵ t − 1 = α t α t − 1 x t − 2 + ( α t − α t α t − 1 ) 2 + ( 1 − α t ) 2 ϵ t − 2 = α t α t − 1 x t − 2 + α t − α t α t − 1 + 1 − α t ϵ t − 2 = α t α t − 1 x t − 2 + 1 − α t α t − 1 ϵ t − 2 = … = ∏ i = 1 t α i x 0 + 1 − ∏ i = 1 t α i ϵ 0 = α ˉ t x 0 + 1 − α ˉ t ϵ 0 ∼ N ( x t ; α ˉ t x 0 , ( 1 − α ˉ t ) I ) \begin{aligned}
\mathbf{x}_t& =\sqrt{\alpha_t} \mathbf{x}_{t-1}+\sqrt{1-\alpha_t} \epsilon_{t-1} \\
& =\sqrt{\alpha_t}\left(\sqrt{\alpha_{t-1}} \mathbf{x}_{t-2}+\sqrt{1-\alpha_{t-1}} \epsilon_{t-2}\right)+\sqrt{1-\alpha_t} \epsilon_{t-1} \\
& =\sqrt{\alpha_t \alpha_{t-1}} \mathbf{x}_{t-2}+\sqrt{\alpha_t-\alpha_t \alpha_{t-1}} \epsilon_{t-2}+\sqrt{1-\alpha_t} \epsilon_{t-1} \\
& =\sqrt{\alpha_t \alpha_{t-1}} \mathbf{x}_{t-2}+\sqrt{({\sqrt{\alpha_t-\alpha_t \alpha_{t-1}}})^2+{(\sqrt{1-\alpha_t}})^2} \epsilon_{t-2} \\
& =\sqrt{\alpha_t \alpha_{t-1}} \mathbf{x}_{t-2}+\sqrt{\alpha_t-\alpha_t \alpha_{t-1}+1-\alpha_t} \epsilon_{t-2} \\
& =\sqrt{\alpha_t \alpha_{t-1}} \mathbf{x}_{t-2}+\sqrt{1-\alpha_t \alpha_{t-1}} \epsilon_{t-2} \\
& =\dots \\
& =\sqrt{\prod_{i=1}^t \alpha_i} \mathbf{x}_0+\sqrt{1-\prod_{i=1}^t \alpha_i}\epsilon_0 \\
& =\sqrt{\bar{\alpha}_t} \mathbf{x}_0+\sqrt{1-\bar{\alpha}_t} \epsilon_0 \\
& \sim \mathcal{N}\left(\mathbf{x}_t;\sqrt{\bar{\alpha}_t} \mathbf{x}_0,\left(1-\bar{\alpha}_t\right) I\right) \\
\end{aligned}
x t = α t x t − 1 + 1 − α t ϵ t − 1 = α t ( α t − 1 x t − 2 + 1 − α t − 1 ϵ t − 2 ) + 1 − α t ϵ t − 1 = α t α t − 1 x t − 2 + α t − α t α t − 1 ϵ t − 2 + 1 − α t ϵ t − 1 = α t α t − 1 x t − 2 + ( α t − α t α t − 1 ) 2 + ( 1 − α t ) 2 ϵ t − 2 = α t α t − 1 x t − 2 + α t − α t α t − 1 + 1 − α t ϵ t − 2 = α t α t − 1 x t − 2 + 1 − α t α t − 1 ϵ t − 2 = … = i = 1 ∏ t α i x 0 + 1 − i = 1 ∏ t α i ϵ 0 = α ˉ t x 0 + 1 − α ˉ t ϵ 0 ∼ N ( x t ; α ˉ t x 0 , ( 1 − α ˉ t ) I )
其中 α t ˉ \bar{\alpha_t} α t ˉ 是一个关于 t t t 递减的函数。此外,如果合适地选择 β t \beta_t β t ,我们还可以得到lim t → ∞ α ˉ t = 0 \lim_{t\rightarrow \infty}\bar\alpha_t =0 lim t → ∞ α ˉ t = 0 ,这意味着:
lim t → ∞ q ( x t ∣ x 0 ) = N ( 0 , I ) \lim_{t\rightarrow \infty}q(\mathbf{x}_t|\mathbf{x}_0)=\mathcal{N}(0,I)
t → ∞ lim q ( x t ∣ x 0 ) = N ( 0 , I )
也就是说这个条件分布最终会收敛到一个和 x 0 \mathbf{x}_0 x 0 无关的分布,因此可以证明边缘分布也会收敛到标准高斯分布,即:
lim t → ∞ q ( x t ) = N ( 0 , I ) \lim_{t\rightarrow\infty}q(\mathbf{x}_t)=\mathcal{N}(0,I)
t → ∞ lim q ( x t ) = N ( 0 , I )
这样一来,只要我们取一个足够大的终止时刻 N N N ,我们就可以逐渐把数据分布 x 0 \mathbf{x}_0 x 0 映射到一个非常接近高斯分布的 x N \mathbf{x}_N x N 。我们把这样的随机过程称为前向过程(forward process)。并且,由于 q ( x t ∣ x t − 1 ) q(\mathbf{x}_t|\mathbf{x}_{t-1}) q ( x t ∣ x t − 1 ) 本质上就是把 x t − 1 \mathbf{x}_{t-1} x t − 1 放缩后加上一点小小的噪声,我们也可以把这个过程理解为逐渐给数据加噪声,这样的过程也称为扩散过程(diffusion process),这也是diffusion model名字的由来。然后通过逆过程(reverse process)你就能把每一步都转换成近似高斯,然后反复去噪x N \mathbf{x}_N x N 生成最后的结果。
Diffusion的优化
总的来说diffusion有三种理解和优化,他们其实是等价的。s θ s_\theta s θ 是拟合score function的模型, ϵ θ \epsilon_\theta ϵ θ 是预测数据中的噪声的模型(DDPM用的就是这种参数化方法), x θ \mathbf{x}_\theta x θ 是估计加噪数据对应的原始数据的模型(去噪模型)。
Diffusion的训练是通过优化负对数似然的ELBO来完成的:
log p ( x ) = log ∫ p ( x 0 : T ) d x 1 : T = log ∫ p ( x 0 : T ) q ( x 1 : T ∣ x 0 ) q ( x 1 : T ∣ x 0 ) d x 1 : T = log E q ( x 1 : T ∣ x 0 ) [ p ( x 0 : T ) q ( x 1 : T ∣ x 0 ) ] ≥ E q ( x 1 : T ∣ x 0 ) [ log p ( x 0 : T ) q ( x 1 : T ∣ x 0 ) ] = E q ( x 1 : T ∣ x 0 ) [ log p ( x T ) ∏ t = 1 T p θ ( x t − 1 ∣ x t ) ∏ t = 1 T q ( x t ∣ x t − 1 ) ] = E q ( x 1 : T ∣ x 0 ) [ log p ( x T ) p θ ( x 0 ∣ x 1 ) ∏ t = 2 T p θ ( x t − 1 ∣ x t ) q ( x T ∣ x T − 1 ) ∏ t = 1 T − 1 q ( x t ∣ x t − 1 ) ] = E q ( x 1 : T ∣ x 0 ) [ log p ( x T ) p θ ( x 0 ∣ x 1 ) ∏ t = 1 T − 1 p θ ( x t ∣ x t + 1 ) q ( x T ∣ x T − 1 ) ∏ t = 1 T − 1 q ( x t ∣ x t − 1 ) ] = E q ( x 1 : T ∣ x 0 ) [ log p ( x T ) p θ ( x 0 ∣ x 1 ) q ( x T ∣ x T − 1 ) ] + E q ( x 1 : T ∣ x 0 ) [ log ∏ t = 1 T − 1 p θ ( x t ∣ x t + 1 ) q ( x t ∣ x t − 1 ) ] = E q ( x 1 : T ∣ x 0 ) [ log p θ ( x 0 ∣ x 1 ) ] + E q ( x 1 : T ∣ x 0 ) [ log p ( x T ) q ( x T ∣ x T − 1 ) ] + E q ( x 1 : T ∣ x 0 ) [ ∑ t = 1 T − 1 log p θ ( x t ∣ x t + 1 ) q ( x t ∣ x t − 1 ) ] = E q ( x 1 : T ∣ x 0 ) [ log p θ ( x 0 ∣ x 1 ) ] + E q ( x 1 : T ∣ x 0 ) [ log p ( x T ) q ( x T ∣ x T − 1 ) ] + ∑ t = 1 T − 1 E q ( x 1 : T ∣ x 0 ) [ log p θ ( x t ∣ x t + 1 ) q ( x t ∣ x t − 1 ) ] = E q ( x 1 ∣ x 0 ) [ log p θ ( x 0 ∣ x 1 ) ] + E q ( x T − 1 , x T ∣ x 0 ) [ log p ( x T ) q ( x T ∣ x T − 1 ) ] + ∑ t = 1 T − 1 E q ( x t − 1 , x t , x t + 1 ∣ x 0 ) [ log p θ ( x t ∣ x t + 1 ) q ( x t ∣ x t − 1 ) ] = E q ( x 1 ∣ x 0 ) [ log p θ ( x 0 ∣ x 1 ) ] ⏟ reconstruction term − E q ( x T − 1 ∣ x 0 ) [ D K L ( q ( x T ∣ x T − 1 ) ∥ p ( x T ) ) ] ⏟ prior matching term − ∑ t = 1 T − 1 E q ( x t − 1 , 1 x t + 1 ∣ x 0 ) [ D K L ( q ( x t ∣ x t − 1 ) ∥ p θ ( x t ∣ x t + 1 ) ) ] ⏟ consistency term \begin{aligned}
\log p(x)&=\log \int p\left(\mathbf{x}_{0: T}\right) d \mathbf{x}_{1: T} \\
& =\log \int \frac{p\left(\mathbf{x}_{0: T}\right) q\left(\mathbf{x}_{1: T} | \mathbf{x}_0\right)}{q\left(\mathbf{x}_{1: T} | \mathbf{x}_0\right)} d \mathbf{x}_{1: T} \\
& =\log \mathbb{E}_{q\left(\mathbf{x}_{1: T} | \mathbf{x}_0\right)}\left[\frac{p\left(\mathbf{x}_{0: T}\right)}{q\left(\mathbf{x}_{1: T} | \mathbf{x}_0\right)}\right] \\
& \geq \mathbb{E}_{q\left(\mathbf{x}_{1: T} | \mathbf{x}_0\right)}\left[\log \frac{p\left(\mathbf{x}_{0: T}\right)}{q\left(\mathbf{x}_{1: T} | \mathbf{x}_0\right)}\right] \\
& =\mathbb{E}_{q\left(\mathbf{x}_{1: T} | \mathbf{x}_0\right)}\left[\log \frac{p\left(\mathbf{x}_T\right) \prod_{t=1}^T p_{\theta}\left(\mathbf{x}_{t-1} | \mathbf{x}_t\right)}{\prod_{t=1}^T q\left(\mathbf{x}_t | \mathbf{x}_{t-1}\right)}\right] \\
& =\mathbb{E}_{q\left(\mathbf{x}_{1: T} | \mathbf{x}_0\right)}\left[\log \frac{p\left(\mathbf{x}_T\right) p_{\theta}\left(\mathbf{x}_0 | \mathbf{x}_1\right) \prod_{t=2}^T p_{\theta}\left(\mathbf{x}_{t-1} | \mathbf{x}_t\right)}{q\left(\mathbf{x}_T | \mathbf{x}_{T-1}\right) \prod_{t=1}^{T-1} q\left(\mathbf{x}_t | \mathbf{x}_{t-1}\right)}\right] \\
& =\mathbb{E}_{q\left(\mathbf{x}_{1: T} | \mathbf{x}_0\right)}\left[\log \frac{p\left(\mathbf{x}_T\right) p_{\theta}\left(\mathbf{x}_0 | \mathbf{x}_1\right) \prod_{t=1}^{T-1} p_{\theta}\left(\mathbf{x}_t | \mathbf{x}_{t+1}\right)}{q\left(\mathbf{x}_T | \mathbf{x}_{T-1}\right) \prod_{t=1}^{T-1} q\left(\mathbf{x}_t | \mathbf{x}_{t-1}\right)}\right] \\
& =\mathbb{E}_{q\left(\mathbf{x}_{1: T} | \mathbf{x}_0\right)}\left[\log \frac{p\left(\mathbf{x}_T\right) p_{\theta}\left(\mathbf{x}_0 | \mathbf{x}_1\right)}{q\left(\mathbf{x}_T | \mathbf{x}_{T-1}\right)}\right]+\mathbb{E}_{q\left(\mathbf{x}_{1: T} | \mathbf{x}_0\right)}\left[\log \prod_{t=1}^{T-1} \frac{p_{\theta}\left(\mathbf{x}_t | \mathbf{x}_{t+1}\right)}{q\left(\mathbf{x}_t | \mathbf{x}_{t-1}\right)}\right] \\
& =\mathbb{E}_{q\left(\mathbf{x}_{1: T} | \mathbf{x}_0\right)}\left[\log p_{\theta}\left(\mathbf{x}_0 | \mathbf{x}_1\right)\right]+\mathbb{E}_{q\left(\mathbf{x}_{1: T} | \mathbf{x}_0\right)}\left[\log \frac{p\left(\mathbf{x}_T\right)}{q\left(\mathbf{x}_T | \mathbf{x}_{T-1}\right)}\right]+\mathbb{E}_{q\left(\mathbf{x}_{1: T} | \mathbf{x}_0\right)}\left[\sum_{t=1}^{T-1} \log \frac{p_{\theta}\left(\mathbf{x}_t | \mathbf{x}_{t+1}\right)}{q\left(\mathbf{x}_t | \mathbf{x}_{t-1}\right)}\right] \\
& =\mathbb{E}_{q\left(\mathbf{x}_{1: T} | \mathbf{x}_0\right)}\left[\log p_{\theta}\left(\mathbf{x}_0 | \mathbf{x}_1\right)\right]+\mathbb{E}_{q\left(\mathbf{x}_{1: T} | \mathbf{x}_0\right)}\left[\log \frac{p\left(\mathbf{x}_T\right)}{q\left(\mathbf{x}_T | \mathbf{x}_{T-1}\right)}\right]+\sum_{t=1}^{T-1} \mathbb{E}_{q\left(\mathbf{x}_{1: T} | \mathbf{x}_0\right)}\left[\log \frac{p_{\theta}\left(\mathbf{x}_t | \mathbf{x}_{t+1}\right)}{q\left(\mathbf{x}_t | \mathbf{x}_{t-1}\right)}\right] \\
& =\mathbb{E}_{q\left(\mathbf{x}_1 | \mathbf{x}_0\right)}\left[\log p_{\theta}\left(\mathbf{x}_0 | \mathbf{x}_1\right)\right]+\mathbb{E}_{q\left(\mathbf{x}_{T-1}, \mathbf{x}_T | \mathbf{x}_0\right)}\left[\log \frac{p\left(\mathbf{x}_T\right)}{q\left(\mathbf{x}_T | \mathbf{x}_{T-1}\right)}\right]+\sum_{t=1}^{T-1} \mathbb{E}_{q\left(\mathbf{x}_{t-1}, \mathbf{x}_t, \mathbf{x}_{t+1} | \mathbf{x}_0\right)}\left[\log \frac{p_{\theta}\left(\mathbf{x}_t | \mathbf{x}_{t+1}\right)}{q\left(\mathbf{x}_t | \mathbf{x}_{t-1}\right)}\right] \\
& =\underbrace{\mathbb{E}_{q\left(\mathbf{x}_1 | \mathbf{x}_0\right)}\left[\log p_\theta\left(\mathbf{x}_0 | \mathbf{x}_1\right)\right]}_{\text {reconstruction term }}-\underbrace{\mathbb{E}_{q\left(\mathbf{x}_{T-1} | \mathbf{x}_0\right)}\left[D_{\mathrm{KL}}\left(q\left(\mathbf{x}_T | \mathbf{x}_{T-1}\right) \| p\left(\mathbf{x}_T\right)\right)\right]}_{\text {prior matching term }} \\
& \quad \quad -\sum_{t=1}^{T-1} \underbrace{\mathbb{E}_{q\left(\mathbf{x}_{t-1,1} \mathbf{x}_{t+1} | \mathbf{x}_0\right)}\left[D_{\mathrm{KL}}\left(q\left(\mathbf{x}_t | \mathbf{x}_{t-1}\right) \| p_\theta\left(\mathbf{x}_t | \mathbf{x}_{t+1}\right)\right)\right]}_{\text {consistency term }}
\end{aligned}
log p ( x ) = log ∫ p ( x 0 : T ) d x 1 : T = log ∫ q ( x 1 : T ∣ x 0 ) p ( x 0 : T ) q ( x 1 : T ∣ x 0 ) d x 1 : T = log E q ( x 1 : T ∣ x 0 ) [ q ( x 1 : T ∣ x 0 ) p ( x 0 : T ) ] ≥ E q ( x 1 : T ∣ x 0 ) [ log q ( x 1 : T ∣ x 0 ) p ( x 0 : T ) ] = E q ( x 1 : T ∣ x 0 ) [ log ∏ t = 1 T q ( x t ∣ x t − 1 ) p ( x T ) ∏ t = 1 T p θ ( x t − 1 ∣ x t ) ] = E q ( x 1 : T ∣ x 0 ) [ log q ( x T ∣ x T − 1 ) ∏ t = 1 T − 1 q ( x t ∣ x t − 1 ) p ( x T ) p θ ( x 0 ∣ x 1 ) ∏ t = 2 T p θ ( x t − 1 ∣ x t ) ] = E q ( x 1 : T ∣ x 0 ) [ log q ( x T ∣ x T − 1 ) ∏ t = 1 T − 1 q ( x t ∣ x t − 1 ) p ( x T ) p θ ( x 0 ∣ x 1 ) ∏ t = 1 T − 1 p θ ( x t ∣ x t + 1 ) ] = E q ( x 1 : T ∣ x 0 ) [ log q ( x T ∣ x T − 1 ) p ( x T ) p θ ( x 0 ∣ x 1 ) ] + E q ( x 1 : T ∣ x 0 ) [ log t = 1 ∏ T − 1 q ( x t ∣ x t − 1 ) p θ ( x t ∣ x t + 1 ) ] = E q ( x 1 : T ∣ x 0 ) [ log p θ ( x 0 ∣ x 1 ) ] + E q ( x 1 : T ∣ x 0 ) [ log q ( x T ∣ x T − 1 ) p ( x T ) ] + E q ( x 1 : T ∣ x 0 ) [ t = 1 ∑ T − 1 log q ( x t ∣ x t − 1 ) p θ ( x t ∣ x t + 1 ) ] = E q ( x 1 : T ∣ x 0 ) [ log p θ ( x 0 ∣ x 1 ) ] + E q ( x 1 : T ∣ x 0 ) [ log q ( x T ∣ x T − 1 ) p ( x T ) ] + t = 1 ∑ T − 1 E q ( x 1 : T ∣ x 0 ) [ log q ( x t ∣ x t − 1 ) p θ ( x t ∣ x t + 1 ) ] = E q ( x 1 ∣ x 0 ) [ log p θ ( x 0 ∣ x 1 ) ] + E q ( x T − 1 , x T ∣ x 0 ) [ log q ( x T ∣ x T − 1 ) p ( x T ) ] + t = 1 ∑ T − 1 E q ( x t − 1 , x t , x t + 1 ∣ x 0 ) [ log q ( x t ∣ x t − 1 ) p θ ( x t ∣ x t + 1 ) ] = reconstruction term E q ( x 1 ∣ x 0 ) [ log p θ ( x 0 ∣ x 1 ) ] − prior matching term E q ( x T − 1 ∣ x 0 ) [ D K L ( q ( x T ∣ x T − 1 ) ∥ p ( x T ) ) ] − t = 1 ∑ T − 1 consistency term E q ( x t − 1 , 1 x t + 1 ∣ x 0 ) [ D K L ( q ( x t ∣ x t − 1 ) ∥ p θ ( x t ∣ x t + 1 ) ) ]
ELBO的推导形式可以通过其各个组成部分进行解释:
第一项可以解释为重建项,预测给定第一步潜变量的原始数据样本的对数概率。这个项在普通的VAE中也出现过,并且可以类似地进行训练。
第二项是一个先验匹配项;当最终潜变量分布与高斯先验相匹配时,它被最小化。这个项不需要优化,因为它没有可训练的参数;此外,我们已经假设T足够大,以使最终分布为高斯分布,因此这个项实际上变为零。
第三项是一个一致性项;它努力使x t \mathbf{x}_t x t 处的分布在正向和反向过程中保持一致。也就是说,从一个更嘈杂的图像进行去噪步骤应该与从一个更清晰的图像进行相应的加噪步骤相匹配,对于每个中间时间步骤;这在数学上通过KL散度来反映。
在这个推导下,ELBO的所有项都被计算为期望,因此可以使用蒙特卡罗估计进行近似。然而,实际上使用我们刚刚推导的项来优化ELBO可能不是最优的;因为一致性项是在每个时间步骤上对两个随机变量 x t − 1 , x t + 1 \mathbf{x}_{t-1}, \mathbf{x}_{t+1} x t − 1 , x t + 1 的期望进行计算的,其蒙特卡罗估计的方差可能会比每个时间步骤只使用一个随机变量来估计的项的方差高。由于它是通过对 T − 1 T-1 T − 1 个一致性项进行求和来计算的,因此对于较大的T T T 值,ELBO的最终估计值可能具有很高的方差。
相反,让我们尝试推导一种形式的ELBO,其中每个项仅在一个随机变量上计算期望。关键是,我们可以将编码器转换重写为q ( x t ∣ x t − 1 ) = q ( x t ∣ x t − 1 , x 0 ) q(\mathbf{x}_t|\mathbf{x}_{t-1}) = q(\mathbf{x}_t|\mathbf{x}_{t-1}, \mathbf{x}_0) q ( x t ∣ x t − 1 ) = q ( x t ∣ x t − 1 , x 0 ) ,由于马尔可夫性质,额外的条件项是不必要的,但可以选择性质的加入。然后,根据贝叶斯定理,我们可以将每个转换重写为:
q ( x t ∣ x t − 1 , x 0 ) = q ( x t − 1 ∣ x t , x 0 ) q ( x t ∣ x 0 ) q ( x t − 1 ∣ x 0 ) q\left(\mathbf{x}_t | \mathbf{x}_{t-1}, \mathbf{x}_0\right)=\frac{q\left(\mathbf{x}_{t-1} | \mathbf{x}_t, \mathbf{x}_0\right) q\left(\mathbf{x}_t | \mathbf{x}_0\right)}{q\left(\mathbf{x}_{t-1} | \mathbf{x}_0\right)}
q ( x t ∣ x t − 1 , x 0 ) = q ( x t − 1 ∣ x 0 ) q ( x t − 1 ∣ x t , x 0 ) q ( x t ∣ x 0 )
重新推导ELBO,得到如下公式:
log p ( x ) ≥ E q ( x 1 : T ∣ x 0 ) [ log p ( x 0 : T ) q ( x 1 : T ∣ x 0 ) ] = E q ( x 1 : T ∣ x 0 ) [ log p ( x T ) ∏ t = 1 T p θ ( x t − 1 ∣ x t ) ∏ t = 1 T q ( x t ∣ x t − 1 ) ] = E q ( x 1 : T ∣ x 0 ) [ log p ( x T ) p θ ( x 0 ∣ x 1 ) ∏ t = 2 T p θ ( x t − 1 ∣ x t ) q ( x 1 ∣ x 0 ) ∏ t = 2 T q ( x t ∣ x t − 1 ) ] = E q ( x 1 : T ∣ x 0 ) [ log p ( x T ) p θ ( x 0 ∣ x 1 ) ∏ t = 1 T − 1 p θ ( x t ∣ x t + 1 ) q ( x 1 ∣ x 0 ) ∏ t = 2 T q ( x t ∣ x t − 1 , x 0 ) ] = E q ( x 1 : T ∣ x 0 ) [ log p ( x T ) p θ ( x 0 ∣ x 1 ) q ( x 1 ∣ x 0 ) + log ∏ t = 2 T p θ ( x t ∣ x t + 1 ) q ( x t ∣ x t − 1 , x 0 ) ] = E q ( x 1 : T ∣ x 0 ) [ log p ( x T ) p θ ( x 0 ∣ x 1 ) q ( x 1 ∣ x 0 ) + log ∏ t = 2 T p θ ( x t ∣ x t + 1 ) q ( x t − 1 ∣ x t , x 0 ) q ( x t ∣ x 0 ) q ( x t − 1 ∣ x 0 ) ] = E q ( x 1 : T ∣ x 0 ) [ log p ( x T ) p θ ( x 0 ∣ x 1 ) q ( x T ∣ x 0 ) + ∑ t = 2 T log p θ ( x t − 1 ∣ x t ) q ( x t − 1 ∣ x t , x 0 ) ] = E q ( x 1 : T ∣ x 0 ) [ log p θ ( x 0 ∣ x 1 ) ] + E q ( x 1 : T ∣ x 0 ) [ log p ( x T ) q ( x T ∣ x 0 ) ] + ∑ t = 2 T E q ( x 1 : T ∣ x 0 ) [ log p θ ( x t − 1 ∣ x t ) q ( x t − 1 ∣ x t , x 0 ) ] = E q ( x 1 ∣ x 0 ) [ log p θ ( x 0 ∣ x 1 ) ] + E q ( x T ∣ x 0 ) [ log p ( x T ) q ( x T ∣ x 0 ) ] + ∑ t = 2 T E q ( x t , x t − 1 ∣ x 0 ) [ log p θ ( x t − 1 ∣ x t ) q ( x t − 1 ∣ x t , x 0 ) ] = E q ( x 1 ∣ x 0 ) [ log p θ ( x 0 ∣ x 1 ) ] ⏟ reconstruction term − D K L ( q ( x T ∣ x 0 ) ∥ p ( x T ) ) ⏟ prior matching term − ∑ t = 2 T E q ( x t ∣ x 0 ) [ D K L ( q ( x t − 1 ∣ x t , x 0 ) ∥ p θ ( x t − 1 ∣ x t ) ) ] ⏟ denoising matching term \begin{aligned}
\log p(x)
& \geq \mathbb{E}_{q\left(\mathbf{x}_{1: T} | \mathbf{x}_0\right)}\left[\log \frac{p\left(\mathbf{x}_{0: T}\right)}{q\left(\mathbf{x}_{1: T} | \mathbf{x}_0\right)}\right] \\
& =\mathbb{E}_{q\left(\mathbf{x}_{1: T} | \mathbf{x}_0\right)}\left[\log \frac{p\left(\mathbf{x}_T\right) \prod_{t=1}^T p_{\theta}\left(\mathbf{x}_{t-1} | \mathbf{x}_t\right)}{\prod_{t=1}^T q\left(\mathbf{x}_t | \mathbf{x}_{t-1}\right)}\right] \\
& =\mathbb{E}_{q\left(\mathbf{x}_{1: T} | \mathbf{x}_0\right)}\left[\log \frac{p\left(\mathbf{x}_T\right) p_{\theta}\left(\mathbf{x}_0 | \mathbf{x}_1\right) \prod_{t=2}^T p_{\theta}\left(\mathbf{x}_{t-1} | \mathbf{x}_t\right)}{q\left(\mathbf{x}_1 | \mathbf{x}_{0}\right) \prod_{t=2}^{T} q\left(\mathbf{x}_t | \mathbf{x}_{t-1}\right)}\right] \\
& =\mathbb{E}_{q\left(\mathbf{x}_{1: T} | \mathbf{x}_0\right)}\left[\log \frac{p\left(\mathbf{x}_T\right) p_{\theta}\left(\mathbf{x}_0 | \mathbf{x}_1\right) \prod_{t=1}^{T-1} p_{\theta}\left(\mathbf{x}_t | \mathbf{x}_{t+1}\right)}{q\left(\mathbf{x}_1 | \mathbf{x}_{0}\right) \prod_{t=2}^{T} q\left(\mathbf{x}_t | \mathbf{x}_{t-1},\mathbf{x}_0\right)}\right] \\
& =\mathbb{E}_{q\left(\mathbf{x}_{1: T} | \mathbf{x}_0\right)}\left[\log \frac{p\left(\mathbf{x}_T\right) p_{\theta}\left(\mathbf{x}_0 | \mathbf{x}_1\right)}{q\left(\mathbf{x}_1 | \mathbf{x}_{0}\right)}+\log \prod_{t=2}^{T} \frac{p_{\theta}\left(\mathbf{x}_t | \mathbf{x}_{t+1}\right)}{q\left(\mathbf{x}_t | \mathbf{x}_{t-1},\mathbf{x}_0\right)}\right] \\
& =\mathbb{E}_{q\left(\mathbf{x}_{1: T} | \mathbf{x}_0\right)}\left[\log \frac{p\left(\mathbf{x}_T\right) p_{\theta}\left(\mathbf{x}_0 | \mathbf{x}_1\right)}{q\left(\mathbf{x}_1 | \mathbf{x}_{0}\right)}+\log \prod_{t=2}^{T} \frac{p_{\theta}\left(\mathbf{x}_t | \mathbf{x}_{t+1}\right)}{\frac{q\left(\mathbf{x}_{t-1} | \mathbf{x}_t, \mathbf{x}_0\right) q\left(\mathbf{x}_t | \mathbf{x}_0\right)}{q\left(\mathbf{x}_{t-1} | \mathbf{x}_0\right)}}\right] \\
& =\mathbb{E}_{q\left(\mathbf{x}_{1: T} | \mathbf{x}_0\right)}\left[\log \frac{p\left(\mathbf{x}_T\right) p_{\theta}\left(\mathbf{x}_0 | \mathbf{x}_1\right)}{q\left(\mathbf{x}_T | \mathbf{x}_0\right)}+\sum_{t=2}^T \log \frac{p_{\theta}\left(\mathbf{x}_{t-1} | \mathbf{x}_t\right)}{q\left(\mathbf{x}_{t-1} | \mathbf{x}_t, \mathbf{x}_0\right)}\right] \\
& =\mathbb{E}_{q\left(\mathbf{x}_{1: T} | \mathbf{x}_0\right)}\left[\log p_{\theta}\left(\mathbf{x}_0 | \mathbf{x}_1\right)\right]+\mathbb{E}_{q\left(\mathbf{x}_{1: T} | \mathbf{x}_0\right)}\left[\log \frac{p\left(\mathbf{x}_T\right)}{q\left(\mathbf{x}_T | \mathbf{x}_0\right)}\right]+\sum_{t=2}^T \mathbb{E}_{q\left(\mathbf{x}_{1: T} | \mathbf{x}_0\right)}\left[\log \frac{p_{\theta}\left(\mathbf{x}_{t-1} | \mathbf{x}_t\right)}{q\left(\mathbf{x}_{t-1} | \mathbf{x}_t, \mathbf{x}_0\right)}\right] \\
& =\mathbb{E}_{q\left(\mathbf{x}_1 | \mathbf{x}_0\right)}\left[\log p_{\theta}\left(\mathbf{x}_0 | \mathbf{x}_1\right)\right]+\mathbb{E}_{q\left(\mathbf{x}_T | \mathbf{x}_0\right)}\left[\log \frac{p\left(\mathbf{x}_T\right)}{q\left(\mathbf{x}_T | \mathbf{x}_0\right)}\right]+\sum_{t=2}^T \mathbb{E}_{q\left(\mathbf{x}_t, \mathbf{x}_{t-1} | \mathbf{x}_0\right)}\left[\log \frac{p_{\theta}\left(\mathbf{x}_{t-1} | \mathbf{x}_t\right)}{q\left(\mathbf{x}_{t-1} | \mathbf{x}_t, \mathbf{x}_0\right)}\right] \\
& =\underbrace{\mathbb{E}_{q\left(\mathbf{x}_1 | \mathbf{x}_0\right)}\left[\log p_{\theta}\left(\mathbf{x}_0 | \mathbf{x}_1\right)\right]}_{\text {reconstruction term }}-\underbrace{D_{\mathrm{KL}}\left(q\left(\mathbf{x}_T | \mathbf{x}_0\right) \| p\left(\mathbf{x}_T\right)\right)}_{\text {prior matching term }} \\
& \quad \quad -\sum_{t=2}^T \underbrace{\mathbb{E}_{q\left(\mathbf{x}_t | \mathbf{x}_0\right)}\left[D_{\mathrm{KL}}\left(q\left(\mathbf{x}_{t-1} | \mathbf{x}_t, \mathbf{x}_0\right) \| p_{\theta}\left(\mathbf{x}_{t-1} | \mathbf{x}_t\right)\right)\right]}_{\text {denoising matching term }} \\
&
\end{aligned}
log p ( x ) ≥ E q ( x 1 : T ∣ x 0 ) [ log q ( x 1 : T ∣ x 0 ) p ( x 0 : T ) ] = E q ( x 1 : T ∣ x 0 ) [ log ∏ t = 1 T q ( x t ∣ x t − 1 ) p ( x T ) ∏ t = 1 T p θ ( x t − 1 ∣ x t ) ] = E q ( x 1 : T ∣ x 0 ) [ log q ( x 1 ∣ x 0 ) ∏ t = 2 T q ( x t ∣ x t − 1 ) p ( x T ) p θ ( x 0 ∣ x 1 ) ∏ t = 2 T p θ ( x t − 1 ∣ x t ) ] = E q ( x 1 : T ∣ x 0 ) [ log q ( x 1 ∣ x 0 ) ∏ t = 2 T q ( x t ∣ x t − 1 , x 0 ) p ( x T ) p θ ( x 0 ∣ x 1 ) ∏ t = 1 T − 1 p θ ( x t ∣ x t + 1 ) ] = E q ( x 1 : T ∣ x 0 ) [ log q ( x 1 ∣ x 0 ) p ( x T ) p θ ( x 0 ∣ x 1 ) + log t = 2 ∏ T q ( x t ∣ x t − 1 , x 0 ) p θ ( x t ∣ x t + 1 ) ] = E q ( x 1 : T ∣ x 0 ) ⎣ ⎡ log q ( x 1 ∣ x 0 ) p ( x T ) p θ ( x 0 ∣ x 1 ) + log t = 2 ∏ T q ( x t − 1 ∣ x 0 ) q ( x t − 1 ∣ x t , x 0 ) q ( x t ∣ x 0 ) p θ ( x t ∣ x t + 1 ) ⎦ ⎤ = E q ( x 1 : T ∣ x 0 ) [ log q ( x T ∣ x 0 ) p ( x T ) p θ ( x 0 ∣ x 1 ) + t = 2 ∑ T log q ( x t − 1 ∣ x t , x 0 ) p θ ( x t − 1 ∣ x t ) ] = E q ( x 1 : T ∣ x 0 ) [ log p θ ( x 0 ∣ x 1 ) ] + E q ( x 1 : T ∣ x 0 ) [ log q ( x T ∣ x 0 ) p ( x T ) ] + t = 2 ∑ T E q ( x 1 : T ∣ x 0 ) [ log q ( x t − 1 ∣ x t , x 0 ) p θ ( x t − 1 ∣ x t ) ] = E q ( x 1 ∣ x 0 ) [ log p θ ( x 0 ∣ x 1 ) ] + E q ( x T ∣ x 0 ) [ log q ( x T ∣ x 0 ) p ( x T ) ] + t = 2 ∑ T E q ( x t , x t − 1 ∣ x 0 ) [ log q ( x t − 1 ∣ x t , x 0 ) p θ ( x t − 1 ∣ x t ) ] = reconstruction term E q ( x 1 ∣ x 0 ) [ log p θ ( x 0 ∣ x 1 ) ] − prior matching term D K L ( q ( x T ∣ x 0 ) ∥ p ( x T ) ) − t = 2 ∑ T denoising matching term E q ( x t ∣ x 0 ) [ D K L ( q ( x t − 1 ∣ x t , x 0 ) ∥ p θ ( x t − 1 ∣ x t ) ) ]
我们成功地推导出了ELBO的解释,可以用更低的方差进行估计,因为每个项最多只计算一个随机变量的期望。这个公式也有一个优雅的解释,在检查每个单独的项时可以看出:
第一项可以解释为重建项 L 0 L_0 L 0 。像普通VAE中的重建项一样,这个项可以使用蒙特卡罗估计进行近似和优化。
第二项表示最终噪声输入的分布与标准高斯先验的接近程度 L T L_T L T 。它没有可训练的参数,在我们的假设下也等于零。在DDPM中由于选择了固定方差,所以可以认为是个常数。
第三项是一个去噪匹配项 L t − 1 L_{t-1} L t − 1 。我们学习期望的去噪转换步骤 p θ ( x t − 1 ∣ x t ) p_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_t) p θ ( x t − 1 ∣ x t ) 作为可计算的、真实的去噪转换步骤 q ( x t − 1 ∣ x t , x 0 ) q(\mathbf{x}_{t-1}|\mathbf{x}_t, \mathbf{x}_0) q ( x t − 1 ∣ x t , x 0 ) 的近似。q ( x t − 1 ∣ x t , x 0 ) q(\mathbf{x}_{t-1}|\mathbf{x}_t, \mathbf{x}_0) q ( x t − 1 ∣ x t , x 0 ) 转换步骤可以作为一个真实的信号,因为它定义了如何去噪带有噪声的图像x t \mathbf{x}_t x t ,并有访问完全去噪的图像x 0 \mathbf{x}_0 x 0 应该是什么。因此,当两个去噪步骤的KL散度尽可能接近时,这个项被最小化。
在这个ELBO的推导中,大部分的优化成本在求和项中,接下来推导这个求和项里面的每一个项。先计算下面的后验概率,这个概率与神经网络无关,和设定的 β t \beta_t β t 有关,可以直接通过公式来计算均值和方差:
q ( x t − 1 ∣ x t , x 0 ) = N ( x t − 1 ; μ ~ t ( x t , x 0 ) , β ~ t I ) q(\mathbf{x}_{t-1} | \mathbf{x}_t, \mathbf{x}_0)=\mathcal{N}(\mathbf{x}_{t-1} ; \tilde{\mu}_t(\mathbf{x}_t, \mathbf{x}_0), \tilde{\beta}_t I)
q ( x t − 1 ∣ x t , x 0 ) = N ( x t − 1 ; μ ~ t ( x t , x 0 ) , β ~ t I )
μ ~ t ( x t , x 0 ) : = α ˉ t − 1 β t 1 − α ˉ t x 0 + α t ( 1 − α ˉ t − 1 ) 1 − α ˉ t x t a n d β ~ t : = 1 − α ˉ t − 1 1 − α ˉ t β t \quad \tilde{\mu}_t(\mathbf{x}_t, \mathbf{x}_0):=\frac{\sqrt{\bar{\alpha}_{t-1}} \beta_t}{1-\bar{\alpha}_t} \mathbf{x}_0+\frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_t} \mathbf{x}_t \quad and \quad \tilde{\beta}_t:=\frac{1-\bar{\alpha}_{t-1}}{1-\bar{\alpha}_t} \beta_t
μ ~ t ( x t , x 0 ) : = 1 − α ˉ t α ˉ t − 1 β t x 0 + 1 − α ˉ t α t ( 1 − α ˉ t − 1 ) x t a n d β ~ t : = 1 − α ˉ t 1 − α ˉ t − 1 β t
推导如下:
q ( x t − 1 ∣ x t , x 0 ) = q ( x t ∣ x t − 1 , x 0 ) q ( x t − 1 ∣ x 0 ) q ( x t ∣ x 0 ) = N ( x t ; α t x t − 1 , ( 1 − α t ) I ) N ( x t − 1 ; α ˉ t − 1 x 0 , ( 1 − α ˉ t − 1 ) I ) N ( x t ; α ˉ t x 0 , ( 1 − α ˉ t ) I ) ∝ exp { − [ ( x t − α t x t − 1 ) 2 2 ( 1 − α t ) + ( x t − 1 − α t − 1 x 0 ) 2 2 ( 1 − α ˉ t − 1 ) − ( x t − α ˉ t x 0 ) 2 2 ( 1 − α ˉ t ) ] } = exp { − 1 2 [ ( x t − α t x t − 1 ) 2 1 − α t + ( x t − 1 − α t − 1 x 0 ) 2 1 − α ˉ t − 1 − ( x t − α ˉ t x 0 ) 2 1 − α ˉ t ] } = exp { − 1 2 [ ( − 2 α t x t x t − 1 + α t x t − 1 2 ) 1 − α t + ( x t − 1 2 − 2 α ˉ t − 1 x t − 1 x 0 ) 1 − α ˉ t − 1 + C ( x t , x 0 ) ] } ∝ exp { − 1 2 [ − 2 α t x t x t − 1 1 − α t + α t x t − 1 2 1 − α t + x t − 1 2 1 − α ˉ t − 1 − 2 α ˉ t − 1 x t − 1 x 0 1 − α ˉ t − 1 ] } = exp { − 1 2 [ ( α t 1 − α t + 1 1 − α ˉ t − 1 ) x t − 1 2 − 2 ( α t x t 1 − α t + α t − 1 x 0 1 − α ˉ t − 1 ) x t − 1 ] } = exp { − 1 2 [ α t ( 1 − α ˉ t − 1 ) + 1 − α t ( 1 − α t ) ( 1 − α ˉ t − 1 ) x t − 1 2 − 2 ( α t x t 1 − α t + α t − 1 x 0 1 − α ˉ t − 1 ) x t − 1 ] } = exp { − 1 2 [ α t − α ˉ t + 1 − α t ( 1 − α t ) ( 1 − α ˉ t − 1 ) x t − 1 2 − 2 ( α t x t 1 − α t + α t − 1 x 0 1 − α ˉ t − 1 ) x t − 1 ] } = exp { − 1 2 [ 1 − α ˉ t ( 1 − α t ) ( 1 − α ˉ t − 1 ) x t − 1 2 − 2 ( α t x t 1 − α t + α t − 1 x 0 1 − α ˉ t − 1 ) x t − 1 ] } = exp { − 1 2 ( 1 − α ˉ t ( 1 − α t ) ( 1 − α ˉ t − 1 ) ) [ x t − 1 2 − 2 ( α t x t 1 − α t + α t − 1 x 0 1 − α ˉ t − 1 ) 1 − α ˉ t ( 1 − α t ) ( 1 − α ˉ t − 1 ) x t − 1 ] } = exp { − 1 2 ( 1 − α ˉ t ( 1 − α t ) ( 1 − α ˉ t − 1 ) ) [ x t − 1 2 − 2 ( α t x t 1 − α t + α t − 1 x 0 1 − α ˉ t − 1 ) ( 1 − α t ) ( 1 − α ˉ t − 1 ) 1 − α ˉ t x t − 1 ] } = exp { − 1 2 ( 1 ( 1 − α t ) ( 1 − α ˉ t − 1 ) 1 − α ˉ t ) [ x t − 1 2 − 2 α t ( 1 − α ˉ t − 1 ) x t + α ˉ t − 1 ( 1 − α t ) x 0 1 − α ˉ t x t − 1 ] } ∝ N ( x t − 1 ; α t ( 1 − α ˉ t − 1 ) x t + α ˉ t − 1 ( 1 − α t ) x 0 1 − α ˉ t ⏟ μ q ( x t , x 0 ) , ( 1 − α t ) ( 1 − α ˉ t − 1 ) 1 − α ˉ t I ) ⏟ Σ q ( t ) \begin{aligned}
& q\left(\mathbf{x}_{t-1} | \mathbf{x}_t, \mathbf{x}_0\right)=\frac{q\left(\mathbf{x}_t | \mathbf{x}_{t-1}, \mathbf{x}_0\right) q\left(\mathbf{x}_{t-1} | \mathbf{x}_0\right)}{q\left(\mathbf{x}_t | \mathbf{x}_0\right)} \\
& =\frac{\mathcal{N}\left(\mathbf{x}_t ; \sqrt{\alpha_t} \mathbf{x}_{t-1},\left(1-\alpha_t\right) I\right) \mathcal{N}\left(\mathbf{x}_{t-1} ; \sqrt{\bar{\alpha}_{t-1}} \mathbf{x}_0,\left(1-\bar{\alpha}_{t-1}\right) I\right)}{\mathcal{N}\left(\mathbf{x}_t ; \sqrt{\bar{\alpha}_t} \mathbf{x}_0,\left(1-\bar{\alpha}_t\right) I\right)} \\
& \propto \exp \left\{-\left[\frac{\left(\mathbf{x}_t-\sqrt{\alpha_t} \mathbf{x}_{t-1}\right)^2}{2\left(1-\alpha_t\right)}+\frac{\left(\mathbf{x}_{t-1}-\sqrt{\alpha_{t-1}} \mathbf{x}_0\right)^2}{2\left(1-\bar{\alpha}_{t-1}\right)}-\frac{\left(\mathbf{x}_t-\sqrt{\bar{\alpha}_t} \mathbf{x}_0\right)^2}{2\left(1-\bar{\alpha}_t\right)}\right]\right\} \\
& =\exp \left\{-\frac{1}{2}\left[\frac{\left(\mathbf{x}_t-\sqrt{\alpha_t} \mathbf{x}_{t-1}\right)^2}{1-\alpha_t}+\frac{\left(\mathbf{x}_{t-1}-\sqrt{\alpha_{t-1}} \mathbf{x}_0\right)^2}{1-\bar{\alpha}_{t-1}}-\frac{\left(\mathbf{x}_t-\sqrt{\bar{\alpha}_t} \mathbf{x}_0\right)^2}{1-\bar{\alpha}_t}\right]\right\} \\
& =\exp \left\{-\frac{1}{2}\left[\frac{\left(-2 \sqrt{\alpha_t} \mathbf{x}_t \mathbf{x}_{t-1}+\alpha_t \mathbf{x}_{t-1}^2\right)}{1-\alpha_t}+\frac{\left(\mathbf{x}_{t-1}^2-2 \sqrt{\bar{\alpha}_{t-1}} \mathbf{x}_{t-1} \mathbf{x}_0\right)}{1-\bar{\alpha}_{t-1}}+C\left(\mathbf{x}_t, \mathbf{x}_0\right)\right]\right\} \\
& \propto \exp \left\{-\frac{1}{2}\left[-\frac{2 \sqrt{\alpha_t} \mathbf{x}_t \mathbf{x}_{t-1}}{1-\alpha_t}+\frac{\alpha_t \mathbf{x}_{t-1}^2}{1-\alpha_t}+\frac{\mathbf{x}_{t-1}^2}{1-\bar{\alpha}_{t-1}}-\frac{2 \sqrt{\bar{\alpha}_{t-1}} \mathbf{x}_{t-1} \mathbf{x}_0}{1-\bar{\alpha}_{t-1}}\right]\right\} \\
& =\exp \left\{-\frac{1}{2}\left[\left(\frac{\alpha_t}{1-\alpha_t}+\frac{1}{1-\bar{\alpha}_{t-1}}\right) \mathbf{x}_{t-1}^2-2\left(\frac{\sqrt{\alpha_t} \mathbf{x}_t}{1-\alpha_t}+\frac{\sqrt{\alpha_{t-1}} \mathbf{x}_0}{1-\bar{\alpha}_{t-1}}\right) \mathbf{x}_{t-1}\right]\right\} \\
& =\exp \left\{-\frac{1}{2}\left[\frac{\alpha_t\left(1-\bar{\alpha}_{t-1}\right)+1-\alpha_t}{\left(1-\alpha_t\right)\left(1-\bar{\alpha}_{t-1}\right)} \mathbf{x}_{t-1}^2-2\left(\frac{\sqrt{\alpha_t} \mathbf{x}_t}{1-\alpha_t}+\frac{\sqrt{\alpha_{t-1}} \mathbf{x}_0}{1-\bar{\alpha}_{t-1}}\right) \mathbf{x}_{t-1}\right]\right\} \\
& =\exp \left\{-\frac{1}{2}\left[\frac{\alpha_t-\bar{\alpha}_t+1-\alpha_t}{\left(1-\alpha_t\right)\left(1-\bar{\alpha}_{t-1}\right)} \mathbf{x}_{t-1}^2-2\left(\frac{\sqrt{\alpha_t} \mathbf{x}_t}{1-\alpha_t}+\frac{\sqrt{\alpha_{t-1}} \mathbf{x}_0}{1-\bar{\alpha}_{t-1}}\right) \mathbf{x}_{t-1}\right]\right\} \\
& =\exp \left\{-\frac{1}{2}\left[\frac{1-\bar{\alpha}_t}{\left(1-\alpha_t\right)\left(1-\bar{\alpha}_{t-1}\right)} \mathbf{x}_{t-1}^2-2\left(\frac{\sqrt{\alpha_t} \mathbf{x}_t}{1-\alpha_t}+\frac{\sqrt{\alpha_{t-1}} \mathbf{x}_0}{1-\bar{\alpha}_{t-1}}\right) \mathbf{x}_{t-1}\right]\right\} \\
& =\exp \left\{-\frac{1}{2}\left(\frac{1-\bar{\alpha}_t}{\left(1-\alpha_t\right)\left(1-\bar{\alpha}_{t-1}\right)}\right)\left[\mathbf{x}_{t-1}^2-2 \frac{\left(\frac{\sqrt{\alpha_t} \mathbf{x}_t}{1-\alpha_t}+\frac{\sqrt{\alpha_{t-1}} \mathbf{x}_0}{1-\bar{\alpha}_{t-1}}\right)}{\frac{1-\bar{\alpha}_t}{\left(1-\alpha_t\right)\left(1-\bar{\alpha}_{t-1}\right)}} \mathbf{x}_{t-1}\right]\right\} \\
& =\exp \left\{-\frac{1}{2}\left(\frac{1-\bar{\alpha}_t}{\left(1-\alpha_t\right)\left(1-\bar{\alpha}_{t-1}\right)}\right)\left[\mathbf{x}_{t-1}^2-2 \frac{\left(\frac{\sqrt{\alpha_t} \mathbf{x}_t}{1-\alpha_t}+\frac{\sqrt{\alpha_{t-1}} \mathbf{x}_0}{1-\bar{\alpha}_{t-1}}\right)\left(1-\alpha_t\right)\left(1-\bar{\alpha}_{t-1}\right)}{1-\bar{\alpha}_t} \mathbf{x}_{t-1}\right]\right\} \\
& =\exp \left\{-\frac{1}{2}\left(\frac{1}{\frac{\left(1-\alpha_t\right)\left(1-\bar{\alpha}_{t-1}\right)}{1-\bar{\alpha}_t}}\right)\left[\mathbf{x}_{t-1}^2-2 \frac{\sqrt{\alpha_t}\left(1-\bar{\alpha}_{t-1}\right) \mathbf{x}_t+\sqrt{\bar{\alpha}_{t-1}}\left(1-\alpha_t\right) \mathbf{x}_0}{1-\bar{\alpha}_t} \mathbf{x}_{t-1}\right]\right\} \\
& \propto \mathcal{N}(\mathbf{x}_{t-1} ; \underbrace{\frac{\sqrt{\alpha_t}\left(1-\bar{\alpha}_{t-1}\right) \mathbf{x}_t+\sqrt{\bar{\alpha}_{t-1}}\left(1-\alpha_t\right) \mathbf{x}_0}{1-\bar{\alpha}_t}}_{\mu_q\left(\mathbf{x}_t, \mathbf{x}_0\right)}, \underbrace{\left.\frac{\left(1-\alpha_t\right)\left(1-\bar{\alpha}_{t-1}\right)}{1-\bar{\alpha}_t} I\right)}_{\Sigma_q(t)} \\
&
\end{aligned}
q ( x t − 1 ∣ x t , x 0 ) = q ( x t ∣ x 0 ) q ( x t ∣ x t − 1 , x 0 ) q ( x t − 1 ∣ x 0 ) = N ( x t ; α ˉ t x 0 , ( 1 − α ˉ t ) I ) N ( x t ; α t x t − 1 , ( 1 − α t ) I ) N ( x t − 1 ; α ˉ t − 1 x 0 , ( 1 − α ˉ t − 1 ) I ) ∝ exp { − [ 2 ( 1 − α t ) ( x t − α t x t − 1 ) 2 + 2 ( 1 − α ˉ t − 1 ) ( x t − 1 − α t − 1 x 0 ) 2 − 2 ( 1 − α ˉ t ) ( x t − α ˉ t x 0 ) 2 ] } = exp { − 2 1 [ 1 − α t ( x t − α t x t − 1 ) 2 + 1 − α ˉ t − 1 ( x t − 1 − α t − 1 x 0 ) 2 − 1 − α ˉ t ( x t − α ˉ t x 0 ) 2 ] } = exp { − 2 1 [ 1 − α t ( − 2 α t x t x t − 1 + α t x t − 1 2 ) + 1 − α ˉ t − 1 ( x t − 1 2 − 2 α ˉ t − 1 x t − 1 x 0 ) + C ( x t , x 0 ) ] } ∝ exp { − 2 1 [ − 1 − α t 2 α t x t x t − 1 + 1 − α t α t x t − 1 2 + 1 − α ˉ t − 1 x t − 1 2 − 1 − α ˉ t − 1 2 α ˉ t − 1 x t − 1 x 0 ] } = exp { − 2 1 [ ( 1 − α t α t + 1 − α ˉ t − 1 1 ) x t − 1 2 − 2 ( 1 − α t α t x t + 1 − α ˉ t − 1 α t − 1 x 0 ) x t − 1 ] } = exp { − 2 1 [ ( 1 − α t ) ( 1 − α ˉ t − 1 ) α t ( 1 − α ˉ t − 1 ) + 1 − α t x t − 1 2 − 2 ( 1 − α t α t x t + 1 − α ˉ t − 1 α t − 1 x 0 ) x t − 1 ] } = exp { − 2 1 [ ( 1 − α t ) ( 1 − α ˉ t − 1 ) α t − α ˉ t + 1 − α t x t − 1 2 − 2 ( 1 − α t α t x t + 1 − α ˉ t − 1 α t − 1 x 0 ) x t − 1 ] } = exp { − 2 1 [ ( 1 − α t ) ( 1 − α ˉ t − 1 ) 1 − α ˉ t x t − 1 2 − 2 ( 1 − α t α t x t + 1 − α ˉ t − 1 α t − 1 x 0 ) x t − 1 ] } = exp ⎩ ⎨ ⎧ − 2 1 ( ( 1 − α t ) ( 1 − α ˉ t − 1 ) 1 − α ˉ t ) ⎣ ⎡ x t − 1 2 − 2 ( 1 − α t ) ( 1 − α ˉ t − 1 ) 1 − α ˉ t ( 1 − α t α t x t + 1 − α ˉ t − 1 α t − 1 x 0 ) x t − 1 ⎦ ⎤ ⎭ ⎬ ⎫ = exp ⎩ ⎨ ⎧ − 2 1 ( ( 1 − α t ) ( 1 − α ˉ t − 1 ) 1 − α ˉ t ) ⎣ ⎡ x t − 1 2 − 2 1 − α ˉ t ( 1 − α t α t x t + 1 − α ˉ t − 1 α t − 1 x 0 ) ( 1 − α t ) ( 1 − α ˉ t − 1 ) x t − 1 ⎦ ⎤ ⎭ ⎬ ⎫ = exp { − 2 1 ( 1 − α ˉ t ( 1 − α t ) ( 1 − α ˉ t − 1 ) 1 ) [ x t − 1 2 − 2 1 − α ˉ t α t ( 1 − α ˉ t − 1 ) x t + α ˉ t − 1 ( 1 − α t ) x 0 x t − 1 ] } ∝ N ( x t − 1 ; μ q ( x t , x 0 ) 1 − α ˉ t α t ( 1 − α ˉ t − 1 ) x t + α ˉ t − 1 ( 1 − α t ) x 0 , Σ q ( t ) 1 − α ˉ t ( 1 − α t ) ( 1 − α ˉ t − 1 ) I )
接下来就是计算几个散度,根据假设ELBO里面的散度都是两个高斯分布间的散度,所以我们需要推导高斯分布的散度:
D K L ( N ( x ; μ x , Σ x ) ∥ N ( y ; μ y , Σ y ) ) = 1 2 [ log ∣ Σ y ∣ ∣ Σ x ∣ − d + tr ( Σ y − 1 Σ x ) + ( μ y − μ x ) T Σ y − 1 ( μ y − μ x ) ] D_{\mathrm{KL}}\left(\mathcal{N}\left(x ; \mu_x, \Sigma_x\right) \| \mathcal{N}\left(y ; \mu_y, \Sigma_y\right)\right)=\frac{1}{2}\left[\log \frac{\left|\Sigma_y\right|}{\left|\Sigma_x\right|}-d+\operatorname{tr}\left(\Sigma_y^{-1} \Sigma_x\right)+\left(\mu_y-\mu_x\right)^T \Sigma_y^{-1}\left(\mu_y-\mu_x\right)\right]
D K L ( N ( x ; μ x , Σ x ) ∥ N ( y ; μ y , Σ y ) ) = 2 1 [ log ∣ Σ x ∣ ∣ Σ y ∣ − d + t r ( Σ y − 1 Σ x ) + ( μ y − μ x ) T Σ y − 1 ( μ y − μ x ) ]
arg min θ D K L ( q ( x t − 1 ∣ x t , x 0 ) ∥ p θ ( x t − 1 ∣ x t ) ) = arg min θ D K L ( N ( x t − 1 ; μ q , Σ q ( t ) ) ∥ N ( x t − 1 ; μ θ , Σ q ( t ) ) ) = arg min θ 1 2 [ log ∣ Σ q ( t ) ∣ ∣ Σ q ( t ) ∣ − d + tr ( Σ q ( t ) − 1 Σ q ( t ) ) + ( μ θ − μ q ) T Σ q ( t ) − 1 ( μ θ − μ q ) ] = arg min θ 1 2 [ log 1 − d + d + ( μ θ − μ q ) T Σ q ( t ) − 1 ( μ θ − μ q ) ] = arg min θ 1 2 [ ( μ θ − μ q ) T Σ q ( t ) − 1 ( μ θ − μ q ) ] = arg min θ 1 2 [ ( μ θ − μ q ) T ( σ q 2 ( t ) I ) − 1 ( μ θ − μ q ) ] = arg min θ 1 2 σ q 2 ( t ) [ ∥ μ θ − μ q ∥ 2 2 ] \begin{aligned}
& \underset{\theta}{\arg \min } D_{\mathrm{KL}}\left(q\left(\mathbf{x}_{t-1} | \mathbf{x}_t, \mathbf{x}_0\right) \| p_{\theta}\left(\mathbf{x}_{t-1} | \mathbf{x}_t\right)\right) \\
& =\underset{\theta}{\arg \min } D_{\mathrm{KL}}\left(\mathcal{N}\left(\mathbf{x}_{t-1} ; \mu_q, \Sigma_q(t)\right) \| \mathcal{N}\left(\mathbf{x}_{t-1} ; \mu_{\theta}, \Sigma_q(t)\right)\right) \\
& =\underset{\theta}{\arg \min } \frac{1}{2}\left[\log \frac{\left|\Sigma_q(t)\right|}{\left|\Sigma_q(t)\right|}-d+\operatorname{tr}\left(\Sigma_q(t)^{-1} \Sigma_q(t)\right)+\left(\mu_{\theta}-\mu_q\right)^T \Sigma_q(t)^{-1}\left(\mu_{\theta}-\mu_q\right)\right] \\
& =\underset{\theta}{\arg \min } \frac{1}{2}\left[\log 1-d+d+\left(\mu_{\theta}-\mu_q\right)^T \Sigma_q(t)^{-1}\left(\mu_{\theta}-\mu_q\right)\right] \\
& =\underset{\theta}{\arg \min } \frac{1}{2}\left[\left(\mu_{\theta}-\mu_q\right)^T \Sigma_q(t)^{-1}\left(\mu_{\theta}-\mu_q\right)\right] \\
& =\underset{\theta}{\arg \min } \frac{1}{2}\left[\left(\mu_{\theta}-\mu_q\right)^T\left(\sigma_q^2(t) I\right)^{-1}\left(\mu_{\theta}-\mu_q\right)\right] \\
& =\underset{\theta}{\arg \min } \frac{1}{2 \sigma_q^2(t)}\left[\left\|\mu_{\theta}-\mu_q\right\|_2^2\right]
\end{aligned}
θ arg min D K L ( q ( x t − 1 ∣ x t , x 0 ) ∥ p θ ( x t − 1 ∣ x t ) ) = θ arg min D K L ( N ( x t − 1 ; μ q , Σ q ( t ) ) ∥ N ( x t − 1 ; μ θ , Σ q ( t ) ) ) = θ arg min 2 1 [ log ∣ Σ q ( t ) ∣ ∣ Σ q ( t ) ∣ − d + t r ( Σ q ( t ) − 1 Σ q ( t ) ) + ( μ θ − μ q ) T Σ q ( t ) − 1 ( μ θ − μ q ) ] = θ arg min 2 1 [ log 1 − d + d + ( μ θ − μ q ) T Σ q ( t ) − 1 ( μ θ − μ q ) ] = θ arg min 2 1 [ ( μ θ − μ q ) T Σ q ( t ) − 1 ( μ θ − μ q ) ] = θ arg min 2 1 [ ( μ θ − μ q ) T ( σ q 2 ( t ) I ) − 1 ( μ θ − μ q ) ] = θ arg min 2 σ q 2 ( t ) 1 [ ∥ μ θ − μ q ∥ 2 2 ]
这里我们用 μ q \mu_q μ q 表示μ q ( x t , x 0 ) \mu_q(\mathbf{x}_t, \mathbf{x}_0) μ q ( x t , x 0 ) ,用 μ θ \mu_\theta μ θ 表示 μ θ ( x t , t ) \mu_\theta(\mathbf{x}_t, t) μ θ ( x t , t ) 。换句话说,我们想要尝试优化并使用 μ θ ( x t , t ) \mu_\theta(\mathbf{x}_t, t) μ θ ( x t , t ) 来预测 μ q ( x t , x 0 ) \mu_q(\mathbf{x}_t, \mathbf{x}_0) μ q ( x t , x 0 ) 。此时我们有多种方式进行建模:
第一种方法是让神经网络直接输出后验分布的均值 μ θ ( x t , t ) \mu_\theta(\mathbf{x}_t, t) μ θ ( x t , t )
第二种方法就是直接预测μ ~ t ( x t , x 0 ) \tilde{\mu}_t(\mathbf{x}_t, \mathbf{x}_0) μ ~ t ( x t , x 0 ) 里面的x 0 \mathbf{x}_0 x 0 ,这种方法被DDPM证明效果很差(下面推导的方法)
第三种方法是把x 0 \mathbf{x}_0 x 0 代换为x t \mathbf{x}_t x t 和ϵ \epsilon ϵ ,然后神经网络去预测这个ϵ \epsilon ϵ 。(DDPM选择的方法)
去噪模型的形式为:
arg min θ D K L ( q ( x t − 1 ∣ x t , x 0 ) ∥ p θ ( x t − 1 ∣ x t ) ) = arg min θ D K L ( N ( x t − 1 ; μ q , Σ q ( t ) ) ∥ N ( x t − 1 ; μ θ , Σ q ( t ) ) ) = arg min θ 1 2 σ q 2 ( t ) [ ∥ α t ( 1 − α ˉ t − 1 ) x t + α ˉ t − 1 ( 1 − α t ) x ^ θ ( x t , t ) 1 − α ˉ t − α t ( 1 − α ˉ t − 1 ) x t + α ˉ t − 1 ( 1 − α t ) x 0 1 − α ˉ t ∥ 2 ] = arg min θ 1 2 σ q 2 ( t ) [ ∥ α ˉ t − 1 ( 1 − α t ) x ^ θ ( x t , t ) 1 − α ˉ t − α ˉ t − 1 ( 1 − α t ) x 0 1 − α ˉ t ∥ 2 2 ] = arg min θ 1 2 σ q 2 ( t ) [ ∥ α ˉ t − 1 ( 1 − α t ) 1 − α ˉ t ( x ^ θ ( x t , t ) − x 0 ) ∥ 2 2 ] = arg min θ 1 2 σ q 2 ( t ) α ˉ t − 1 ( 1 − α t ) 2 ( 1 − α ˉ t ) 2 [ ∥ x ^ θ ( x t , t ) − x 0 ∥ 2 2 ] = arg min θ 1 2 ( 1 − α t ) ( 1 − α ˉ t − 1 ) 1 − α ˉ t α ˉ t − 1 ( 1 − α t ) 2 ( 1 − α ˉ t ) 2 [ ∥ x ^ θ ( x t , t ) − x 0 ∥ 2 2 ] = arg min θ 1 2 1 − α ˉ t ( 1 − α t ) ( 1 − α ˉ t − 1 ) α ˉ t − 1 ( 1 − α t ) 2 ( 1 − α ˉ t ) 2 [ ∥ x ^ θ ( x t , t ) − x 0 ∥ 2 2 ] = arg min θ 1 2 α ˉ t − 1 ( 1 − α t ) ( 1 − α ˉ t − 1 ) ( 1 − α ˉ t ) [ ∥ x ^ θ ( x t , t ) − x 0 ∥ 2 2 ] = arg min θ 1 2 α ˉ t − 1 − α ˉ t ( 1 − α ˉ t − 1 ) ( 1 − α ˉ t ) [ ∥ x ^ θ ( x t , t ) − x 0 ∥ 2 2 ] = arg min θ 1 2 α ˉ t − 1 − α ˉ t − 1 α ˉ t + α ˉ t − 1 α ˉ t − α ˉ t ( 1 − α ˉ t − 1 ) ( 1 − α ˉ t ) [ ∥ x ^ θ ( x t , t ) − x 0 ∥ 2 2 ] = arg min θ 1 2 α ˉ t − 1 ( 1 − α ˉ t ) − α ˉ t ( 1 − α ˉ t − 1 ) ( 1 − α ˉ t − 1 ) ( 1 − α ˉ t ) [ ∥ x ^ θ ( x t , t ) − x 0 ∥ 2 2 ] = arg min θ 1 2 ( α ˉ t − 1 ( 1 − α ˉ t ) ( 1 − α ˉ t − 1 ) ( 1 − α ˉ t ) − α ˉ t ( 1 − α ˉ t − 1 ) ( 1 − α ˉ t − 1 ) ( 1 − α ˉ t ) ) [ ∥ x ^ θ ( x t , t ) − x 0 ∥ 2 2 ] = arg min θ 1 2 ( α ˉ t − 1 1 − α ˉ t − 1 − α ˉ t 1 − α ˉ t ) [ ∥ x ^ θ ( x t , t ) − x 0 ∥ 2 2 ] = arg min θ 1 2 ( S N R ( t − 1 ) − S N R ( t ) ) [ ∥ x ^ θ ( x t , t ) − x 0 ∥ 2 2 ] \begin{aligned}
& \underset{\theta}{\arg \min } D_{\mathrm{KL}}\left(q\left(\mathbf{x}_{t-1} | \mathbf{x}_t, \mathbf{x}_0\right) \| p_{\theta}\left(\mathbf{x}_{t-1} | \mathbf{x}_t\right)\right) \\
& =\underset{\theta}{\arg \min } D_{\mathrm{KL}}\left(\mathcal{N}\left(\mathbf{x}_{t-1} ; \mu_q, \Sigma_q(t)\right) \| \mathcal{N}\left(\mathbf{x}_{t-1} ; \mu_{\theta}, \Sigma_q(t)\right)\right) \\
& =\underset{\theta}{\arg \min } \frac{1}{2 \sigma_q^2(t)}\left[\left\|\frac{\sqrt{\alpha_t}\left(1-\bar{\alpha}_{t-1}\right) \mathbf{x}_t+\sqrt{\bar{\alpha}_{t-1}}\left(1-\alpha_t\right) \hat{x}_{\theta}\left(\mathbf{x}_t, t\right)}{1-\bar{\alpha}_t}-\frac{\sqrt{\alpha_t}\left(1-\bar{\alpha}_{t-1}\right) \mathbf{x}_t+\sqrt{\bar{\alpha}_{t-1}}\left(1-\alpha_t\right) \mathbf{x}_0}{1-\bar{\alpha}_t}\right\|_2\right] \\
& =\underset{\theta}{\arg \min } \frac{1}{2 \sigma_q^2(t)}\left[\left\|\frac{\sqrt{\bar{\alpha}_{t-1}}\left(1-\alpha_t\right) \hat{x}_{\theta}\left(\mathbf{x}_t, t\right)}{1-\bar{\alpha}_t}-\frac{\sqrt{\bar{\alpha}_{t-1}}\left(1-\alpha_t\right) \mathbf{x}_0}{1-\bar{\alpha}_t}\right\|_2^2\right] \\
& =\underset{\theta}{\arg \min } \frac{1}{2 \sigma_q^2(t)}\left[\left\|\frac{\sqrt{\bar{\alpha}_{t-1}}\left(1-\alpha_t\right)}{1-\bar{\alpha}_t}\left(\hat{x}_{\theta}\left(\mathbf{x}_t, t\right)-\mathbf{x}_0\right)\right\|_2^2\right] \\
& =\underset{\theta}{\arg \min } \frac{1}{2 \sigma_q^2(t)} \frac{\bar{\alpha}_{t-1}\left(1-\alpha_t\right)^2}{\left(1-\bar{\alpha}_t\right)^2}\left[\left\|\hat{x}_{\theta}\left(\mathbf{x}_t, t\right)-\mathbf{x}_0\right\|_2^2\right] \\
&=\underset{\theta}{\arg \min } \frac{1}{2 \frac{\left(1-\alpha_t\right)\left(1-\bar{\alpha}_{t-1}\right)}{1-\bar{\alpha}_t}} \frac{\bar{\alpha}_{t-1}\left(1-\alpha_t\right)^2}{\left(1-\bar{\alpha}_t\right)^2}\left[\left\|\hat{x}_{\theta}\left(\mathbf{x}_t, t\right)-\mathbf{x}_0\right\|_2^2\right] \\
& =\underset{\theta}{\arg \min } \frac{1}{2} \frac{1-\bar{\alpha}_t}{\left(1-\alpha_t\right)\left(1-\bar{\alpha}_{t-1}\right)} \frac{\bar{\alpha}_{t-1}\left(1-\alpha_t\right)^2}{\left(1-\bar{\alpha}_t\right)^2}\left[\left\|\hat{x}_{\theta}\left(\mathbf{x}_t, t\right)-\mathbf{x}_0\right\|_2^2\right] \\
& =\underset{\theta}{\arg \min }\frac{1}{2} \frac{\bar{\alpha}_{t-1}\left(1-\alpha_t\right)}{\left(1-\bar{\alpha}_{t-1}\right)\left(1-\bar{\alpha}_t\right)}\left[\left\|\hat{x}_{\theta}\left(\mathbf{x}_t, t\right)-\mathbf{x}_0\right\|_2^2\right] \\
& =\underset{\theta}{\arg \min }\frac{1}{2} \frac{\bar{\alpha}_{t-1}-\bar{\alpha}_t}{\left(1-\bar{\alpha}_{t-1}\right)\left(1-\bar{\alpha}_t\right)}\left[\left\|\hat{x}_{\theta}\left(\mathbf{x}_t, t\right)-\mathbf{x}_0\right\|_2^2\right] \\
& =\underset{\theta}{\arg \min } \frac{1}{2} \frac{\bar{\alpha}_{t-1}-\bar{\alpha}_{t-1} \bar{\alpha}_t+\bar{\alpha}_{t-1} \bar{\alpha}_t-\bar{\alpha}_t}{\left(1-\bar{\alpha}_{t-1}\right)\left(1-\bar{\alpha}_t\right)}\left[\left\|\hat{x}_{\theta}\left(\mathbf{x}_t, t\right)-\mathbf{x}_0\right\|_2^2\right] \\
& =\underset{\theta}{\arg \min } \frac{1}{2} \frac{\bar{\alpha}_{t-1}\left(1-\bar{\alpha}_t\right)-\bar{\alpha}_t\left(1-\bar{\alpha}_{t-1}\right)}{\left(1-\bar{\alpha}_{t-1}\right)\left(1-\bar{\alpha}_t\right)}\left[\left\|\hat{x}_{\theta}\left(\mathbf{x}_t, t\right)-\mathbf{x}_0\right\|_2^2\right] \\
& =\underset{\theta}{\arg \min } \frac{1}{2}\left(\frac{\bar{\alpha}_{t-1}\left(1-\bar{\alpha}_t\right)}{\left(1-\bar{\alpha}_{t-1}\right)\left(1-\bar{\alpha}_t\right)}-\frac{\bar{\alpha}_t\left(1-\bar{\alpha}_{t-1}\right)}{\left(1-\bar{\alpha}_{t-1}\right)\left(1-\bar{\alpha}_t\right)}\right)\left[\left\|\hat{x}_{\theta}\left(\mathbf{x}_t, t\right)-\mathbf{x}_0\right\|_2^2\right] \\
& =\underset{\theta}{\arg \min } \frac{1}{2}\left(\frac{\bar{\alpha}_{t-1}}{1-\bar{\alpha}_{t-1}}-\frac{\bar{\alpha}_t}{1-\bar{\alpha}_t}\right)\left[\left\|\hat{x}_{\theta}\left(\mathbf{x}_t, t\right)-\mathbf{x}_0\right\|_2^2\right] \\
& = \underset{\theta}{\arg \min } \frac{1}{2} \left(SNR(t-1)-SNR(t)\right)\left[\left\|\hat{x}_{\theta}\left(\mathbf{x}_t, t\right)-\mathbf{x}_0\right\|_2^2\right]
\end{aligned}
θ arg min D K L ( q ( x t − 1 ∣ x t , x 0 ) ∥ p θ ( x t − 1 ∣ x t ) ) = θ arg min D K L ( N ( x t − 1 ; μ q , Σ q ( t ) ) ∥ N ( x t − 1 ; μ θ , Σ q ( t ) ) ) = θ arg min 2 σ q 2 ( t ) 1 [ ∥ ∥ ∥ ∥ 1 − α ˉ t α t ( 1 − α ˉ t − 1 ) x t + α ˉ t − 1 ( 1 − α t ) x ^ θ ( x t , t ) − 1 − α ˉ t α t ( 1 − α ˉ t − 1 ) x t + α ˉ t − 1 ( 1 − α t ) x 0 ∥ ∥ ∥ ∥ 2 ] = θ arg min 2 σ q 2 ( t ) 1 [ ∥ ∥ ∥ ∥ 1 − α ˉ t α ˉ t − 1 ( 1 − α t ) x ^ θ ( x t , t ) − 1 − α ˉ t α ˉ t − 1 ( 1 − α t ) x 0 ∥ ∥ ∥ ∥ 2 2 ] = θ arg min 2 σ q 2 ( t ) 1 [ ∥ ∥ ∥ ∥ 1 − α ˉ t α ˉ t − 1 ( 1 − α t ) ( x ^ θ ( x t , t ) − x 0 ) ∥ ∥ ∥ ∥ 2 2 ] = θ arg min 2 σ q 2 ( t ) 1 ( 1 − α ˉ t ) 2 α ˉ t − 1 ( 1 − α t ) 2 [ ∥ x ^ θ ( x t , t ) − x 0 ∥ 2 2 ] = θ arg min 2 1 − α ˉ t ( 1 − α t ) ( 1 − α ˉ t − 1 ) 1 ( 1 − α ˉ t ) 2 α ˉ t − 1 ( 1 − α t ) 2 [ ∥ x ^ θ ( x t , t ) − x 0 ∥ 2 2 ] = θ arg min 2 1 ( 1 − α t ) ( 1 − α ˉ t − 1 ) 1 − α ˉ t ( 1 − α ˉ t ) 2 α ˉ t − 1 ( 1 − α t ) 2 [ ∥ x ^ θ ( x t , t ) − x 0 ∥ 2 2 ] = θ arg min 2 1 ( 1 − α ˉ t − 1 ) ( 1 − α ˉ t ) α ˉ t − 1 ( 1 − α t ) [ ∥ x ^ θ ( x t , t ) − x 0 ∥ 2 2 ] = θ arg min 2 1 ( 1 − α ˉ t − 1 ) ( 1 − α ˉ t ) α ˉ t − 1 − α ˉ t [ ∥ x ^ θ ( x t , t ) − x 0 ∥ 2 2 ] = θ arg min 2 1 ( 1 − α ˉ t − 1 ) ( 1 − α ˉ t ) α ˉ t − 1 − α ˉ t − 1 α ˉ t + α ˉ t − 1 α ˉ t − α ˉ t [ ∥ x ^ θ ( x t , t ) − x 0 ∥ 2 2 ] = θ arg min 2 1 ( 1 − α ˉ t − 1 ) ( 1 − α ˉ t ) α ˉ t − 1 ( 1 − α ˉ t ) − α ˉ t ( 1 − α ˉ t − 1 ) [ ∥ x ^ θ ( x t , t ) − x 0 ∥ 2 2 ] = θ arg min 2 1 ( ( 1 − α ˉ t − 1 ) ( 1 − α ˉ t ) α ˉ t − 1 ( 1 − α ˉ t ) − ( 1 − α ˉ t − 1 ) ( 1 − α ˉ t ) α ˉ t ( 1 − α ˉ t − 1 ) ) [ ∥ x ^ θ ( x t , t ) − x 0 ∥ 2 2 ] = θ arg min 2 1 ( 1 − α ˉ t − 1 α ˉ t − 1 − 1 − α ˉ t α ˉ t ) [ ∥ x ^ θ ( x t , t ) − x 0 ∥ 2 2 ] = θ arg min 2 1 ( S N R ( t − 1 ) − S N R ( t ) ) [ ∥ x ^ θ ( x t , t ) − x 0 ∥ 2 2 ]
其中信噪比(SNR)定义如下,因为q ( x t ∣ x 0 ) : = N ( x t ; α t ˉ x 0 , ( 1 − α t ˉ ) I ) q(\mathbf{x}_t | \mathbf{x}_{0}):=\mathcal{N}(\mathbf{x}_t ; \sqrt{\bar{\alpha_t}} \mathbf{x}_{0}, (1-\bar{\alpha_t}) I) q ( x t ∣ x 0 ) : = N ( x t ; α t ˉ x 0 , ( 1 − α t ˉ ) I ) :
S N R ( t ) = μ 2 σ 2 = α ˉ t 1 − α ˉ t SNR(t)=\frac{\mu^2}{\sigma^2}=\frac{\bar{\alpha}_t}{1-\bar{\alpha}_{t}}
S N R ( t ) = σ 2 μ 2 = 1 − α ˉ t α ˉ t
信噪比代表了原始信号与存在的噪声量之间的比率。较高的信噪比代表较多的信号,较低的信噪比代表较多的噪声。在扩散模型中,我们要求信噪比随着时间步长t t t 的增加而单调减小,我们可以使用神经网络直接参数化每个时间步的信噪比,并与扩散模型一起学习它。
α ˉ t 1 − α ˉ t = exp ( − ω η ( t ) ) α ˉ t = sigmoid ( − ω η ( t ) ) 1 − α ˉ t = sigmoid ( ω η ( t ) ) \begin{aligned}
\frac{\bar{\alpha}_t}{1-\bar{\alpha}_t}=\exp \left(-\omega_{\boldsymbol{\eta}}(t)\right) \\
\bar{\alpha}_t=\operatorname{sigmoid}\left(-\omega_{\boldsymbol{\eta}}(t)\right) \\
1-\bar{\alpha}_t=\operatorname{sigmoid}\left(\omega_{\boldsymbol{\eta}}(t)\right)
\end{aligned}
1 − α ˉ t α ˉ t = exp ( − ω η ( t ) ) α ˉ t = s i g m o i d ( − ω η ( t ) ) 1 − α ˉ t = s i g m o i d ( ω η ( t ) )
此外,最小化我们导出的ELBO目标在所有噪声水平上的求和项可以通过最小化所有时间步上的期望来逼近,然后可以在时间步长上使用随机样本进行优化:
arg min θ E t ∼ U { 2 , T } [ E q ( x t ∣ x 0 ) [ D K L ( q ( x t − 1 ∣ x t , x 0 ) ∥ p θ ( x t − 1 ∣ x t ) ) ] ] \underset{\theta}{\arg \min } \mathbb{E}_{t \sim U\{2, T\}}\left[\mathbb{E}_{q\left(\mathbf{x}_t | \mathbf{x}_0\right)}\left[D_{\mathrm{KL}}\left(q\left(\mathbf{x}_{t-1} | \mathbf{x}_t, \mathbf{x}_0\right) \| p_{\theta}\left(\mathbf{x}_{t-1} | \mathbf{x}_t\right)\right)\right]\right]
θ arg min E t ∼ U { 2 , T } [ E q ( x t ∣ x 0 ) [ D K L ( q ( x t − 1 ∣ x t , x 0 ) ∥ p θ ( x t − 1 ∣ x t ) ) ] ]
优化的等价描述(DDPM)
另外,将x t = α ˉ t x 0 + 1 − α ˉ t ϵ 0 \mathbf{x}_t=\sqrt{\bar{\alpha}_t} \mathbf{x}_0+\sqrt{1-\bar{\alpha}_t} \epsilon_0 x t = α ˉ t x 0 + 1 − α ˉ t ϵ 0 带入可以得到:
μ q ( x t , x 0 ) = α t ( 1 − α ˉ t − 1 ) x t + α ˉ t − 1 ( 1 − α t ) x 0 1 − α ˉ t = α t ( 1 − α ˉ t − 1 ) x t + α ˉ t − 1 ( 1 − α t ) x t − 1 − α ˉ t ϵ 0 α ˉ t 1 − α ˉ t = α t ( 1 − α ˉ t − 1 ) x t + ( 1 − α t ) x t − 1 − α ˉ t ϵ 0 α t 1 − α ˉ t = α t ( 1 − α ˉ t − 1 ) x t 1 − α ˉ t + ( 1 − α t ) x t ( 1 − α ˉ t ) α t − ( 1 − α t ) 1 − α ˉ t ϵ 0 ( 1 − α ˉ t ) α t = ( α t ( 1 − α ˉ t − 1 ) 1 − α ˉ t + 1 − α t ( 1 − α ˉ t ) α t ) x t − ( 1 − α t ) 1 − α ˉ t ( 1 − α ˉ t ) α t ϵ 0 = ( α t ( 1 − α ˉ t − 1 ) ( 1 − α ˉ t ) α t + 1 − α t ( 1 − α ˉ t ) α t ) x t − 1 − α t 1 − α ˉ t α t ϵ 0 = α t − α ˉ t + 1 − α t ( 1 − α ˉ t ) α t x t − 1 − α t 1 − α ˉ t α t ϵ 0 = 1 − α ˉ t ( 1 − α ˉ t ) α t x t − 1 − α t 1 − α ˉ t α t ϵ 0 = 1 α t x t − 1 − α t 1 − α ˉ t α t ϵ 0 = 1 α t ( x t − 1 − α t 1 − α ˉ t ϵ 0 ) \begin{aligned}
\mu_q\left(\mathbf{x}_t, \mathbf{x}_0\right) & =\frac{\sqrt{\alpha_t}\left(1-\bar{\alpha}_{t-1}\right) \mathbf{x}_t+\sqrt{\bar{\alpha}_{t-1}}\left(1-\alpha_t\right) \mathbf{x}_0}{1-\bar{\alpha}_t} \\
& =\frac{\sqrt{\alpha_t}\left(1-\bar{\alpha}_{t-1}\right) \mathbf{x}_t+\sqrt{\bar{\alpha}_{t-1}}\left(1-\alpha_t\right) \frac{\mathbf{x}_t-\sqrt{1-\bar{\alpha}_t} \epsilon_0}{\sqrt{\bar{\alpha}_t}}}{1-\bar{\alpha}_t} \\
& =\frac{\sqrt{\alpha_t}\left(1-\bar{\alpha}_{t-1}\right) \mathbf{x}_t+\left(1-\alpha_t\right) \frac{\mathbf{x}_t-\sqrt{1-\bar{\alpha}_t} \epsilon_0}{\sqrt{\alpha_t}}}{1-\bar{\alpha}_t} \\
& =\frac{\sqrt{\alpha_t}\left(1-\bar{\alpha}_{t-1}\right) \mathbf{x}_t}{1-\bar{\alpha}_t}+\frac{\left(1-\alpha_t\right) \mathbf{x}_t}{\left(1-\bar{\alpha}_t\right) \sqrt{\alpha_t}}-\frac{\left(1-\alpha_t\right) \sqrt{1-\bar{\alpha}_t} \epsilon_0}{\left(1-\bar{\alpha}_t\right) \sqrt{\alpha_t}} \\
& =\left(\frac{\sqrt{\alpha_t}\left(1-\bar{\alpha}_{t-1}\right)}{1-\bar{\alpha}_t}+\frac{1-\alpha_t}{\left(1-\bar{\alpha}_t\right) \sqrt{\alpha_t}}\right) \mathbf{x}_t-\frac{\left(1-\alpha_t\right) \sqrt{1-\bar{\alpha}_t}}{\left(1-\bar{\alpha}_t\right) \sqrt{\alpha_t}} \epsilon_0 \\
& =\left(\frac{\alpha_t\left(1-\bar{\alpha}_{t-1}\right)}{\left(1-\bar{\alpha}_t\right) \sqrt{\alpha_t}}+\frac{1-\alpha_t}{\left(1-\bar{\alpha}_t\right) \sqrt{\alpha_t}}\right) \mathbf{x}_t-\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t} \sqrt{\alpha_t}} \epsilon_0 \\
& =\frac{\alpha_t-\bar{\alpha}_t+1-\alpha_t}{\left(1-\bar{\alpha}_t\right) \sqrt{\alpha_t}} \mathbf{x}_t-\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t} \sqrt{\alpha_t}} \epsilon_0 \\
& =\frac{1-\bar{\alpha}_t}{\left(1-\bar{\alpha}_t\right) \sqrt{\alpha_t}} \mathbf{x}_t-\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t} \sqrt{\alpha_t}} \epsilon_0 \\
& =\frac{1}{\sqrt{\alpha_t}} \mathbf{x}_t-\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t} \sqrt{\alpha_t}} \epsilon_0 \\
& =\frac{1}{\sqrt{\alpha_t}}\left(\mathrm{x}_t-\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}} \epsilon_0\right)
\end{aligned}
μ q ( x t , x 0 ) = 1 − α ˉ t α t ( 1 − α ˉ t − 1 ) x t + α ˉ t − 1 ( 1 − α t ) x 0 = 1 − α ˉ t α t ( 1 − α ˉ t − 1 ) x t + α ˉ t − 1 ( 1 − α t ) α ˉ t x t − 1 − α ˉ t ϵ 0 = 1 − α ˉ t α t ( 1 − α ˉ t − 1 ) x t + ( 1 − α t ) α t x t − 1 − α ˉ t ϵ 0 = 1 − α ˉ t α t ( 1 − α ˉ t − 1 ) x t + ( 1 − α ˉ t ) α t ( 1 − α t ) x t − ( 1 − α ˉ t ) α t ( 1 − α t ) 1 − α ˉ t ϵ 0 = ( 1 − α ˉ t α t ( 1 − α ˉ t − 1 ) + ( 1 − α ˉ t ) α t 1 − α t ) x t − ( 1 − α ˉ t ) α t ( 1 − α t ) 1 − α ˉ t ϵ 0 = ( ( 1 − α ˉ t ) α t α t ( 1 − α ˉ t − 1 ) + ( 1 − α ˉ t ) α t 1 − α t ) x t − 1 − α ˉ t α t 1 − α t ϵ 0 = ( 1 − α ˉ t ) α t α t − α ˉ t + 1 − α t x t − 1 − α ˉ t α t 1 − α t ϵ 0 = ( 1 − α ˉ t ) α t 1 − α ˉ t x t − 1 − α ˉ t α t 1 − α t ϵ 0 = α t 1 x t − 1 − α ˉ t α t 1 − α t ϵ 0 = α t 1 ( x t − 1 − α ˉ t 1 − α t ϵ 0 )
我们尝试使用 μ θ ( x t , t ) \mu_\theta(\mathbf{x}_t, t) μ θ ( x t , t ) 来预测 μ q ( x t , x 0 ) \mu_q(\mathbf{x}_t, \mathbf{x}_0) μ q ( x t , x 0 ) ,因为 x t \mathbf{x}_t x t 在训练时可作为输入,我们可以重新参数化高斯噪声项以使其预测ϵ 0 \epsilon_0 ϵ 0 :
μ θ ( x t , t ) = 1 α t ( x t − 1 − α t 1 − α ˉ t ϵ θ ( x t , t ) ) \mu_\theta(\mathbf{x}_t, t) = \frac{1}{\sqrt{\alpha_t}} \left( \mathbf{x}_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar{\alpha}_t}} \epsilon_\theta(\mathbf{x}_t, t) \right)
μ θ ( x t , t ) = α t 1 ( x t − 1 − α ˉ t 1 − α t ϵ θ ( x t , t ) )
x t − 1 = N ( x t − 1 ; 1 α t ( x t − 1 − α t 1 − α ˉ t ϵ θ ( x t , t ) ) , Σ θ ( x t , t ) ) \mathbf{x}_{t-1} = \mathcal{N}\left(\mathbf{x}_{t-1}; \frac{1}{\sqrt{\alpha_t}} \left( \mathbf{x}_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar{\alpha}_t}} \epsilon_\theta(\mathbf{x}_t, t) \right), \Sigma_\theta(\mathbf{x}_t, t)\right)
x t − 1 = N ( x t − 1 ; α t 1 ( x t − 1 − α ˉ t 1 − α t ϵ θ ( x t , t ) ) , Σ θ ( x t , t ) )
这里,ϵ θ ( x t , t ) \epsilon_{\theta}(\mathbf{x}_t, t) ϵ θ ( x t , t ) 是一个学习预测源噪声 ϵ 0 \epsilon_0 ϵ 0 的神经网络,DDPM 采用了一个 U-Net 结构的 Autoencoder 来对噪声进行预测。这个式子说明通过预测原始图像 x 0 \mathbf{x}_0 x 0 来学习diffusion等同于学习预测噪声。此外DDPM丢弃了前面的系数,此时的损失描述如下:
L simple ( θ ) : = E t , x 0 , ϵ [ ∥ ϵ − ϵ θ ( α ˉ t x 0 + 1 − α ˉ t ϵ , t ) ∥ 2 ] L_{\text {simple }}(\theta):=\mathbb{E}_{t, \mathbf{x}_0, \epsilon}\left[\left\|\epsilon-\epsilon_\theta\left(\sqrt{\bar{\alpha}_t} \mathbf{x}_0+\sqrt{1-\bar{\alpha}_t} \epsilon, t\right)\right\|^2\right]
L simple ( θ ) : = E t , x 0 , ϵ [ ∥ ∥ ϵ − ϵ θ ( α ˉ t x 0 + 1 − α ˉ t ϵ , t ) ∥ ∥ 2 ]
优化的等价描述(Score-based Generative Model)
我们先推导Tweedie公式。在参数估计里面,经典贝叶斯派的公式很多,Tweedie公式就是其中一个。假设p ( x ∣ θ ) = N ( θ , σ 2 ) p(x|\theta)=\mathcal{N}(\theta,\sigma^2) p ( x ∣ θ ) = N ( θ , σ 2 ) 。
E [ θ ∣ x ] = ∫ θ p ( θ ∣ x ) d θ = ∫ θ p ( x ∣ θ ) p ( θ ) p ( x ) d θ = ∫ θ p ( x ∣ θ ) p ( θ ) d θ p ( x ) = ∫ θ 1 2 π σ 2 e − ( x − θ ) 2 2 σ 2 p ( θ ) d θ p ( x ) = ∫ [ σ 2 θ − x σ 2 1 2 π σ 2 e − ( x − θ ) 2 2 σ 2 p ( θ ) + x 1 2 π σ 2 e − ( x − θ ) 2 2 σ 2 p ( θ ) ] d θ p ( x ) = ∫ σ 2 θ − x σ 2 1 2 π σ 2 e − ( x − θ ) 2 2 σ 2 p ( θ ) d θ + ∫ x 1 2 π σ 2 e − ( x − θ ) 2 2 σ 2 p ( θ ) d θ p ( x ) = σ 2 ∫ d [ 1 2 π σ 2 e − ( x − θ ) 2 2 σ 2 ] d x p ( θ ) d θ + ∫ x 1 2 π σ 2 e − ( x − θ ) 2 2 σ 2 p ( θ ) d θ p ( x ) = σ 2 ∫ d p ( x ∣ θ ) d x p ( θ ) d θ + ∫ x p ( x ∣ θ ) p ( θ ) d θ p ( x ) = σ 2 d d x ∫ p ( x ∣ θ ) p ( θ ) d θ + x ∫ p ( x ∣ θ ) p ( θ ) d θ p ( x ) = σ 2 d p ( x ) d x + x p ( x ) p ( x ) = x + σ 2 d d x log p ( x ) \begin{aligned}
\mathbb{E}[\theta | x]& =\int \theta p(\theta | x) \mathrm{d} \theta \\
& =\int \theta \frac{p(x | \theta) p(\theta)}{p(x)} \mathrm{d} \theta \\
& =\frac{\int \theta p(x | \theta) p(\theta) \mathrm{d} \theta}{p(x)} \\
& =\frac{\int \theta \frac{1}{\sqrt{2 \pi \sigma^2}} e^{-\frac{(x-\theta)^2}{2 \sigma^2}} p(\theta) \mathrm{d} \theta}{p(x)} \\
& =\frac{\int\left[\sigma^2 \frac{\theta-x}{\sigma^2} \frac{1}{\sqrt{2 \pi \sigma^2}} e^{-\frac{(x-\theta)^2}{2 \sigma^2}} p(\theta)+x \frac{1}{\sqrt{2 \pi \sigma^2}} e^{-\frac{(x-\theta)^2}{2 \sigma^2}} p(\theta)\right] \mathrm{d} \theta}{p(x)} \\
& =\frac{\int \sigma^2 \frac{\theta-x}{\sigma^2} \frac{1}{\sqrt{2 \pi \sigma^2}} e^{-\frac{(x-\theta)^2}{2 \sigma^2}} p(\theta) \mathrm{d} \theta+\int x \frac{1}{\sqrt{2 \pi \sigma^2}} e^{-\frac{(x-\theta)^2}{2 \sigma^2}} p(\theta) \mathrm{d} \theta}{p(x)} \\
& =\frac{\sigma^2 \int \frac{\mathrm{d}\left[\frac{1}{\sqrt{2 \pi \sigma^2}} e^{-\frac{(x-\theta)^2}{2 \sigma^2}}\right]}{\mathrm{d} x} p(\theta) \mathrm{d} \theta+\int x \frac{1}{\sqrt{2 \pi \sigma^2}} e^{-\frac{(x-\theta)^2}{2 \sigma^2}} p(\theta) \mathrm{d} \theta}{p(x)} \\
& =\frac{\sigma^2 \int \frac{\mathrm{d} p(x | \theta)}{\mathrm{d} x} p(\theta) \mathrm{d} \theta+\int x p(x | \theta) p(\theta) \mathrm{d} \theta}{p(x)} \\
& =\frac{\sigma^2 \frac{\mathrm{d}}{\mathrm{d} x} \int p(x | \theta) p(\theta) \mathrm{d} \theta+x \int p(x | \theta) p(\theta) \mathrm{d} \theta}{p(x)} \\
& =\frac{\sigma^2 \frac{\mathrm{d} p(x)}{\mathrm{d} x}+x p(x)}{p(x)} \\
& =x+\sigma^2 \frac{\mathrm{d}}{\mathrm{d} x} \log p(x) \\
&
\end{aligned}
E [ θ ∣ x ] = ∫ θ p ( θ ∣ x ) d θ = ∫ θ p ( x ) p ( x ∣ θ ) p ( θ ) d θ = p ( x ) ∫ θ p ( x ∣ θ ) p ( θ ) d θ = p ( x ) ∫ θ 2 π σ 2 1 e − 2 σ 2 ( x − θ ) 2 p ( θ ) d θ = p ( x ) ∫ [ σ 2 σ 2 θ − x 2 π σ 2 1 e − 2 σ 2 ( x − θ ) 2 p ( θ ) + x 2 π σ 2 1 e − 2 σ 2 ( x − θ ) 2 p ( θ ) ] d θ = p ( x ) ∫ σ 2 σ 2 θ − x 2 π σ 2 1 e − 2 σ 2 ( x − θ ) 2 p ( θ ) d θ + ∫ x 2 π σ 2 1 e − 2 σ 2 ( x − θ ) 2 p ( θ ) d θ = p ( x ) σ 2 ∫ d x d [ 2 π σ 2 1 e − 2 σ 2 ( x − θ ) 2 ] p ( θ ) d θ + ∫ x 2 π σ 2 1 e − 2 σ 2 ( x − θ ) 2 p ( θ ) d θ = p ( x ) σ 2 ∫ d x d p ( x ∣ θ ) p ( θ ) d θ + ∫ x p ( x ∣ θ ) p ( θ ) d θ = p ( x ) σ 2 d x d ∫ p ( x ∣ θ ) p ( θ ) d θ + x ∫ p ( x ∣ θ ) p ( θ ) d θ = p ( x ) σ 2 d x d p ( x ) + x p ( x ) = x + σ 2 d x d log p ( x )
在diffusion的公式里面,我们应用它来预测给定样本的 x t \mathbf{x}_t x t 的真实后验均值:
E [ μ x t ∣ x t ] = x t + ( 1 − α ˉ t ) ∇ x t log p ( x t ) \mathbb{E}\left[\mu_{\mathbf{x}_t} | \mathbf{x}_t\right]=\mathbf{x}_t+\left(1-\bar{\alpha}_t\right) \nabla_{\mathbf{x}_t} \log p\left(\mathbf{x}_t\right)
E [ μ x t ∣ x t ] = x t + ( 1 − α ˉ t ) ∇ x t log p ( x t )
根据Tweedie’s Formula,对 x t \mathbf{x}_t x t 的真实均值的最佳估计为μ x t = α ˉ t x 0 \mu_{\mathbf{x}_t} =\sqrt{\bar\alpha_t}\mathbf{x}_0 μ x t = α ˉ t x 0 :
x 0 = x t + ( 1 − α ˉ t ) ∇ log p ( x t ) α ˉ t \mathbf{x}_0=\frac{\mathbf{x}_t+\left(1-\bar{\alpha}_t\right) \nabla \log p\left(\mathbf{x}_t\right)}{\sqrt{\bar{\alpha}_t}}
x 0 = α ˉ t x t + ( 1 − α ˉ t ) ∇ log p ( x t )
那么就能推导出新的公式:
μ q ( x t , x 0 ) = α t ( 1 − α ˉ t − 1 ) x t + α ˉ t − 1 ( 1 − α t ) x 0 1 − α ˉ t = α t ( 1 − α ˉ t − 1 ) x t + α ˉ t − 1 ( 1 − α t ) x t + ( 1 − α ˉ t ) ∇ log p ( x t ) α ˉ t 1 − α ˉ t = α t ( 1 − α ˉ t − 1 ) x t + ( 1 − α t ) x t + ( 1 − α ˉ t ) ∇ log p ( x t ) α t 1 − α ˉ t = α t ( 1 − α ˉ t − 1 ) x t 1 − α ˉ t + ( 1 − α t ) x t ( 1 − α ˉ t ) α t + ( 1 − α t ) ( 1 − α ˉ t ) ∇ log p ( x t ) ( 1 − α ˉ t ) α t = ( α t ( 1 − α ˉ t − 1 ) 1 − α ˉ t + 1 − α t ( 1 − α ˉ t ) α t ) x t + 1 − α t α t ∇ log p ( x t ) = ( α t ( 1 − α ˉ t − 1 ) ( 1 − α ˉ t ) α t + 1 − α t ( 1 − α ˉ t ) α t ) x t + 1 − α t α t ∇ log p ( x t ) = α t − α ˉ t + 1 − α t ( 1 − α ˉ t ) α t x t + 1 − α t α t ∇ log p ( x t ) = 1 − α ˉ t ( 1 − α ˉ t ) α t x t + 1 − α t α t ∇ log p ( x t ) = 1 α t x t + 1 − α t α t ∇ log p ( x t ) \begin{aligned}
\mu_q\left(\mathbf{x}_t, \mathbf{x}_0\right) & =\frac{\sqrt{\alpha_t}\left(1-\bar{\alpha}_{t-1}\right) \mathbf{x}_t+\sqrt{\bar{\alpha}_{t-1}}\left(1-\alpha_t\right) \mathbf{x}_0}{1-\bar{\alpha}_t} \\
& =\frac{\sqrt{\alpha_t}\left(1-\bar{\alpha}_{t-1}\right) \mathbf{x}_t+\sqrt{\bar{\alpha}_{t-1}}\left(1-\alpha_t\right) \frac{\mathbf{x}_t+\left(1-\bar{\alpha}_t\right) \nabla \log p\left(\mathbf{x}_t\right)}{\sqrt{\bar{\alpha}_t}}}{1-\bar{\alpha}_t} \\
& =\frac{\sqrt{\alpha_t}\left(1-\bar{\alpha}_{t-1}\right) \mathbf{x}_t+\left(1-\alpha_t\right) \frac{\mathbf{x}_t+\left(1-\bar{\alpha}_t\right) \nabla \log p\left(\mathbf{x}_t\right)}{\sqrt{\alpha_t}}}{1-\bar{\alpha}_t} \\
& =\frac{\sqrt{\alpha_t}\left(1-\bar{\alpha}_{t-1}\right) \mathbf{x}_t}{1-\bar{\alpha}_t}+\frac{\left(1-\alpha_t\right) \mathbf{x}_t}{\left(1-\bar{\alpha}_t\right) \sqrt{\alpha_t}}+\frac{\left(1-\alpha_t\right)\left(1-\bar{\alpha}_t\right) \nabla \log p\left(\mathbf{x}_t\right)}{\left(1-\bar{\alpha}_t\right) \sqrt{\alpha_t}} \\
& =\left(\frac{\sqrt{\alpha_t}\left(1-\bar{\alpha}_{t-1}\right)}{1-\bar{\alpha}_t}+\frac{1-\alpha_t}{\left(1-\bar{\alpha}_t\right) \sqrt{\alpha_t}}\right) \mathbf{x}_t+\frac{1-\alpha_t}{\sqrt{\alpha_t}} \nabla \log p\left(\mathbf{x}_t\right) \\
& =\left(\frac{\alpha_t\left(1-\bar{\alpha}_{t-1}\right)}{\left(1-\bar{\alpha}_t\right) \sqrt{\alpha_t}}+\frac{1-\alpha_t}{\left(1-\bar{\alpha}_t\right) \sqrt{\alpha_t}}\right) \mathbf{x}_t+\frac{1-\alpha_t}{\sqrt{\alpha_t}} \nabla \log p\left(\mathbf{x}_t\right) \\
& =\frac{\alpha_t-\bar{\alpha}_t+1-\alpha_t}{\left(1-\bar{\alpha}_t\right) \sqrt{\alpha_t}} \mathbf{x}_t+\frac{1-\alpha_t}{\sqrt{\alpha_t}} \nabla \log p\left(\mathbf{x}_t\right) \\
& =\frac{1-\bar{\alpha}_t}{\left(1-\bar{\alpha}_t\right) \sqrt{\alpha_t}} \mathbf{x}_t+\frac{1-\alpha_t}{\sqrt{\alpha_t}} \nabla \log p\left(\mathbf{x}_t\right) \\
& =\frac{1}{\sqrt{\alpha_t}} \mathbf{x}_t+\frac{1-\alpha_t}{\sqrt{\alpha_t}} \nabla \log p\left(\mathbf{x}_t\right)
\end{aligned}
μ q ( x t , x 0 ) = 1 − α ˉ t α t ( 1 − α ˉ t − 1 ) x t + α ˉ t − 1 ( 1 − α t ) x 0 = 1 − α ˉ t α t ( 1 − α ˉ t − 1 ) x t + α ˉ t − 1 ( 1 − α t ) α ˉ t x t + ( 1 − α ˉ t ) ∇ log p ( x t ) = 1 − α ˉ t α t ( 1 − α ˉ t − 1 ) x t + ( 1 − α t ) α t x t + ( 1 − α ˉ t ) ∇ log p ( x t ) = 1 − α ˉ t α t ( 1 − α ˉ t − 1 ) x t + ( 1 − α ˉ t ) α t ( 1 − α t ) x t + ( 1 − α ˉ t ) α t ( 1 − α t ) ( 1 − α ˉ t ) ∇ log p ( x t ) = ( 1 − α ˉ t α t ( 1 − α ˉ t − 1 ) + ( 1 − α ˉ t ) α t 1 − α t ) x t + α t 1 − α t ∇ log p ( x t ) = ( ( 1 − α ˉ t ) α t α t ( 1 − α ˉ t − 1 ) + ( 1 − α ˉ t ) α t 1 − α t ) x t + α t 1 − α t ∇ log p ( x t ) = ( 1 − α ˉ t ) α t α t − α ˉ t + 1 − α t x t + α t 1 − α t ∇ log p ( x t ) = ( 1 − α ˉ t ) α t 1 − α ˉ t x t + α t 1 − α t ∇ log p ( x t ) = α t 1 x t + α t 1 − α t ∇ log p ( x t )
在score-based里面我们使用如下式子去近似上面的式子:
μ θ ( x t , t ) = 1 α t x t + 1 − α t α t s θ ( x t , t ) \mu_{\theta}\left(\mathbf{x}_t, t\right)=\frac{1}{\sqrt{\alpha_t}} \mathbf{x}_t+\frac{1-\alpha_t}{\sqrt{\alpha_t}} s_{\theta}\left(\mathbf{x}_t, t\right)
μ θ ( x t , t ) = α t 1 x t + α t 1 − α t s θ ( x t , t )
这里s θ ( x t , t ) s_{\theta}\left(\mathbf{x}_t, t\right) s θ ( x t , t ) 是一个神经网络,它学习预测 x t \mathbf{x}_t x t 在数据空间中的梯度。其实他和源噪声 ϵ 0 \epsilon_0 ϵ 0 非常相似。这两项与时间有一个常数的比值差距。
Improved Denoising Diffusion Probabilistic Models
除了 L 0 L_0 L 0 之外,损失函数都是两个高斯分布之间的KL散度,因此可以通过闭合形式进行评估。为了评估图像的L 0 L_0 L 0 ,我们假设每个颜色分量分成256个分区,并计算p θ ( x 0 ∣ x 1 ) p_{\theta}(\mathbf{x}_0|\mathbf{x}_1) p θ ( x 0 ∣ x 1 ) 落入正确分区的概率(使用高斯分布的CDF可计算性)。此外,需要注意的是,虽然L T L_T L T 不依赖于 θ \theta θ ,但如果前向加噪过程充分破坏了数据分布,此时L T L_T L T 将接近于零。
DDIM
改变了q ( x t − 1 ∣ x t , x 0 ) q(\mathbf{x}_{t-1} | \mathbf{x}_t, \mathbf{x}_0) q ( x t − 1 ∣ x t , x 0 )
q σ ( x t − 1 ∣ x t , x 0 ) = N ( x t − 1 ; α t − 1 x 0 + 1 − α t − 1 − σ t 2 x t − α t x 0 1 − α t , σ t 2 I ) q_\sigma(\mathbf{x}_{t-1} \vert \mathbf{x}_t, \mathbf{x}_0) = \mathcal{N}(\mathbf{x}_{t-1}; \sqrt{\alpha_{t-1}}\mathbf{x}_0 + \sqrt{1 - \alpha_{t-1} - \sigma_t^2} \frac{\mathbf{x}_t - \sqrt{\alpha_t}\mathbf{x}_0}{\sqrt{1 - \alpha_t}}, \sigma_t^2 \mathbf{I})
q σ ( x t − 1 ∣ x t , x 0 ) = N ( x t − 1 ; α t − 1 x 0 + 1 − α t − 1 − σ t 2 1 − α t x t − α t x 0 , σ t 2 I )
x t − 1 = α t − 1 ( x t − 1 − α t ϵ θ ( x t , t ) α t ⏟ predicted x 0 ) + 1 − α t − 1 − σ t 2 ⋅ ϵ θ ( x t , t ) ⏟ direction pointing to x t + σ t ϵ t ⏟ random noise \mathbf{x}_{t-1} = \sqrt{\alpha_{t-1}}\Big(\underbrace{\frac{\mathbf{x}_t-\sqrt{1-\alpha_{t}}\mathbf{\epsilon}_\theta(\mathbf{x}_t, t)}{\sqrt{\alpha_{t}}}}_{\text{predicted}\ \mathbf{x}_0}\Big) + \underbrace{\sqrt{1 - \alpha_{t-1} - \sigma_t^2} \cdot \mathbf{\epsilon}_\theta(\mathbf{x}_t, t)}_{\text{direction pointing to }\ \mathbf{x}_t} + \underbrace{\sigma_t\epsilon_t}_{\text {random noise}}
x t − 1 = α t − 1 ( predicted x 0 α t x t − 1 − α t ϵ θ ( x t , t ) ) + direction pointing to x t 1 − α t − 1 − σ t 2 ⋅ ϵ θ ( x t , t ) + random noise σ t ϵ t
注意:其中的α \alpha α 就是DDPM里面的α ˉ \bar{\alpha} α ˉ 。这里将生成过程分成三个部分:一是由预测的x 0 \mathbf{x}_0 x 0 来产生的,二是由指向x t \mathbf{x}_t x t 的部分,三是随机噪音(这里ϵ t \epsilon_t ϵ t 是与x t \mathbf{x}_t x t 无关的噪音)。
条件生成(Stable diffusion/Diffusion Models Beat GANs on Image Synthesis)
上面的推导只关注于数据分布 p ( x ) p(\mathbf{x}) p ( x ) 的建模。然而,我们也经常对学习条件分布 p ( x ∣ y ) p(\mathbf{x}|y) p ( x ∣ y ) 感兴趣,这将使我们能够显式地控制通过条件信息 y y y 生成的数据。
p ( x 0 : T ∣ y ) = p ( x T ) ∏ t = 1 T p θ ( x t − 1 ∣ x t , y ) p\left(\mathbf{x}_{0: T} | y\right)=p\left(\mathbf{x}_T\right) \prod_{t=1}^T p_{\theta}\left(\mathbf{x}_{t-1} | \mathbf{x}_t, y\right)
p ( x 0 : T ∣ y ) = p ( x T ) t = 1 ∏ T p θ ( x t − 1 ∣ x t , y )
其中 y y y 可以是任意东西。只需要在各个神经网络的输入里面传入信息 y y y 。但是,以这种方式训练的条件扩散模型可能会潜在地学会忽略或淡化任何给定的条件信息。因此,指导(Guidance)被提出作为一种方式,以样本多样性为代价,更明确地控制模型给予条件信息的权重。最流行的两种指导形式是分类器指导和无分类器指导。
分类器指导
∇ log p ( x t ∣ y ) = ∇ log ( p ( x t ) p ( y ∣ x t ) p ( y ) ) = ∇ log p ( x t ) + ∇ log p ( y ∣ x t ) − ∇ log p ( y ) = ∇ log p ( x t ) ⏟ unconditional score + ∇ log p ( y ∣ x t ) ⏟ adversarial gradient \begin{aligned}
\nabla \log p\left(\mathbf{x}_t | y\right) & =\nabla \log \left(\frac{p\left(\mathbf{x}_t\right) p\left(y | \mathbf{x}_t\right)}{p(y)}\right) \\
& =\nabla \log p\left(\mathbf{x}_t\right)+\nabla \log p\left(y | \mathbf{x}_t\right)-\nabla \log p(y) \\
& =\underbrace{\nabla \log p\left(\mathbf{x}_t\right)}_{\text {unconditional score }}+\underbrace{\nabla \log p\left(y | \mathbf{x}_t\right)}_{\text {adversarial gradient }}
\end{aligned}
∇ log p ( x t ∣ y ) = ∇ log ( p ( y ) p ( x t ) p ( y ∣ x t ) ) = ∇ log p ( x t ) + ∇ log p ( y ∣ x t ) − ∇ log p ( y ) = unconditional score ∇ log p ( x t ) + adversarial gradient ∇ log p ( y ∣ x t )
我们也可以选择控制后面那个训练的分类器牺牲的多样性代价:
∇ log p ( x t ∣ y ) = ∇ log p ( x t ) + γ ∇ log p ( y ∣ x t ) \nabla \log p\left(\mathbf{x}_t | y\right) = \nabla \log p\left(\mathbf{x}_t\right) + \gamma \nabla \log p\left(y | \mathbf{x}_t\right)
∇ log p ( x t ∣ y ) = ∇ log p ( x t ) + γ ∇ log p ( y ∣ x t )
分类器指南的一个显著缺点是它依赖于一个单独的学习分类器。因为分类器必须处理任意噪声的输入,这是大多数现有的预先训练的分类模型都没有优化过的,所以它必须与扩散模型一起学习。
无分类器指导
∇ log p ( x t ∣ y ) = ∇ log p ( x t ) + γ ( ∇ log p ( x t ∣ y ) − ∇ log p ( x t ) ) = ∇ log p ( x t ) + γ ∇ log p ( x t ∣ y ) − γ ∇ log p ( x t ) = γ ∇ log p ( x t ∣ y ) ⏟ conditional score + ( 1 − γ ) ∇ log p ( x t ) ⏟ unconditional score \begin{aligned}
\nabla \log p\left(\mathbf{x}_t | y\right) & =\nabla \log p\left(\mathbf{x}_t\right)+\gamma\left(\nabla \log p\left(\mathbf{x}_t | y\right)-\nabla \log p\left(\mathbf{x}_t\right)\right) \\
& =\nabla \log p\left(\mathbf{x}_t\right)+\gamma \nabla \log p\left(\mathbf{x}_t | y\right)-\gamma \nabla \log p\left(\mathbf{x}_t\right) \\
& =\underbrace{\gamma \nabla \log p\left(\mathbf{x}_t | y\right)}_{\text {conditional score }}+\underbrace{(1-\gamma) \nabla \log p\left(\mathbf{x}_t\right)}_{\text {unconditional score }}
\end{aligned}
∇ log p ( x t ∣ y ) = ∇ log p ( x t ) + γ ( ∇ log p ( x t ∣ y ) − ∇ log p ( x t ) ) = ∇ log p ( x t ) + γ ∇ log p ( x t ∣ y ) − γ ∇ log p ( x t ) = conditional score γ ∇ log p ( x t ∣ y ) + unconditional score ( 1 − γ ) ∇ log p ( x t )