Variational Inference Jul 07, 2022 References
David Blei has many good write-ups / talks on this topic
Others
Intro
Assumption:
x = x 1 : n x = x_{1:n} x = x 1 : n : observations
z = z 1 : m z = z_{1:m} z = z 1 : m : hidden variables
We want to model the posterior distribution (notice that this is a form of inference : estimating a hidden variable from observations):
p ( z ∣ x ) = p ( z , x ) ∫ z p ( z , x ) p(z|x) = \frac{p(z, x)}{\int_z p(z, x)} p ( z ∣ x ) = ∫ z p ( z , x ) p ( z , x )
The posterior links the data and a model.
In most of the interesting problems, calculating denominator is not tractable.
x x x is the evidence about z z z .
The main idea: pick a family of distributions over the latent variables parameterized with its own variation parameters :
q ( z 1 : m ∣ ν ) q(z_{1:m}| \nu) q ( z 1 : m ∣ ν )
then find a setting of ν \nu ν which makes it closest to the posterior of interest.
![concept](/media/posts/vi/vi-concept.png =400x)
The closeness can be measured by Kullback-Leibler (KL) divergence:
KL ( q ∥ p ) = E Z ∼ q [ log q ( Z ) p ( Z ∣ x ) ] \textrm{KL}(q\|p) = E_{Z\sim q} \big[ \log \frac{q(Z)}{p(Z|x)} \big] KL ( q ∥ p ) = E Z ∼ q [ log p ( Z ∣ x ) q ( Z ) ]
We use the arguments (q q q and p p p ) in this order specifically to take expectation over q q q . If you flip the order (i.e., KL ( p ∥ q ) \textrm{KL}(p\|q) KL ( p ∥ q ) ), that is called expectation propagation . It's a different kind of variational inference and it's more computationally expensive in general.
We cannot minimize this KL divergence directly (why??).
But we can minimize a function that is equal to it up to a constant (ELBO).
The Evidence Lower Bound (ELBO)
KL ( q ( z ) ∥ p ( z ∣ x ) ) = E q [ log q ( Z ) p ( Z ∣ x ) ] = E q [ log q ( Z ) ] − E q [ log p ( Z ∣ x ) ] = E q [ log q ( Z ) ] − E q [ log p ( Z , x ) − log p ( x ) ) ] = E q [ log q ( Z ) ] − E q [ log p ( Z , x ) ] + log p ( x ) = − ( E q [ log p ( Z , x ) ] − E q [ log q ( Z ) ] ‾ ) + log p ( x ) \begin{align*}
\textrm{KL}\big(q(z)\|p(z|x) \big) &= E_q \bigg[ \log \frac{q(Z)}{p(Z|x)} \bigg] \\
&= E_q [\log q(Z)] - E_q[\log p(Z|x)] \\
&= E_q [\log q(Z)] - E_q[\log p(Z, x) - \log p(x))] \\
&= E_q [\log q(Z)] - E_q[\log p(Z, x)] + \log p(x) \\
&= - (\underline{E_q[\log p(Z, x)] - E_q [\log q(Z)] }) + \log p(x) \\
\end{align*} KL ( q ( z ) ∥ p ( z ∣ x ) ) = E q [ log p ( Z ∣ x ) q ( Z ) ] = E q [ log q ( Z )] − E q [ log p ( Z ∣ x )] = E q [ log q ( Z )] − E q [ log p ( Z , x ) − log p ( x ))] = E q [ log q ( Z )] − E q [ log p ( Z , x )] + log p ( x ) = − ( E q [ log p ( Z , x )] − E q [ log q ( Z )] ) + log p ( x )
Notes:
The last term log p ( x ) \log p(x) log p ( x ) is independent of q q q , thus:
Minimizing KL divergence is equiv to Maximizing ELBO
:::message
ELBO derivation
Using Jensen's inequality :
log p ( x ) = log ∫ z p ( x , z ) d z = log ∫ z q ( z ) p ( x , z ) q ( z ) d z = log ( E q [ p ( x , Z ) q ( z ) ] ) ≥ E q [ log p ( x , Z ) ] − E q [ log q ( Z ) ] ‾ ( ∵ Jensen’s inequality ) . \begin{align}
\log p(x) &= \log \int_z p(x, z) dz \\
&= \log \int_z q(z) \frac{p(x, z)}{q(z)} dz \\
&= \log\bigg( E_q[\frac{p(x, Z)}{q(z)}] \bigg) \\
&\geq \underline{E_q [\log p(x, Z)] - E_q[\log q(Z)]} ~~~(\because \text{Jensen's inequality}).
\end{align} log p ( x ) = log ∫ z p ( x , z ) d z = log ∫ z q ( z ) q ( z ) p ( x , z ) d z = log ( E q [ q ( z ) p ( x , Z ) ] ) ≥ E q [ log p ( x , Z )] − E q [ log q ( Z )] ( ∵ Jensen’s inequality ) .
Another derivation: forcibly extract KL divergence
log p ( x ) = E q [ log p ( x ) ] = E q [ log { p ( x ) ⋅ p ( z ∣ x ) q ( z ) p ( z ∣ x ) q ( z ) } ] ( Stupid technique to make KL term ) = E q [ log p ( x , z ) q ( z ) + log q ( z ) p ( z ∣ x ) ] = E q [ log p ( x , z ) − log q ( z ) ] + E q [ log q ( z ) p ( z ∣ x ) ] = E q [ log p ( x , z ) − log q ( z ) ] ‾ + KL ( q ( z ) ∥ p ( z ∣ x ) ) ≥ E q [ log p ( x , Z ) ] − E q [ log q ( Z ) ] ‾ ( ∵ KL divergence is non-negative ) . \begin{align}
\log p(x) &= E_q \big[\log p(x)\big] \\
&= E_q \big[\log \{ {\color{blue} p(x)} \cdot \frac{ {\color{blue} p(z|x)} q(z)}{p(z|x) {\color{green} q(z)}}\}\big] ~~~(\text{Stupid technique to make KL term})\\
&= E_q [\log \frac{{\color{blue} p(x, z)}} {\color{green} q(z)} + \log \frac{q(z)}{p(z|x)}] \\
&= E_q [\log p(x, z) - \log q(z)] + E_q [\log \frac{q(z)}{p(z|x)}] \\
&= \underline{E_q [\log p(x, z) - \log q(z)]} + \textrm{KL}\big( q(z)\| p(z|x) \big) \\
&\geq \underline{E_q [\log p(x, Z)] - E_q[\log q(Z)]} ~~~(\because \text{KL divergence is non-negative}).
\end{align} log p ( x ) = E q [ log p ( x ) ] = E q [ log { p ( x ) ⋅ p ( z ∣ x ) q ( z ) p ( z ∣ x ) q ( z ) } ] ( Stupid technique to make KL term ) = E q [ log q ( z ) p ( x , z ) + log p ( z ∣ x ) q ( z ) ] = E q [ log p ( x , z ) − log q ( z )] + E q [ log p ( z ∣ x ) q ( z ) ] = E q [ log p ( x , z ) − log q ( z )] + KL ( q ( z ) ∥ p ( z ∣ x ) ) ≥ E q [ log p ( x , Z )] − E q [ log q ( Z )] ( ∵ KL divergence is non-negative ) .
The left hand side is called evidence probability . Hence ELBO.
:::
The difference between the ELBO and the KL divergence is the log normalizer --- which is what the ELBO bounds (???).
Variational Auto Encoder (Pretty much the same thing)
Latent variable models:
P Φ , Θ ( z , x ) = P Φ ( z ) P Θ ( x ∣ z ) P Φ , Θ ( z ∣ x ) = P Φ , Θ ( z , x ) ∫ z P Φ , Θ ( z , x ) \begin{align*}
P_{\Phi, \Theta}(z, x) &= P_\Phi(z) P_\Theta(x|z) \\
P_{\Phi, \Theta}(z | x) &= \frac{P_{\Phi, \Theta}(z, x)}{\int_z P_{\Phi, \Theta}(z, x)}
\end{align*} P Φ , Θ ( z , x ) P Φ , Θ ( z ∣ x ) = P Φ ( z ) P Θ ( x ∣ z ) = ∫ z P Φ , Θ ( z , x ) P Φ , Θ ( z , x )
We have data population, so we want to estimate Φ \Phi Φ and Θ \Theta Θ based on it:
Φ ∗ , Θ ∗ = argmin Φ , Θ E x ∼ P o p − log P Φ , Θ ( x ) \Phi^*, \Theta^* = \textrm{argmin}_{\Phi, \Theta} E_{x \sim Pop} - \log P_{\Phi, \Theta}(x) Φ ∗ , Θ ∗ = argmin Φ , Θ E x ∼ P o p − log P Φ , Θ ( x )
The problem is that we can't typically compute P Φ , Θ ( x ) P_{\Phi, \Theta}(x) P Φ , Θ ( x ) .
P Φ , Θ ( x ) = ∫ z P Φ ( z ) P Θ ( x ∣ z ) d z P_{\Phi, \Theta}(x) = \int_z P_\Phi (z) P_\Theta (x|z) dz P Φ , Θ ( x ) = ∫ z P Φ ( z ) P Θ ( x ∣ z ) d z doesn't work as the sum is too large
The same sum but with importance sampling with P Φ , Θ ( z ∣ x ) P_{\Phi, \Theta}(z|x) P Φ , Θ ( z ∣ x ) is a better idea but doesn't work: (why???)
Variational Bayes sidesteps this with a model P Ψ ( z ∣ x ) P_\Psi(z|x) P Ψ ( z ∣ x ) that approximate P Φ , Θ ( z ∣ x ) P_{\Phi, \Theta}(z|x) P Φ , Θ ( z ∣ x ) .
The ELBO:
log P Φ , Θ ( x ) ≥ E z ∼ P Ψ [ log P Φ , Θ ( z , x ) ] − E z ∼ P Ψ [ log P Ψ ( z ∣ x ) ] = E z ∼ P Ψ [ log P Θ ( x ∣ z ) P Φ ( z ) ] − E z ∼ P Ψ [ log P Ψ ( z ∣ x ) ] = E z ∼ P Ψ [ − ( log P Ψ ( z ∣ x ) P Φ ( z ) − log P Θ ( x ∣ z ) ) ] \begin{align*}
\log P_{\Phi, \Theta}(x) &\geq E_{z \sim P_{\Psi}} \big[ \log P_{\Phi, \Theta}(z, x) \big] - E_{z \sim P_{\Psi}} \big[ \log P_\Psi(z|x) \big] \\
&= E_{z \sim P_{\Psi}} \big[ \log P_{\Theta}(x|z)P_{\Phi}(z) \big] - E_{z \sim P_{\Psi}} \big[ \log P_\Psi(z|x) \big] \\
&= E_{z \sim P_{\Psi}} \big[ - \bigg( \log \frac{P_\Psi(z|x)}{P_{\Phi}(z)} - \log P_{\Theta}(x|z) \bigg) \big]
\end{align*} log P Φ , Θ ( x ) ≥ E z ∼ P Ψ [ log P Φ , Θ ( z , x ) ] − E z ∼ P Ψ [ log P Ψ ( z ∣ x ) ] = E z ∼ P Ψ [ log P Θ ( x ∣ z ) P Φ ( z ) ] − E z ∼ P Ψ [ log P Ψ ( z ∣ x ) ] = E z ∼ P Ψ [ − ( log P Φ ( z ) P Ψ ( z ∣ x ) − log P Θ ( x ∣ z ) ) ]
Thus,
Φ ∗ , Θ ∗ , Ψ ∗ = argmin E x ∼ P o p , z ∼ P Ψ [ log P Ψ ( z ∣ x ) P Φ ( z ) − P Θ ( x ∣ z ) ] \Phi^*, \Theta^*, \Psi^* = \textrm{argmin}~E_{x \sim Pop,~z \sim P_\Psi} \big[
\log \frac{P_\Psi(z|x)}{P_{\Phi}(z)} - P_{\Theta}(x|z) \big] Φ ∗ , Θ ∗ , Ψ ∗ = argmin E x ∼ P o p , z ∼ P Ψ [ log P Φ ( z ) P Ψ ( z ∣ x ) − P Θ ( x ∣ z ) ]
Minor but important: we can't do gradient descent w.r.t. Ψ \Psi Ψ as there's sampling procedure. We use re-parameterization trick to circumvent this.
P Φ ( z ) P_\Phi(z) P Φ ( z ) : the prior
P Ψ ( z ∣ x ) P_\Psi(z|x) P Ψ ( z ∣ x ) : the encoder
P Θ ( x ∣ z ) P_\Theta(x|z) P Θ ( x ∣ z ) : the decoder
E [ log P Ψ ( z ∣ x ) / P Φ ( z ) ] E[\log P_\Psi(z|x)/P_\Phi(z)] E [ log P Ψ ( z ∣ x ) / P Φ ( z )] : rate term , KL-divergence
E [ − log P Θ ( x ∣ z ) ] E[- \log P_\Theta (x | z)] E [ − log P Θ ( x ∣ z )] : distortion , Conditional entropy
:::message
Something more that are covered in TTIC31230
EM (Expectation-Maximization) algorithm is indeed a specific instantiation of VAE!
EM corresponds to minimiing the VAE objective:
First w.r.t. encoder (Ψ \Psi Ψ ): Inference step -- E step
And then w.r.t. Φ \Phi Φ and Θ \Theta Θ , while fixing Ψ \Psi Ψ : Update step -- M step
VAE is exactly the same as Rate Distortion Autoencoder (RDA)
:::
Rate-Distortion Autoencoders (mathematically the same as VAE)
Setting: Image compression where an image x x x is compressed to z z z .
We assume a stochastic compression algorithm (encoder ): P enc ( z ∣ x ) P_\text{enc}(z|x) P enc ( z ∣ x )
H ( z ) H(z) H ( z ) : The number of bits needed for the compressed file. This is the rate (bits / image) for transmitting compressed images
This is modeled with a prior model P pri ( z ) P_\text{pri}(z) P pri ( z )
H ( x ∣ z ) H(x|z) H ( x ∣ z ) : The number of additional bits needed to exactly recover x x x . This is a measure of the distortion of x x x
This is modeled with a decoder model P dec ( x ∣ z ) P_\text{dec}(x|z) P dec ( x ∣ z )