I recently realized the connection between the expectation maximization algorithm (EM) and variational autoencoders (VAE).  Both optimize the same objective function where VAE performs gradient descent based on a sampling estimate of the gradient while EM performs exact alternating maximization in models where this is possible.  It turns out that the alternating optimization view of EM is described in section 9.4 of Bishop’s 2006 text on machine learning and also in section 19.2 of the text by Goodfellow, Bengio and Courville.   I had overlooked this connection between EM and VAEs, and it seems important, so here is a blog post that helps me, at least, organize my thinking.

First some background.

The expectation maximization (EM) algorithm is a time-honored cornerstone of machine learning.  Its general formulation was given in a classic 1977 paper by Dempster, Laird and Rubin, although many special cases had been developed prior to that.  The algorithm is used for estimating unknown parameters of a probability distribution involving latent variables.  The algorithm is usually introduced by looking at the special case of modeling a collection of points by a mixture of Gaussians.  Here we are given points where the model assumes that each point came from an unlabeled component of the mixture (the component associated with the point is a latent variable). Given the points one must estimate a component weight, a mean and a covariance matrix for each component of the mixture.

Screen Shot 2017-10-02 at 10.06.41 AM.png

This is a non-convex optimization problem. EM iteratively re-estimates the parameters with a guarantee that for each iteration, unless one is already at a parameter setting with zero probability gradients (a stationary point), the probability of the points under the model improves.  There are numerous important special cases such as learning the parameters of a hidden Markov model or a probabilistic context free grammar.

Variational autoencoders (VAEs) (tutorial) were introduced in 2014 by Kingma and Welling in the context of deep learning.  Autoencoding is related to compression.  JPEG compresses an image into an encoded form that uses less memory but can then be decompressed to get the image back with some loss of resolution.

Screen Shot 2017-10-02 at 10.09.13 AM.png
x \hspace{1.0in}P_\Psi(z|x) \hspace{1.0in} z \hspace{1.0in} P_\Theta(z,x)\hspace{1.0in}x'

[Kevin Franz]

Information theoretically, compression is closely relate to distribution modeling. Shannon’s source coding theorem states that for any probability distribution, such as the distribution on “natural” images, one can (block) code the elements using a number of bits per element equal to the information-theoretic entropy of the distribution.  In the above figure the average number of bits in the compressed form in the middle should be equal to the entropy of the probability distribution over images.

VAEs, however, do not directly address the compression problem.  Instead they focus on fitting the parameters \Theta of the distribution P_\Theta(z,x) so as to maximize the marginal probability P_\Theta(x) of a given set x of samples (such as a set of points or a set of images). The encoder is just a tool to make this parameter fitting more efficient. The problem being solved by a VAE is the same as the problem being solved by EM — fitting the parameters of a probability distribution to given data where the model includes latent variables not specified in the data.

VAE = EM. Both EM and VAEs are formulated in terms of the following objective function on the two distributions P_\psi(z|y) (the encoder) and P_\Theta(y,z) (the model defining P(z) and P(y|z)).

\Psi^*,\Theta^* = \mathrm{argmax}_{\Psi,\Theta}\;\sum_y\;{\cal L}(\Psi,\Theta,y)

{\cal L}(\Psi,\Theta,y) = E_{z \sim P_\Psi(z|y)}\; \ln P_\Theta(z,y) + H(P_\Psi(z|y))\;\;\;(1)

= \ln\;P_{\Theta}(y) - KL(P_\Psi(z|y),P_{\Theta}(z|y))\;\;\;(2)

Here y is summed over training data. The equivalence of  (1) and (2) is important as they provide different outlooks on the objective.   The equivalence is implied by the following observation.

E_{z \sim P_\Psi(z|y)} \;\ln P_\Theta(z,y)
= E_{z \sim P_\Psi(z|y)} \;\ln P_\Theta(y)P_\Theta(z|y)
= E_{z \sim P_\Psi(z|y)}\;\ln P_\Theta(y) + \;E_{z \sim P_\Psi(z|y)}\;\ln P_\Theta(z|y)
= \ln P_\Theta(y) + \;E_{z \sim P_\Psi(z|y)}\;\ln P_\Theta(z|y)

EM can be defined as alternating optimization of \sum_y {\cal L}(\Psi,\Phi,y) where (2) supports the optimization of \Psi (the E step) and (1) supports the optimization of \Theta (the M step).  The EM algorithm applies to models where both the E step and the M step can be solved efficiently exactly and for such models each M step is guaranteed to improve \sum _y \ln P_\Theta(y).  For these models this alternating optimization is typically far superior to any form of gradient descent on \sum_y \ln P_\Theta(y).

VAEs are used in models where the alternating optimization cannot be done efficiently in closed form.  A VAE performs gradient descent on (1) where the gradient is estimated by sampling z from the encoder P_\Psi(z|y).  This typically involves a “reparameterization trick”.  One represents P_\Psi(z|y) by taking z = g_\Psi(y,\eta) where \eta is a random variable of a fixed distribution. We can then rewrite (1) as

{\cal L}(\Psi,\Theta,y) = E_{\eta}\; \ln P_\Theta(g_\Psi(y,\eta),y) + H(P_\Psi(g_\Psi(y,\eta)|y))\;\;\;(3)

It is common to take P(\eta)P_\Psi(z|y), P_\Theta(z) and P_\Theta (y|z) to all be high dimensional Gaussians each represented by a mean and a (diagonal) covariance . The noise variable \eta is taken to be a fixed zero mean isotropic Gaussian while P_\Psi(z|y), P_\Theta(z) and P_\Theta (y|z) have means and covariances computed by deep networks.  We can now calculate the gradient of (3) for any given sample of \eta.

It is not clear whether the term “variational auto encoder” should imply the use of Gaussians.  Goodfellow et al. seem to take the position that VAEs are defined by gradient descent on (1) where the gradient is estimated by sampling z from P_\Psi(z|y).  This does not necessarily involve Gaussians.  A version of HMM-VAEs has recently been formulated which includes discrete latent HMM states.  However the observed variable y is still continuous and Gaussians are still involved.  Presumably one can formulate VAEs where both y and z are structured discrete variables, such a language model with discrete latent states.  These VAEs would presumably not involve Gaussians.

Other approaches to optimizing {\cal L}(\Psi,\Theta,y) are possible. The speech recognition algorithm CTC uses an exact E step but a gradient M step and achieves gradient descent on deep networks with discrete latent variables but without latent variable sampling.  See the following blog post for a discussion of the CTC approach which we might call the EG (expected gradient) algorithm.

This entry was posted in Uncategorized. Bookmark the permalink.

2 Responses to VAE = EM

  1. a k says:

    Wonderful! I wasn’t sure of this being a thing until you outlined it above. I was also trying to bridge the connection between EM and how RL, specifically policy gradient methods, learn?

    I’m still trying to understand and organize my own thinking, so excuse my ignorance if there actually is no connection between the two. I’d love to hear your thoughts!

    August Karlstedt

  2. McAllester says:

    VAEs and EM are about latent variables. A partially observed Markov decision process (POMDP) has an unobserved (latent) state. This is closely related to VRNNs (the marriage of VAEs and RNNs But while VRNNs model sequential distributions they do not handle decision making (action policies). A marriage of VRNNS with policy gradient should be possible. But my intuition is that an RNN state already model a belief state (a distribution over the hidden state) and it might be difficult to obtain gains from the VRNN latent state formulation.

Leave a Reply

Fill in your details below or click an icon to log in: Logo

You are commenting using your account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s