Variational Inference

References

David Blei has many good write-ups / talks on this topic

Others

Intro

Assumption:

  • x=x1:nx = x_{1:n}: observations
  • z=z1:mz = z_{1:m}: hidden variables

We want to model the posterior distribution (notice that this is a form of inference: estimating a hidden variable from observations):

p(zx)=p(z,x)zp(z,x)p(z|x) = \frac{p(z, x)}{\int_z p(z, x)}

The posterior links the data and a model. In most of the interesting problems, calculating denominator is not tractable. xx is the evidence about zz.

The main idea: pick a family of distributions over the latent variables parameterized with its own variation parameters:

q(z1:mν)q(z_{1:m}| \nu)

then find a setting of ν\nu which makes it closest to the posterior of interest.

![concept](/media/posts/vi/vi-concept.png =400x)

The closeness can be measured by Kullback-Leibler (KL) divergence:

KL(qp)=EZq[logq(Z)p(Zx)]\textrm{KL}(q\|p) = E_{Z\sim q} \big[ \log \frac{q(Z)}{p(Z|x)} \big]

We use the arguments (qq and pp) in this order specifically to take expectation over qq. If you flip the order (i.e., KL(pq)\textrm{KL}(p\|q)), that is called expectation propagation. It's a different kind of variational inference and it's more computationally expensive in general.

We cannot minimize this KL divergence directly (why??). But we can minimize a function that is equal to it up to a constant (ELBO).

The Evidence Lower Bound (ELBO)

KL(q(z)p(zx))=Eq[logq(Z)p(Zx)]=Eq[logq(Z)]Eq[logp(Zx)]=Eq[logq(Z)]Eq[logp(Z,x)logp(x))]=Eq[logq(Z)]Eq[logp(Z,x)]+logp(x)=(Eq[logp(Z,x)]Eq[logq(Z)])+logp(x)\begin{align*} \textrm{KL}\big(q(z)\|p(z|x) \big) &= E_q \bigg[ \log \frac{q(Z)}{p(Z|x)} \bigg] \\ &= E_q [\log q(Z)] - E_q[\log p(Z|x)] \\ &= E_q [\log q(Z)] - E_q[\log p(Z, x) - \log p(x))] \\ &= E_q [\log q(Z)] - E_q[\log p(Z, x)] + \log p(x) \\ &= - (\underline{E_q[\log p(Z, x)] - E_q [\log q(Z)] }) + \log p(x) \\ \end{align*}

Notes:

  • The last term logp(x)\log p(x) is independent of qq, thus:
  • Minimizing KL divergence is equiv to Maximizing ELBO

:::message ELBO derivation Using Jensen's inequality:

