Attention Networks with Keras

Note: A jupyter python notebook with example code can be found here: link

One of the most interesting advancements in natural language processing is the idea of attention networks. They've been used successfully in translation services, medical diagnosis, and other tasks.

Today we'll be learning what makes an attention network tick, why it's special, and the implementation details behind one.

Going into this tutorial, I'll assume some prior experience with neural networks.


A traditional recurrent neural network has some significant limitations. In an encoder-decoder network layout, it is hard to remember the entire input in a compressed format. The longer the input is, the harder learning becomes. As previous studies have shown, this performance tanks at input lengths greater then roughly 30 words.

To combat this, we have the attention network. The inspiration comes from human translation. When translating sentences, a translator doesn't read in an entire text before writing their translation. Instead, they read part of the text, then write part of the translation, and repeat until their work is done. In other words, they pay attention to only part of the text at a given moment in time. That is the key idea behind attention networks.

In attention networks, each input step has an attention weight. This will be ~1 if the input step is relevant to our current work, ~0 otherwise. These attention weights are recalculated for each output step. This allows the network's attention to shift over time.

For example, consider an attention network that translates English sentences to Spanish sentences. Each output word depends on multiple words in the input sentence, because of conjugation, tense, punctuation, etc... At t=0, an attention network would assign high attention weights to "Have", "you", and "library?", because each of these influence the first output word "┬┐Has". It would assign low attention weights to all other words to ignore them.

A diagram of a sentence being translated.

For another example, consider an attention network which converts human written times to military time. The below graph shows the attention weights at each output time step. Notice what the network is focusing on.

Attention Weights per Timestep
Attention Weights per Timestep.

The Innerworkings

Before continuing, let's list some notation:

\(T_x\) - The number of input time steps
\(T_y\) - The number of output time steps
\(attention_i\) - At output time step i, the attention weights
\(c_i\) - At output time step i, the context

With this context in hand, a diagram of an attention layer is below:

A diagram of an attention network.

Let's walk through this part by part and understand all of the math going on. The first step is to calculate \(attention_i\). This can be done in a variety of ways, as long as \(attention_i\) has shape \((T_X)\) and sums to 1.

$$ attention_i = softmax(Dense(x, y_{i-1})) $$

To be safe, \(y_{0}\) is taken to be \(\vec{0}\).

The next step is to compute the weighted sum between the attention weights and the input. The computed value is called the context, hence the shorthand is \(c_i\).

$$ c_i = \sum_{i=1}^{m} ( attention_i * x_i ) $$

The final step is to feed this into the RNN layer. The layer type is flexible.

$$ y_i = RNN(c_i) $$

Take a moment to pause, and see how the attention and context tie back into the above examples.

Let's address the elephant in the room. Why is the RNN layer necessary?

Without the RNN layer, the context would be the same for every output time step. That's like giving a text to a translator, but 90% of the words are blacked out. Hence it's necessary to have the time-varying input to allow attention to shift over time.

One other important item to note is that the attention network requires quadratic time to train. This makes it hard to train for long input. Hence it's usually used in a sliding window fashion, where it reads in words [0, 20], then [1, 21], [2, 22], and so on... There is currently research aimed at improving the training time.


Now that you know the innerworkings, it's time to mess around with some code. A demo network and dataset are provided here: link

Have any questions or thoughts? Feel free to comment below.