Diffusion Models

References

Forward Diffusion

Forward diffusion is a process to gradually add noise to an input x0x_0.

q(xtxt1)=N(xt;1βtxt1,βtI)q(x_t | x_{t-1}) = \mathcal{N}(x_t; \sqrt{1 - \beta_t} x_{t-1}, \beta_t \mathbf{I})

Instead of computing next xtx_t one by one, we can derive a closed-form for arbitrary tt. Let αt=1βt\alpha_t = 1 - \beta_t and αtˉ=Πs=1tαs\bar{\alpha_t} = \Pi_{s=1}^{t} \alpha_s:

q(xtx0)=N(xt;αtˉx0,(1αtˉ)I)q(x_t | x_0) = \mathcal{N}(x_t; \sqrt{\bar{\alpha_t}} x_0, (1 - \bar{\alpha_t}) \mathbf{I})

Thus,

xt(x0,ϵ)=αtˉx0+1αˉtϵ,  ϵN(0,1)x_t(x_0, \epsilon) = \sqrt{\bar{\alpha_t}} x_0 + \sqrt{1 - \bar{\alpha}_t} \epsilon, ~~ \epsilon \sim \mathcal{N}(0, 1)
  • βt\beta_t can be predefined to gradually decrease with time (typically propotionally decrease with time)

Some important questions and comments:

  • This gradually brings down the mean to zero
  • Why does it try to meet c12+c22=1c_1^2 + c_2^2 = 1 (as in xt=c1x0+c2ϵx_t = c_1 x_0 + c_2 \epsilon) ?
    • This preserves variance. This assures that xTx_T is isotropic Gaussian, assuming x0x_0 has a unit variance
  • What happens when the input is zero vector (x0=0x_0 = \mathbf{0}) ?
  • What happens if we pretend that we always sample from the mean of gaussian (ϵ=0\epsilon = 0) ?

Reverse Diffusion

We want to model the reverse diffusion process q(xt1xt)q(x_{t-1}|x_t). That can be represented as:

pθ(xt1xt)=N(xt1;μθ(xt,t),Σθ(xt,t))p_\theta(x_{t-1}| x_t) = \mathcal{N}(x_{t-1}; \mu_\theta(x_t, t), \Sigma_\theta(x_t, t))

:::message Notice that this is exactly Variational Inference, where we approximate q(xt1xt)q(x_{t-1}|x_t) by pθ(xt1xt)p_\theta(x_{t-1}|x_t). :::

If we obtain μθ\mu_\theta and Σθ\Sigma_\theta, we are done.

