Section 20.9.1 of Goodfellow, Bengio and Courville is titled “Back-Propagating through Discrete Stochastic Operations”. It discusses the REINFORCE algorithm which requires sampling sequences of discrete latent choices. It is well known that sampling choices yields a high variance gradient estimate. While such sampling seems necessary for reinforcement learning applications (such as games) there are other settings where one can train discrete latent choices using a closed form gradient and avoid high variance choice sampling. This is done in the CTC algorithm (Connectionist Temporal Classification, Graves et al. 2006). CTC is arguably the currently dominant approach to speech recognition. It is mentioned in Goodfellow et al. in the section on speech recognition but not in the general context of training discrete latent choices.
This post is also a sequel to the previous post on variational auto-encoders and EM. In CTC, as in various other models, we are interested in a marginal distribution over latent variables. In this case we have training pairs and are interested in maximizing the marginal distribution.
In CTC is a speech signal,
is a (time-compressed) phonetic transcription, and
is a sequence of latent “emission decisions”. This is described in more detail below.
The previous post noted that both VAEs and EM can be phrased in terms of a unified objective function involving an “encoder distribution” . Here we have made everything conditional on an input
.
The equivalence between (1) and (2) is derived in the previous post. In EM one performs alternating optimization where (2) is used to optimize the encoder distribution while holding the model distribution
fixed, and (1) is used to optimize model distribution
while holding the encoder distribution
fixed. In EM each of these optimizations is done exactly in closed form yielding an exact alternating maximization algorithm. In CTC the E step optimization is computed exactly in closed form but the M step is replaced by an exact closed form calculation of the gradient of (1) with respect to
. I will call this the EG algorithm (expected gradient).
Example: A Deep HMM Language Model: A deep language model is typically an auto-regressive RNN1.
[Andrej Karpathy]
Here the model stochastically generates a word and then uses the generated word as input in computing the next hidden state. In contrast to a classical finite-state HMM, the hidden state is determined by the observable word sequence. In a classical HMM many different hidden state sequences can generate the same word sequence — the hidden states are latent. We can introduce discrete latent hidden states such that the full hidden state is a pair
where
is an RNN state vector determined by observable words and
is a latent discrete state not determined by the observable sequence. We allow the discrete hidden state transition probabilities to depend on the hidden vector of an RNN. In equations we have
Here is the system of network parameters and
and
are vector embeddings of the word
and discrete state
. This model allows
to be computed by dynamic programming over the possible hidden state sequences using the forward-backward procedure. The gradient of the probability can also be computed in closed form and we can do SGD without sampling the discrete hidden states. For dynamic programming to work it is important that the state vector
is determined by
and the observed words
and is not influenced by the hidden state sequence
.
The general case: The general EG algorithm alternates an E step, which is a exact optimization as in EM, and a G step in which we take a gradient step in defined by computing the gradient of (1):
By (2), and the fact that that the gradient of the KL divergence is zero when , for the G step we also have
Combining (3) and (4) we see that that the gradient defined by (3) in the G step equals the gradient of the top level objective.
We now consider the case where is defined by a generative process where the pair
is generated stochastically by a series of discrete decisions. Each decision is assumed to be a draw from a certain probability distribution. I will write
for the decision (event) of drawing
from distribution
. Every decision made in the above HMM has the form
where and
are particular discrete hidden states. For a sequence of length
and for
discrete hidden states we have a total of
possible choice events and these choice events form the edges of an HMM trellis. Note that it is important that
does not depend on the patch through the the trellis.
But getting back to the general case let be the choices made in generating
under the model
. We then have
We now have
Now let . In the HMM example
is the set of all edges in the HMM trellis while
is one particular path through the trellis. Also let
be the indicator function that the choice
occurs when generating
from
. We can now rewrite (5) as
The expectation are assumed to be computable by dynamic programming. In the HMM example these are the edge probabilities in the HMM lattice.
CTC: In CTC we are given a sequence (typically acoustic feature vectors sampled at 100 frames per second) and we must output a sequence of symbols (typically letters or phonetic symbols)
with
and typically
. More precisely, for network parameters
the model defines
. However, in the model there is a latent sequence
where each
is either an output symbol or the special symbol
. The sequence
is derived from
by removing the occurrences of
. The model can be written as
We can use the EG algorithm to compute the gradient of where EG sums over all possible sequences
that reduce to
. I leave it to the reader to work out the details of the dynamic programming table. For the dynamic programming to work it is important that
does not depend on the latent sequence
.
Footnotes:
- A week after this post was originally published Gu et al. published a non-autoregressive language model.
Gaussian is indeed not very informative. Using Gaussians as z is a common choice seen in both VAE and GAN. For generating images, maybe using random sentences as z could be one step forward.
Then what about image captions as z? Well, but at that point we have plenty of data associating image captions with images. So it would be a comparison between a somewhat conditional approach (caption -> image) trained on MSCOCO and randomly sampling captions as a generative model, and a VAE (random caption -> image) trained on MSCOCO captions + web images. The performance number difference might be small.