logp(x)=logzp(x,z)dz=logzq(z)p(x,z)q(z)dz=log(Eq[p(x,Z)q(z)])Eq[logp(x,Z)]Eq[logq(Z)]   (Jensen’s inequality).\begin{align} \log p(x) &= \log \int_z p(x, z) dz \\ &= \log \int_z q(z) \frac{p(x, z)}{q(z)} dz \\ &= \log\bigg( E_q[\frac{p(x, Z)}{q(z)}] \bigg) \\ &\geq \underline{E_q [\log p(x, Z)] - E_q[\log q(Z)]} ~~~(\because \text{Jensen's inequality}). \end{align}

Another derivation: forcibly extract KL divergence

logp(x)=Eq[logp(x)]=Eq[log{p(x)p(zx)q(z)p(zx)q(z)}]   (Stupid technique to make KL term)=Eq[logp(x,z)q(z)+logq(z)p(zx)]=Eq[logp(x,z)logq(z)]+Eq[logq(z)p(zx)]=Eq[logp(x,z)logq(z)]+KL(q(z)p(zx))Eq[logp(x,Z)]Eq[logq(Z)]   (KL divergence is non-negative).\begin{align} \log p(x) &= E_q \big[\log p(x)\big] \\ &= E_q \big[\log \{ {\color{blue} p(x)} \cdot \frac{ {\color{blue} p(z|x)} q(z)}{p(z|x) {\color{green} q(z)}}\}\big] ~~~(\text{Stupid technique to make KL term})\\ &= E_q [\log \frac{{\color{blue} p(x, z)}} {\color{green} q(z)} + \log \frac{q(z)}{p(z|x)}] \\ &= E_q [\log p(x, z) - \log q(z)] + E_q [\log \frac{q(z)}{p(z|x)}] \\ &= \underline{E_q [\log p(x, z) - \log q(z)]} + \textrm{KL}\big( q(z)\| p(z|x) \big) \\ &\geq \underline{E_q [\log p(x, Z)] - E_q[\log q(Z)]} ~~~(\because \text{KL divergence is non-negative}). \end{align}

The left hand side is called evidence probability. Hence ELBO. :::

The difference between the ELBO and the KL divergence is the log normalizer --- which is what the ELBO bounds (???).

Variational Auto Encoder (Pretty much the same thing)

Latent variable models:

PΦ,Θ(z,x)=PΦ(z)PΘ(xz)PΦ,Θ(zx)=PΦ,Θ(z,x)zPΦ,Θ(z,x)\begin{align*} P_{\Phi, \Theta}(z, x) &= P_\Phi(z) P_\Theta(x|z) \\ P_{\Phi, \Theta}(z | x) &= \frac{P_{\Phi, \Theta}(z, x)}{\int_z P_{\Phi, \Theta}(z, x)} \end{align*}

We have data population, so we want to estimate Φ\Phi and Θ\Theta based on it:

Φ,Θ=argminΦ,ΘExPoplogPΦ,Θ(x)\Phi^*, \Theta^* = \textrm{argmin}_{\Phi, \Theta} E_{x \sim Pop} - \log P_{\Phi, \Theta}(x)

The problem is that we can't typically compute PΦ,Θ(x)P_{\Phi, \Theta}(x).

  • PΦ,Θ(x)=zPΦ(z)PΘ(xz)dzP_{\Phi, \Theta}(x) = \int_z P_\Phi (z) P_\Theta (x|z) dz doesn't work as the sum is too large
  • The same sum but with importance sampling with PΦ,Θ(zx)P_{\Phi, \Theta}(z|x) is a better idea but doesn't work: (why???)

Variational Bayes sidesteps this with a model PΨ(zx)P_\Psi(z|x) that approximate PΦ,Θ(zx)P_{\Phi, \Theta}(z|x).

The ELBO:

logPΦ,Θ(x)EzPΨ[logPΦ,Θ(z,x)]EzPΨ[logPΨ(zx)]=EzPΨ[logPΘ(xz)PΦ(z)]EzPΨ[logPΨ(zx)]=EzPΨ[(logPΨ(zx)PΦ(z)logPΘ(xz))]\begin{align*} \log P_{\Phi, \Theta}(x) &\geq E_{z \sim P_{\Psi}} \big[ \log P_{\Phi, \Theta}(z, x) \big] - E_{z \sim P_{\Psi}} \big[ \log P_\Psi(z|x) \big] \\ &= E_{z \sim P_{\Psi}} \big[ \log P_{\Theta}(x|z)P_{\Phi}(z) \big] - E_{z \sim P_{\Psi}} \big[ \log P_\Psi(z|x) \big] \\ &= E_{z \sim P_{\Psi}} \big[ - \bigg( \log \frac{P_\Psi(z|x)}{P_{\Phi}(z)} - \log P_{\Theta}(x|z) \bigg) \big] \end{align*}

Thus,

Φ,Θ,Ψ=argmin ExPop, zPΨ[logPΨ(zx)PΦ(z)PΘ(xz)]\Phi^*, \Theta^*, \Psi^* = \textrm{argmin}~E_{x \sim Pop,~z \sim P_\Psi} \big[ \log \frac{P_\Psi(z|x)}{P_{\Phi}(z)} - P_{\Theta}(x|z) \big]

Minor but important: we can't do gradient descent w.r.t. Ψ\Psi as there's sampling procedure. We use re-parameterization trick to circumvent this.

  • PΦ(z)P_\Phi(z): the prior
  • PΨ(zx)P_\Psi(z|x): the encoder
  • PΘ(xz)P_\Theta(x|z): the decoder
  • E[logPΨ(zx)/PΦ(z)]E[\log P_\Psi(z|x)/P_\Phi(z)]: rate term, KL-divergence
  • E[logPΘ(xz)]E[- \log P_\Theta (x | z)]: distortion, Conditional entropy

:::message Something more that are covered in TTIC31230

  • EM (Expectation-Maximization) algorithm is indeed a specific instantiation of VAE! EM corresponds to minimiing the VAE objective:
    • First w.r.t. encoder (Ψ\Psi): Inference step -- E step
    • And then w.r.t. Φ\Phi and Θ\Theta, while fixing Ψ\Psi: Update step -- M step
  • VAE is exactly the same as Rate Distortion Autoencoder (RDA) :::

Rate-Distortion Autoencoders (mathematically the same as VAE)

Setting: Image compression where an image xx is compressed to zz.

We assume a stochastic compression algorithm (encoder): Penc(zx)P_\text{enc}(z|x)

  • H(z)H(z): The number of bits needed for the compressed file. This is the rate (bits / image) for transmitting compressed images
    • This is modeled with a prior model Ppri(z)P_\text{pri}(z)
  • H(xz)H(x|z): The number of additional bits needed to exactly recover xx. This is a measure of the distortion of xx
    • This is modeled with a decoder model Pdec(xz)P_\text{dec}(x|z)