L=Ex0Pop[logpθ(x0)]=Ex0Pop[log(pθ(x0:T)dx1:T)]=Ex0Pop[log(q(x1:Tx0)pθ(x0:T)q(x1:Tx0)dx1:T)]    (Importance sampling)=Ex0Pop[log(Ex1:Tq(x0)[pθ(x0:T)q(x1:Tx0)])]Ex0Pop[Ex1:Tq(x0)[logpθ(x0:T)q(x1:Tx0)]]     (Jensen’s inequality)=Ex1:Tq(x0),x0Pop[logpθ(x0:T)q(x1:Tx0)]=Ex1:Tq(x0),x0Pop[logp(x0x1)pθ(xT1xT)pθ(xT)q(xTxT1)q(x1x0)]=Ex1:Tq(x0),x0Pop[logpθ(xT)logΠt=1Tp(xt1xt)q(xtxt1)]==Ex1:Tq(x0),x0Pop[logp(xT)q(xTx0)logΠt=2Tp(xt1xt)q(xt1xt,x0)logp(x0x1)]=KL(p(xTx0)  p(xT))LT : constant+t=2TKL(q(xt1xt,x0)  p(xt1xt))Lt1+Ex1q(x0),x0Pop[logp(x0x1)]L0 : VAE=t=2TKL(q(xt1xt,x0)  p(xt1xt))+Ex1q(x0),x0Pop[logp(x0x1)]+const.\begin{align*} \mathcal{L} &= \mathbb{E}_{x_0 \sim Pop} [-\log p_\theta (x_0)] \\ &= \mathbb{E}_{x_0 \sim Pop} [-\log \bigg( \int p_\theta (x_{0:T}) dx_{1:T} \bigg) ] \\ &= \mathbb{E}_{x_0 \sim Pop} [-\log \bigg( \int q(x_{1:T}|x_0) \frac{p_\theta (x_{0:T})}{q(x_{1:T}|x_0)} dx_{1:T} \bigg)] ~~~~\text{(Importance sampling)}\\ &= \mathbb{E}_{x_0 \sim Pop} [{\color{brown}-\log} \bigg( {\color{blue}\mathbb{E}_{x_{1:T}\sim q(\cdot|x_0)}} \bigg[\frac{p_\theta(x_{0:T})}{q(x_{1:T}|x_0)}\bigg] \bigg)] \\ &\leq \mathbb{E}_{x_0 \sim Pop} [{\color{blue}\mathbb{E}_{x_{1:T}\sim q(\cdot|x_0)}} \bigg[{\color{brown} -\log} \frac{p_\theta(x_{0:T})}{q(x_{1:T}|x_0)}\bigg]] ~~~~~ \text{(Jensen's inequality)}\\ &= \mathbb{E}_{x_{1:T}\sim q(\cdot|x_0), x_0 \sim Pop} \bigg[-\log \frac{p_\theta(x_{0:T})}{q(x_{1:T}|x_0)}\bigg]\\ &= \mathbb{E}_{x_{1:T}\sim q(\cdot|x_0), x_0 \sim Pop} \bigg[-\log \frac{{\color{brown}p(x_0|x_1) \cdots p_\theta(x_{T-1}|x_T)} p_\theta(x_T)}{{\color{brown}q(x_T|x_{T-1}) \cdots q(x_1|x_0)}}\bigg]\\ &= \mathbb{E}_{x_{1:T}\sim q(\cdot|x_0), x_0 \sim Pop} \bigg[-\log p_\theta(x_T) -\log {\color{brown}\Pi_{t=1}^T \frac{p(x_{t-1}|x_t)}{q(x_t|x_{t-1})}} \bigg]\\ &= \cdots \\ &= \mathbb{E}_{x_{1:T}\sim q(\cdot|x_0), x_0 \sim Pop} \bigg[-\log \frac{p(x_T)}{q(x_T|x_0)} - \log \Pi_{{\color{brown}t=2}}^T \frac{p(x_{t-1}|x_t)}{q(x_{t-1}|x_t, x_0)} - \log p(x_0| x_1) \bigg]\\ &= \underbrace{{\color{blue} \text{KL}\big(p(x_T|x_0)~||~p(x_T) \big)}}_{L_T~:~\text{constant}} + \sum_{t=2}^T \underbrace{\text{KL}\big(q(x_{t-1}| x_t, x_0) ~||~ p(x_{t-1}| x_t) \big)}_{L_{t-1}} + \underbrace{\mathbb{E}_{x_1 \sim q(\cdot| x_0), x_0 \sim Pop} \big[ -\log p(x_0 | x_1) \big]}_{L_0~:~\text{VAE}}\\ &= \sum_{t=2}^T \text{KL}\big(q(x_{t-1}| x_t, x_0) ~||~ p(x_{t-1}| x_t) \big) + \mathbb{E}_{x_1 \sim q(\cdot| x_0), x_0 \sim Pop} \big[ -\log p(x_0 | x_1) \big] + {\color{blue} \text{const}.} \\ \end{align*}

Each term can be computed in closed-form. Note that q(xt1xt,x0)q(x_{t-1}|x_t, x_0) is a little tricky:

Training Loss

Parameterization of mean

μθ(xt,t)=1αt(xt1αt1αˉtϵθ(xt,t))\mu_\theta(x_t, t) = \frac{1}{\sqrt{\alpha_t}} (x_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar{\alpha}_t}} \epsilon_\theta (x_t, t) )

Variance is taken as constant

xt1=1αt(xt1αt1αˉtϵθ(xt,t))+σtzx_{t-1} = \frac{1}{\sqrt{\alpha_t}} (x_t - \frac{1-\alpha_t}{\sqrt{1 - \bar{\alpha}_t }} \epsilon_\theta (x_t, t) ) + \sigma_t z

Simplified loss:

L=Et,x0,ϵ[ϵϵθ(αˉtx0+1αˉtϵ,t)2]\mathcal{L} = \mathbb{E}_{t, x_0, \epsilon} [\| \epsilon - \epsilon_\theta (\sqrt{\bar{\alpha}_t} x_0 + \sqrt{1 - \bar{\alpha}_t} \epsilon, t) \|^2]

Another view

Forward process:

xt=αtxt1+1αtϵ,  ϵN(0,1)x_t = \sqrt{\alpha_t} x_{t-1} + \sqrt{1 - \alpha_t} \epsilon, ~~\epsilon \sim \mathcal{N}(0, 1)

Let's solve this for xt1x_{t-1}

xt1=1αt(xt1αtϵ),  ϵN(0,1)x_{t-1} = \frac{1}{\sqrt{\alpha_t}}( x_t - \sqrt{1 - \alpha_t} \epsilon), ~~\epsilon \sim \mathcal{N}(0, 1)

Thus, they model ϵ\epsilon as decoder and get xt1x_{t-1} as:

dec(xt,t)=1αt(xt1αtϵΦ(xt,t))+sδ,  δN(0,I),sR\text{dec}(x_t, t) = \frac{1}{\sqrt{\alpha_t}} (x_t - \sqrt{1 - \alpha_t} {\color{blue}\epsilon_\Phi(x_t, t)}) + s \delta, ~~ \delta \sim \mathcal{N}(0, I), s \in \mathbb{R}

:::message Questions

  • Why do we have the last term δ\delta?
    • This is to make sure the decoder is not a point-estimate (i.e., deterministic). Without this, KL divergence goes to infinity.
    • Instead of a scalar ss, some work learns pixel and channel-wise variance s~(xt,t)\tilde{s}(x_t, t)
      • This corresponds to the level of uncertainty in the decoder value ϵ(xt,t)\epsilon(x_t, t) :::

The loss with this decoder (xt1dec(xt,t)2\|x_{t-1} - \text{dec}(x_t, t)\|^2) is problematic as it scales the gradients on ϵ\epsilon differently for different tt.