## CTC and the EG Algotithm: Discrete Latent Choices without Reinforcement Learning

Section 20.9.1 of Goodfellow, Bengio and Courville is titled “Back-Propagating through Discrete Stochastic Operations”.  It discusses the REINFORCE algorithm which requires sampling sequences of discrete latent choices.  It is well known that sampling choices yields a high variance gradient estimate.  While such sampling seems necessary for reinforcement learning applications (such as games) there are other settings where one can train discrete latent choices using a closed form gradient and avoid high variance choice sampling.  This is done in the CTC algorithm (Connectionist Temporal Classification, Graves et al. 2006).  CTC is arguably the currently dominant approach to speech recognition.  It is mentioned in Goodfellow et al. in the section on speech recognition but not in the general context of training discrete latent choices.

This post is also a sequel to the previous post on variational auto-encoders and EM.  In CTC, as in various other models, we are interested in a marginal distribution over latent variables.  In this case we have training pairs $(x, y)$ and are interested in maximizing the marginal distribution.

$\Theta^* = \mathrm{argmax}_\Theta\; E_{(x,y)}\;\ln P_\Theta(y|x)$

$= \mathrm{argmax}_\Theta \; E_{(x,y)}\;\ln\; \sum_z \;P_\Theta(y,z|x)$

In CTC $x$ is a speech signal, $y$ is a (time-compressed) phonetic transcription, and $z$ is a sequence of latent “emission decisions”.  This is described in more detail below.

The previous post noted that both VAEs and EM can be phrased in terms of a unified objective function involving an “encoder distribution” $P_\Psi(z|x,y)$.  Here we have made everything conditional on an input $x$.

$\Psi^*,\Theta^* = \mathrm{argmax}_{\Psi,\Theta}\;E_{(x,y)}\;{\cal L}(\Psi,\Theta,x,y)$

${\cal L}(\Psi,\Theta,x,y) = E_{z \sim P_\Psi(z|x,y)}\; \ln P_\Theta(z,y|x) + H(P_\Psi(z|x,y))\;\;\;(1)$

$= \ln\;P_{\Theta}(y|x) - KL(P_\Psi(z|x,y),P_{\Theta}(z|x,y))\;\;\;(2)$

The equivalence between (1) and (2) is derived in the previous post.  In EM one performs alternating optimization where (2) is used to optimize the encoder distribution $P_\Psi(z|x,y)$ while holding the model distribution $P_\Theta(y,z|x)$ fixed, and (1) is used to optimize model distribution $P_\Theta(z,y)$ while holding the encoder distribution $P_\Psi(z|x,y)$ fixed.  In EM each of these optimizations is done exactly in closed form yielding an exact alternating maximization algorithm.  In CTC the E step optimization is computed exactly in closed form but the M step is replaced by an exact closed form calculation of the gradient of (1) with respect to $\Theta$.  I will call this the EG algorithm (expected gradient).

Example: A Deep HMM Language Model: A deep language model is typically an auto-regressive RNN1.

[Andrej Karpathy]

Here the model stochastically generates a word and then uses the generated word as input in computing the next hidden state.  In contrast to a classical finite-state HMM, the hidden state is determined by the observable word sequence.  In a classical HMM many different hidden state sequences can generate the same word sequence — the hidden states are latent.  We can introduce discrete latent hidden states $H_0, \ldots, H_T$ such that the full hidden state is a pair $(h_t, H_t)$ where $h_t$ is an RNN state vector determined by observable words and $H_t$ is a latent discrete state not determined by the observable sequence.  We allow the discrete hidden state transition probabilities to depend on the hidden vector of an RNN.  In equations we have

$h_{t+1} = \mathrm{RNNcell}_\Theta(h_t, e(w_t))$

$P_\Theta(H_{t+1} \;|\;w_1,\ldots,\;w_t,\;H_1,\ldots,\;H_t) = P_\Theta(H_{t+1}\;|\;h_{t+1},e(H_t))$

$P_\Theta(w_{t+1} \;|w_1,\ldots,w_t,\;H_1,\ldots,\;H_t) = P_\Theta(w_{t+1}\;|\;h_{t+1},e(H_{t+1})).$

Here $\Theta$ is the system of network parameters and $e(w)$ and $e(H)$ are vector embeddings of the word $w$ and discrete state $H$. This model allows $P_\Theta(w_1,\ldots,w_T)$ to be computed by dynamic programming over the possible hidden state sequences using the forward-backward procedure.  The gradient of the probability can also be computed in closed form and we can do SGD without sampling the discrete hidden states.  For dynamic programming to work it is important that the state vector $h_t$ is determined by $h_0$ and the observed words $w_1,\ldots,w_{t-1}$ and is not influenced by the hidden state sequence $H_1,\ldots,H_{t-1}$.

