Prerequisites
You should be able to understand this image or this image or this image.
Introduction
In this blog, we will review the original attention mechanism as published in Bahdanau et. al..
Let’s first start with some background on where attention mechanism is mainly used – seq2seq models. One classic example of the seq2seq model is Neural Machine Translation (NMT). Let’s review it.
Neural Machine Translation
NMT has two components:
-
Encoder: The encoder, usually an RNN or multi-layered RNNs, is used to encode the source language sentence into a fixed-length vector, which is the last hidden state of the RNN for a given input sequence.
-
Decoder: The decoder takes in the last encoder hidden state as input and tries to predict each word in the target language (in an autoregressive way, meaning the last predicted word is used to predict the current word.)
The bottleneck with such a system is that the decoder has to predict the entire sentence in the target language with only the last-encoder-hidden-state as input. This limits the capacity of the decoder to predict/output long sentences in the target language.
What if there was a way to supplement the last-encoder-hidden-state with some more information to make the task of decoding a little bit easier?
Enter attention!!
Learning to Align and Translate : The Attention Model
We have seen in the previous section that decoding the whole output sequence from a single fixed-length vector can be problematic when the output sequence is long. We can provide more info/context about all the encoder hiddens states by taking dynamically weighted average of all the encoder hidden states, in the following manner, known as “attention”:
Assume that, $(h_{1}, h_{2}, h_{3},…, h_{T})$ are the hidden states of the encoder layer.
Let the decoder layer be a generic function,
\begin{equation} y_{i} = f(y_{i-1}, c_i, s_i) \end{equation}
where
- $y_{i-1}$ is the last predicted output by the decoder,
- $s_i$ is the current hidden state of the decoder,
- $c_i$ is the dynamic context vector.
The dynamic context vector $c_i$ is a (learnable) weighted average of all the encoder hidden states as described below:
\begin{equation} c_i = \sum_{j=1}^{T_{x}}\alpha_{ij}h_j. \end{equation}
where $\alpha_{ij}$ is the weight/score of the hidden state $h_j$ and are given by equation:
\begin{equation} \alpha_{ij} = \frac{\exp(e_{ij})}{\sum_{k=1}^{T_x}exp(e_{ik})} \end{equation}
where \begin{equation} e_{ij} = a(s_{i-1}, h_j) \end{equation}
The function $a$ here is an “attention” model.
One can think of $e_{ij}$ as the unnormalized weight/score given to the encoder hidden state $h_j$ while calculating the linear combination of all hidden states to obtain a context vector $c_i$. The context vector is further used to obtain the decoder output at $i^{th}$ step.
So, $a$ is a function (of $s_{i-1}$ and $h_j$) which learns how much weightage should be assigned to encoder hidden state $h_j$ when decoding $y_i$.
Detailed Explanation of Attention for NMT
We have an input sequence $(x_1, x_2, x_3, …, x_n)$ and let their corresponding encoder hidden states be $(h_1, h_2, h_3, …, h_n)$. Let $s_i$ be the decoder hidden state at $i^{th}$ decoding step. Then the context vector $c_i$ that will be used to predict the word $y_i$ can be obtained as follows:
- $e_{ij} = f(s_i, h_j)$, here $h_j$ is the encoder hidden state of the $j^{th}$ sequence input, $s_i$ is the decoder hidden state at the $i^{th}$ decoding step.
- We get $e_{11}, e_{12}, e_{13}, …, e_{1n}$ from $f(s_1, h_1), f(s_1, h_2), f(s_1, h_3), …, f(s_1, h_n)$ respectively.
- These $e_{ij}$ are unnormalized scores/weights. Apply softmax to get normalized scores/weights: $a_{ij} = \frac{e_{ij}}{\sum_{j=1}^{T_{n}}e_{ij}}$.
- To calculate the context vector $c_i$ for the $i^{th}$ decoding step, take average of encoder hidden states with $a_{ij}$ as weights: $c_i = \sum_{j=1}^{T_n} a_{ij}h_j$.
Now with context vector $c_i$, decoder hidden state $s_i$ and last predicted word $y_{i-1}$, the decoder will predict the next word $y_i$ as per following expression:
\begin{equation} y_i = g(y_{i-1}, s_i, c_i) \end{equation}
This is the alignment/attention model as described in the paper.
This post was developed in collaboration with @nonlocal