I recently realized the connection between the expectation maximization algorithm (EM) and variational autoencoders (VAE). Both optimize the same objective function where VAE performs gradient descent based on a sampling estimate of the gradient while EM performs exact alternating maximization in models where this is possible. It turns out that the alternating optimization view of EM is described in section 9.4 of Bishop’s 2006 text on machine learning and also in section 19.2 of the text by Goodfellow, Bengio and Courville. I had overlooked this connection between EM and VAEs, and it seems important, so here is a blog post that helps me, at least, organize my thinking.
First some background.
The expectation maximization (EM) algorithm is a time-honored cornerstone of machine learning. Its general formulation was given in a classic 1977 paper by Dempster, Laird and Rubin, although many special cases had been developed prior to that. The algorithm is used for estimating unknown parameters of a probability distribution involving latent variables. The algorithm is usually introduced by looking at the special case of modeling a collection of points by a mixture of Gaussians. Here we are given points where the model assumes that each point came from an unlabeled component of the mixture (the component associated with the point is a latent variable). Given the points one must estimate a component weight, a mean and a covariance matrix for each component of the mixture.
This is a non-convex optimization problem. EM iteratively re-estimates the parameters with a guarantee that for each iteration, unless one is already at a parameter setting with zero probability gradients (a stationary point), the probability of the points under the model improves. There are numerous important special cases such as learning the parameters of a hidden Markov model or a probabilistic context free grammar.
Variational autoencoders (VAEs) (tutorial) were introduced in 2014 by Kingma and Welling in the context of deep learning. Autoencoding is related to compression. JPEG compresses an image into an encoded form that uses less memory but can then be decompressed to get the image back with some loss of resolution.
Information theoretically, compression is closely relate to distribution modeling. Shannon’s source coding theorem states that for any probability distribution, such as the distribution on “natural” images, one can (block) code the elements using a number of bits per element equal to the information-theoretic entropy of the distribution. In the above figure the average number of bits in the compressed form in the middle should be equal to the entropy of the probability distribution over images.
VAEs, however, do not directly address the compression problem. Instead they focus on fitting the parameters of the distribution so as to maximize the marginal probability of a given set of samples (such as a set of points or a set of images). The encoder is just a tool to make this parameter fitting more efficient. The problem being solved by a VAE is the same as the problem being solved by EM — fitting the parameters of a probability distribution to given data where the model includes latent variables not specified in the data.
VAE = EM. Both EM and VAEs are formulated in terms of the following objective function on the two distributions (the encoder) and (the model defining and ).
Here is summed over training data. The equivalence of (1) and (2) is important as they provide different outlooks on the objective. The equivalence is implied by the following observation.
EM can be defined as alternating optimization of where (2) supports the optimization of (the E step) and (1) supports the optimization of (the M step). The EM algorithm applies to models where both the E step and the M step can be solved efficiently exactly and for such models each M step is guaranteed to improve . For these models this alternating optimization is typically far superior to any form of gradient descent on .
VAEs are used in models where the alternating optimization cannot be done efficiently in closed form. A VAE performs gradient descent on (1) where the gradient is estimated by sampling from the encoder . This typically involves a “reparameterization trick”. One represents by taking where is a random variable of a fixed distribution. We can then rewrite (1) as
It is common to take , , and to all be high dimensional Gaussians each represented by a mean and a (diagonal) covariance . The noise variable is taken to be a fixed zero mean isotropic Gaussian while , and have means and covariances computed by deep networks. We can now calculate the gradient of (3) for any given sample of .
It is not clear whether the term “variational auto encoder” should imply the use of Gaussians. Goodfellow et al. seem to take the position that VAEs are defined by gradient descent on (1) where the gradient is estimated by sampling from . This does not necessarily involve Gaussians. A version of HMM-VAEs has recently been formulated which includes discrete latent HMM states. However the observed variable is still continuous and Gaussians are still involved. Presumably one can formulate VAEs where both and are structured discrete variables, such a language model with discrete latent states. These VAEs would presumably not involve Gaussians.
Other approaches to optimizing are possible. The speech recognition algorithm CTC uses an exact E step but a gradient M step and achieves gradient descent on deep networks with discrete latent variables but without latent variable sampling. See the following blog post for a discussion of the CTC approach which we might call the EG (expected gradient) algorithm.
Wonderful! I wasn’t sure of this being a thing until you outlined it above. I was also trying to bridge the connection between EM and how RL, specifically policy gradient methods, learn?
I’m still trying to understand and organize my own thinking, so excuse my ignorance if there actually is no connection between the two. I’d love to hear your thoughts!
VAEs and EM are about latent variables. A partially observed Markov decision process (POMDP) has an unobserved (latent) state. This is closely related to VRNNs (the marriage of VAEs and RNNs https://arxiv.org/abs/1506.02216). But while VRNNs model sequential distributions they do not handle decision making (action policies). A marriage of VRNNS with policy gradient should be possible. But my intuition is that an RNN state already model a belief state (a distribution over the hidden state) and it might be difficult to obtain gains from the VRNN latent state formulation.