Diffusion Models July 11, 2022 References
Forward Diffusion
Forward diffusion is a process to gradually add noise to an input x_0 .
q(x_t | x_{t-1}) = \mathcal{N}(x_t; \sqrt{1 - \beta_t} x_{t-1}, \beta_t \mathbf{I})
Instead of computing next x_t one by one, we can derive a closed-form for arbitrary t .
Let \alpha_t = 1 - \beta_t and \bar{\alpha_t} = \Pi_{s=1}^{t} \alpha_s :
q(x_t | x_0) = \mathcal{N}(x_t; \sqrt{\bar{\alpha_t}} x_0, (1 - \bar{\alpha_t}) \mathbf{I})
Thus,
x_t(x_0, \epsilon) = \sqrt{\bar{\alpha_t}} x_0 + \sqrt{1 - \bar{\alpha}_t} \epsilon, ~~ \epsilon \sim \mathcal{N}(0, 1)
\beta_t can be predefined to gradually decrease with time (typically propotionally decrease with time)
Some important questions and comments:
This gradually brings down the mean to zero
Why does it try to meet c_1^2 + c_2^2 = 1 (as in x_t = c_1 x_0 + c_2 \epsilon ) ?
This preserves variance. This assures that x_T is isotropic Gaussian, assuming x_0 has a unit variance
What happens when the input is zero vector (x_0 = \mathbf{0} ) ?
What happens if we pretend that we always sample from the mean of gaussian (\epsilon = 0 ) ?
Reverse Diffusion
We want to model the reverse diffusion process q(x_{t-1}|x_t) . That can be represented as:
p_\theta(x_{t-1}| x_t) = \mathcal{N}(x_{t-1}; \mu_\theta(x_t, t), \Sigma_\theta(x_t, t))
If we obtain \mu_\theta and \Sigma_\theta , we are done.
\begin{align*}
\mathcal{L} &= \mathbb{E}_{x_0 \sim Pop} [-\log p_\theta (x_0)] \\
&= \mathbb{E}_{x_0 \sim Pop} [-\log \bigg( \int p_\theta (x_{0:T}) dx_{1:T} \bigg) ] \\
&= \mathbb{E}_{x_0 \sim Pop} [-\log \bigg( \int q(x_{1:T}|x_0) \frac{p_\theta (x_{0:T})}{q(x_{1:T}|x_0)} dx_{1:T} \bigg)] ~~~~\text{(Importance sampling)}\\
&= \mathbb{E}_{x_0 \sim Pop} [{\color{brown}-\log} \bigg( {\color{blue}\mathbb{E}_{x_{1:T}\sim q(\cdot|x_0)}} \bigg[\frac{p_\theta(x_{0:T})}{q(x_{1:T}|x_0)}\bigg] \bigg)] \\
&\leq \mathbb{E}_{x_0 \sim Pop} [{\color{blue}\mathbb{E}_{x_{1:T}\sim q(\cdot|x_0)}} \bigg[{\color{brown} -\log} \frac{p_\theta(x_{0:T})}{q(x_{1:T}|x_0)}\bigg]] ~~~~~ \text{(Jensen's inequality)}\\
&= \mathbb{E}_{x_{1:T}\sim q(\cdot|x_0), x_0 \sim Pop} \bigg[-\log \frac{p_\theta(x_{0:T})}{q(x_{1:T}|x_0)}\bigg]\\
&= \mathbb{E}_{x_{1:T}\sim q(\cdot|x_0), x_0 \sim Pop} \bigg[-\log \frac{{\color{brown}p(x_0|x_1) \cdots p_\theta(x_{T-1}|x_T)} p_\theta(x_T)}{{\color{brown}q(x_T|x_{T-1}) \cdots q(x_1|x_0)}}\bigg]\\
&= \mathbb{E}_{x_{1:T}\sim q(\cdot|x_0), x_0 \sim Pop} \bigg[-\log p_\theta(x_T) -\log {\color{brown}\Pi_{t=1}^T \frac{p(x_{t-1}|x_t)}{q(x_t|x_{t-1})}} \bigg]\\
&= \cdots \\
&= \mathbb{E}_{x_{1:T}\sim q(\cdot|x_0), x_0 \sim Pop} \bigg[-\log \frac{p(x_T)}{q(x_T|x_0)} - \log \Pi_{{\color{brown}t=2}}^T \frac{p(x_{t-1}|x_t)}{q(x_{t-1}|x_t, x_0)} - \log p(x_0| x_1) \bigg]\\
&= \underbrace{{\color{blue} \text{KL}\big(p(x_T|x_0)~||~p(x_T) \big)}}_{L_T~:~\text{constant}} + \sum_{t=2}^T \underbrace{\text{KL}\big(q(x_{t-1}| x_t, x_0) ~||~ p(x_{t-1}| x_t) \big)}_{L_{t-1}} + \underbrace{\mathbb{E}_{x_1 \sim q(\cdot| x_0), x_0 \sim Pop} \big[ -\log p(x_0 | x_1) \big]}_{L_0~:~\text{VAE}}\\
&= \sum_{t=2}^T \text{KL}\big(q(x_{t-1}| x_t, x_0) ~||~ p(x_{t-1}| x_t) \big) + \mathbb{E}_{x_1 \sim q(\cdot| x_0), x_0 \sim Pop} \big[ -\log p(x_0 | x_1) \big] + {\color{blue} \text{const}.} \\
\end{align*}
Each term can be computed in closed-form.
Note that q(x_{t-1}|x_t, x_0) is a little tricky:
Conditional probability of a reverse step Note that we only know q(x_t|x_{t-1}) and q(x_t|x_0) .
\begin{align*}
q(x_{t-1} | x_t, x_0) &= \frac{q(x_{t-1}, x_t | x_0)}{q(x_t | x_0)} \\
&= \frac{q(x_t| x_{t-1}, x_0) q(x_{t-1}| x_0)}{q(x_t | x_0)} \\
&= \frac{q(x_t| x_{t-1}) q(x_{t-1}| x_0)}{q(x_t | x_0)} \\
\end{align*}
Training Loss
Parameterization of mean
\mu_\theta(x_t, t) = \frac{1}{\sqrt{\alpha_t}} (x_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar{\alpha}_t}} \epsilon_\theta (x_t, t) )
Variance is taken as constant
x_{t-1} = \frac{1}{\sqrt{\alpha_t}} (x_t - \frac{1-\alpha_t}{\sqrt{1 - \bar{\alpha}_t }} \epsilon_\theta (x_t, t) ) + \sigma_t z
Simplified loss:
\mathcal{L} = \mathbb{E}_{t, x_0, \epsilon} [\| \epsilon - \epsilon_\theta (\sqrt{\bar{\alpha}_t} x_0 + \sqrt{1 - \bar{\alpha}_t} \epsilon, t) \|^2]
Another view
Forward process:
x_t = \sqrt{\alpha_t} x_{t-1} + \sqrt{1 - \alpha_t} \epsilon, ~~\epsilon \sim \mathcal{N}(0, 1)
Let's solve this for x_{t-1}
x_{t-1} = \frac{1}{\sqrt{\alpha_t}}( x_t - \sqrt{1 - \alpha_t} \epsilon), ~~\epsilon \sim \mathcal{N}(0, 1)
Thus, they model \epsilon as decoder and get x_{t-1} as:
\text{dec}(x_t, t) = \frac{1}{\sqrt{\alpha_t}} (x_t - \sqrt{1 - \alpha_t} {\color{blue}\epsilon_\Phi(x_t, t)}) + s \delta, ~~ \delta \sim \mathcal{N}(0, I), s \in \mathbb{R}
! Questions
Why do we have the last term \delta ?
This is to make sure the decoder is not a point-estimate (i.e., deterministic). Without this, KL divergence goes to infinity.
Instead of a scalar s , some work learns pixel and channel-wise variance \tilde{s}(x_t, t)
This corresponds to the level of uncertainty in the decoder value \epsilon(x_t, t)
The loss with this decoder (\|x_{t-1} - \text{dec}(x_t, t)\|^2 ) is problematic as it scales the gradients on \epsilon differently for different t .