The general case: The general EG algorithm alternates an E step, which is a exact optimization as in EM, and a G step in which we take a gradient step in $\Theta$ defined by computing the gradient of (1):

$\nabla_\Theta {\cal L}(\Psi,\Theta,y) = E_{(x,y)}\;E_{z \sim P_\Psi(z|x,y)}\; \nabla_\Theta\;\ln\;P_\Theta(y,z|x)\;\;\;\;\;(3)$

By (2), and the fact that that the gradient of the KL divergence is zero when $P_\Psi(z|y) = P_\Theta(z|y)$, for the G step we also have

$\nabla_\Theta {\cal L}(\Psi, \Theta,y) = \nabla_\Theta \ln P_\Theta(y|x)\;\;\;(4)$

Combining (3) and (4) we see that that the gradient defined by (3) in the G step equals the gradient of the top level objective.

We now consider the case where $P_\Theta(z,y|x)$ is defined by a generative process where the pair $(z,y)$ is generated stochastically by a series of discrete decisions.  Each decision is assumed to be a draw from a certain probability distribution. I will write $Q \leadsto j$ for the decision (event) of drawing $j$ from distribution $Q$ .  Every decision made in the above HMM has the form

$P_\Theta(H|h_{t+1},e(H^i)) \leadsto H^j$

where $H^i$ and $H^j$ are particular discrete hidden states.  For a sequence of length $T$ and for $n$ discrete hidden states we have a total of $Tn^2$ possible choice events and these choice events form the edges of an HMM trellis.  Note that it is important that $h_t$ does not depend on the patch through the the trellis.

But getting back to the general case let ${\cal C}(z,y|x)$ be the choices made in generating $(z,y)$ under the model $P_\Theta(z,y|x)$.  We then have

$\ln P_\Theta(z,y|x) = \sum_{(Q \leadsto j) \in {\cal C}(z,y|x)}\;\ln P(j\;|\;Q).$

We now have

$\nabla_\Theta \ln P_\Theta (y|x) = E_{z \sim P_\Theta(z|x,y)}\;\nabla_\Theta \;\sum_{(Q \leadsto j) \in {\cal C}(y,z|x)} \;\; \ln P(j\;|\;Q)\;\;\;\;(5)$

Now let ${\cal C}(y|x) = \bigcup_z \;{\cal C}(y,z|x)$.  In the HMM example ${\cal C}(y|x)$ is the set of all edges in the HMM trellis while ${\cal C}(z,y|x)$ is one particular path through the trellis.  Also let $1((Q\leadsto j) \in {\cal C}(z,y|x))$ be the indicator function that the choice $Q \leadsto j$ occurs when generating $(z,y)$ from $x$. We can now rewrite (5) as

$\nabla_\Theta \ln P_\Theta(y|x) = \sum_{(Q \leadsto j) \in{\cal C}(y|x)} \left(E_{z \sim P_\Theta(z|x,y)} \; 1((Q \rightarrow j)\in{\cal C}(y,z|x))\right)\; \nabla_\Theta\;\ln P_\Theta(j|Q)$

The expectation $E_z 1((Q \leadsto j) \in {\cal C}(y,z|x))$ are assumed to be computable by dynamic programming.  In the HMM example these are the edge probabilities in the HMM lattice.

CTC: In CTC we are given a sequence $x_1, \ldots, x_T$ (typically acoustic feature vectors sampled at 100 frames per second) and we must output a sequence of symbols (typically letters or phonetic symbols) $y_1, \ldots, y_N$ with $N \leq T$ and typically $N << T$.  More precisely, for network parameters $\Theta$ the model defines $P_\Theta(y_1,\ldots,y_N|x_1,\ldots,x_T)$.  However, in the model there is a latent sequence $\hat{y}_1,\ldots,\hat{y}_T$ where each $\hat{y}_T$ is either an output symbol or the special symbol $\bot$.  The sequence $y_1,\ldots,y_N$ is derived from $\hat{y}_1,\ldots,\hat{y}_T$ by removing the occurrences of $\bot$. The model can be written as

$h_{t+1} = \mathrm{RNNcell}_\Theta(h_t, x_t)$

$P\Theta(\hat{y}_{t+1}|x_1,\ldots,x_t) = P_\Theta(\hat{y}_{t+1}|h_{t+1})$

We can use the EG algorithm to compute the gradient of $\ln P_\Theta(y_1,\ldots,y_N|x_1,\ldots,x_T)$ where EG sums over all possible sequences $\hat{y}_1,\ldots,\hat{y}_T$ that reduce to $y_1,\ldots,y_N$.  I leave it to the reader to work out the details of the dynamic programming table. For the dynamic programming to work it is important that $h_t$ does not depend on the latent sequence $\hat{y}_1,\ldots,\hat{y}_T$.

Footnotes:

1. A week after this post was originally published Gu et al. published a non-autoregressive language model.