Takuma Yoneda

Variational Inference

References

David Blei has many good write-ups / talks on this topic

Others

Intro

Assumption:

  • x = x_{1:n}: observations
  • z = z_{1:m}: hidden variables

We want to model the posterior distribution (notice that this is a form of inference: estimating a hidden variable from observations):

p(z|x) = \frac{p(z, x)}{\int_z p(z, x)}

The posterior links the data and a model.
In most of the interesting problems, calculating denominator is not tractable.
x is the evidence about z.

The main idea: pick a family of distributions over the latent variables parameterized with its own variation parameters:

q(z_{1:m}| \nu)

then find a setting of \nu which makes it closest to the posterior of interest.

concept

The closeness can be measured by Kullback-Leibler (KL) divergence:

\textrm{KL}(q\|p) = E_{Z\sim q} \big[ \log \frac{q(Z)}{p(Z|x)} \big]

We use the arguments (q and p) in this order specifically to take expectation over q. If you flip the order (i.e., \textrm{KL}(p\|q)), that is called expectation propagation. It's a different kind of variational inference and it's more computationally expensive in general.

We cannot minimize this KL divergence directly (why??).
But we can minimize a function that is equal to it up to a constant (ELBO).

The Evidence Lower Bound (ELBO)

\begin{align*} \textrm{KL}\big(q(z)\|p(z|x) \big) &= E_q \bigg[ \log \frac{q(Z)}{p(Z|x)} \bigg] \\ &= E_q [\log q(Z)] - E_q[\log p(Z|x)] \\ &= E_q [\log q(Z)] - E_q[\log p(Z, x) - \log p(x))] \\ &= E_q [\log q(Z)] - E_q[\log p(Z, x)] + \log p(x) \\ &= - (\underline{E_q[\log p(Z, x)] - E_q [\log q(Z)] }) + \log p(x) \\ \end{align*}

Notes:

  • The last term \log p(x) is independent of q, thus:
  • Minimizing KL divergence is equiv to Maximizing ELBO

The difference between the ELBO and the KL divergence is the log normalizer --- which is what the ELBO bounds (???).

Variational Auto Encoder (Pretty much the same thing)

Latent variable models:

\begin{align*} P_{\Phi, \Theta}(z, x) &= P_\Phi(z) P_\Theta(x|z) \\ P_{\Phi, \Theta}(z | x) &= \frac{P_{\Phi, \Theta}(z, x)}{\int_z P_{\Phi, \Theta}(z, x)} \end{align*}

We have data population, so we want to estimate \Phi and \Theta based on it:

\Phi^*, \Theta^* = \textrm{argmin}_{\Phi, \Theta} E_{x \sim Pop} - \log P_{\Phi, \Theta}(x)

The problem is that we can't typically compute P_{\Phi, \Theta}(x).

  • P_{\Phi, \Theta}(x) = \int_z P_\Phi (z) P_\Theta (x|z) dz doesn't work as the sum is too large
  • The same sum but with importance sampling with P_{\Phi, \Theta}(z|x) is a better idea but doesn't work: (why???)

Variational Bayes sidesteps this with a model P_\Psi(z|x) that approximate P_{\Phi, \Theta}(z|x).

The ELBO:

\begin{align*} \log P_{\Phi, \Theta}(x) &\geq E_{z \sim P_{\Psi}} \big[ \log P_{\Phi, \Theta}(z, x) \big] - E_{z \sim P_{\Psi}} \big[ \log P_\Psi(z|x) \big] \\ &= E_{z \sim P_{\Psi}} \big[ \log P_{\Theta}(x|z)P_{\Phi}(z) \big] - E_{z \sim P_{\Psi}} \big[ \log P_\Psi(z|x) \big] \\ &= E_{z \sim P_{\Psi}} \big[ - \bigg( \log \frac{P_\Psi(z|x)}{P_{\Phi}(z)} - P_{\Theta}(x|z) \bigg) \big] \end{align*}

Thus,

\Phi^*, \Theta^*, \Psi^* = \textrm{argmin}~E_{x \sim Pop,~z \sim P_\Psi} \big[ \log \frac{P_\Psi(z|x)}{P_{\Phi}(z)} - P_{\Theta}(x|z) \big]

Minor but important: we can't do gradient descent w.r.t. \Psi as there's sampling procedure. We use re-parameterization trick to circumvent this.

  • P_\Phi(z): the prior
  • P_\Psi(z|x): the encoder
  • P_\Theta(x|z): the decoder
  • E[\log P_\Psi(z|x)/P_\Phi(z)]: rate term, KL-divergence
  • E[- \log P_\Theta (x | z)]: distortion, Conditional entropy

Rate-Distortion Autoencoders (mathematically the same as VAE)

Setting: Image compression where an image x is compressed to z.

We assume a stochastic compression algorithm (encoder): P_\text{enc}(z|x)

  • H(z): The number of bits needed for the compressed file. This is the rate (bits / image) for transmitting compressed images
    • This is modeled with a prior model P_\text{pri}(z)
  • H(x|z): The number of additional bits needed to exactly recover x. This is a measure of the distortion of x
    • This is modeled with a decoder model P_\text{dec}(x|z)