Diffusion Models
References
Forward Diffusion
Forward diffusion is a process to gradually add noise to an input x0.
q(xt∣xt−1)=N(xt;1−βtxt−1,βtI)
Instead of computing next xt one by one, we can derive a closed-form for arbitrary t.
Let αt=1−βt and αtˉ=Πs=1tαs:
q(xt∣x0)=N(xt;αtˉx0,(1−αtˉ)I)
Thus,
xt(x0,ϵ)=αtˉx0+1−αˉtϵ, ϵ∼N(0,1)
- βt can be predefined to gradually decrease with time (typically propotionally decrease with time)
Some important questions and comments:
- This gradually brings down the mean to zero
- Why does it try to meet c12+c22=1 (as in xt=c1x0+c2ϵ) ?
- This preserves variance. This assures that xT is isotropic Gaussian, assuming x0 has a unit variance
- What happens when the input is zero vector (x0=0) ?
- What happens if we pretend that we always sample from the mean of gaussian (ϵ=0) ?
Reverse Diffusion
We want to model the reverse diffusion process q(xt−1∣xt). That can be represented as:
pθ(xt−1∣xt)=N(xt−1;μθ(xt,t),Σθ(xt,t))
:::message
Notice that this is exactly Variational Inference, where we approximate q(xt−1∣xt) by pθ(xt−1∣xt).
:::
If we obtain μθ and Σθ, we are done.
L=Ex0∼Pop[−logpθ(x0)]=Ex0∼Pop[−log(∫pθ(x0:T)dx1:T)]=Ex0∼Pop[−log(∫q(x1:T∣x0)q(x1:T∣x0)pθ(x0:T)dx1:T)] (Importance sampling)=Ex0∼Pop[−log(Ex1:T∼q(⋅∣x0)[q(x1:T∣x0)pθ(x0:T)])]≤Ex0∼Pop[Ex1:T∼q(⋅∣x0)[−logq(x1:T∣x0)pθ(x0:T)]] (Jensen’s inequality)=Ex1:T∼q(⋅∣x0),x0∼Pop[−logq(x1:T∣x0)pθ(x0:T)]=Ex1:T∼q(⋅∣x0),x0∼Pop[−logq(xT∣xT−1)⋯q(x1∣x0)p(x0∣x1)⋯pθ(xT−1∣xT)pθ(xT)]=Ex1:T∼q(⋅∣x0),x0∼Pop[−logpθ(xT)−logΠt=1Tq(xt∣xt−1)p(xt−1∣xt)]=⋯=Ex1:T∼q(⋅∣x0),x0∼Pop[−logq(xT∣x0)p(xT)−logΠt=2Tq(xt−1∣xt,x0)p(xt−1∣xt)−logp(x0∣x1)]=LT : constantKL(p(xT∣x0) ∣∣ p(xT))+t=2∑TLt−1KL(q(xt−1∣xt,x0) ∣∣ p(xt−1∣xt))+L0 : VAEEx1∼q(⋅∣x0),x0∼Pop[−logp(x0∣x1)]=t=2∑TKL(q(xt−1∣xt,x0) ∣∣ p(xt−1∣xt))+Ex1∼q(⋅∣x0),x0∼Pop[−logp(x0∣x1)]+const.
Each term can be computed in closed-form.
Note that q(xt−1∣xt,x0) is a little tricky:
Training Loss
Parameterization of mean
μθ(xt,t)=αt1(xt−1−αˉt1−αtϵθ(xt,t))
Variance is taken as constant
xt−1=αt1(xt−1−αˉt1−αtϵθ(xt,t))+σtz
Simplified loss:
L=Et,x0,ϵ[∥ϵ−ϵθ(αˉtx0+1−αˉtϵ,t)∥2]
Another view
Forward process:
xt=αtxt−1+1−αtϵ, ϵ∼N(0,1)
Let's solve this for xt−1
xt−1=αt1(xt−1−αtϵ), ϵ∼N(0,1)
Thus, they model ϵ as decoder and get xt−1 as:
dec(xt,t)=αt1(xt−1−αtϵΦ(xt,t))+sδ, δ∼N(0,I),s∈R
:::message
Questions
- Why do we have the last term δ?
- This is to make sure the decoder is not a point-estimate (i.e., deterministic). Without this, KL divergence goes to infinity.
- Instead of a scalar s, some work learns pixel and channel-wise variance s~(xt,t)
- This corresponds to the level of uncertainty in the decoder value ϵ(xt,t)
:::
The loss with this decoder (∥xt−1−dec(xt,t)∥2) is problematic as it scales the gradients on ϵ differently for different t.