Vaclav Kosar's face photo
Vaclav Kosar
Software, Machine Learning, & Business

Transformer's Self-Attention Mechanism Simplified

Understand quickly successful architecture used in GPT, BERT, and other famous transformer models.

BERT full model diagram

The prototypical example of the Transformer architecture is the Bidirectional Encoder Representations from Transformers (BERT) model. The BERT Transformer model was introduced in a Attention Is All You Need paper.

While self-attention is the central part of the Transformer architecture, it is not the whole picture. Transformer architecture is a composite of following parts:

Self-Attention in Transformer

self-attention calculation visualization
Output and input have the same sequence length and dimension. Weight each value by similarity of the corresponding query and key. For each sequence position output sum up the weighted values.

Transformer’s self-attention layer computes differentiable key-value search and summation on the input sequence.

  • input \( X \in \mathbf{R}^{L \times d} \) is a sequence of embeddings of dimension \( d \) of length \( L \)
  • output \( Y \in \mathbf{R}^{L \times d} \) has the same shape as input
  • project \( X \) into 3 matrices of the same shape
    • query \( X^Q := W^Q X \),
    • key \( X^K := W^K X \)
    • value \( X^V := W^V X \)
  • calculate “soft sequence-wise nearest neighbor search”
    • “search” all \( L \times L \) combinations of sequence elements of \( X^K \) and \( X^Q \)
    • for each sequence position \( m \): output more of \( X^V_{o} \) the more is \( X^K_o \) similar to \( X^Q_{m} \)
    • this is done by weighting the value with a softmax of a dot-product and summing the values
    • in pseudo-code: \( Y = \mathrm{matmul}_L(\mathrm{softmax}_L(\mathrm{matmul_d}(X_q, X_k^\intercal)), X_v) \)
    • in equation: \( Y = \mathbf{softmax}(QK^\intercal)V \)
  • More details in Attention Is All You Need paper e.g.: dot-product is “scaled”, residual connection, layer normalization

Word Embeddings

Input text is split into character chunks called tokens. Tokens are mostly words around 4 characters long with prepended whitespace, but can represent special characters. Embedding layers map tokens to vectors in other words to sequence of numbers. Input and output embeddings layer share the same mapping.

Multi-Head Attention

Instead of basic self-attention above, BERT implements special more complicated layer:

  1. for each key, value, and query multiplies by additional projection weight matrix
  2. then splits each resulting embedding into 8 equal sized vectors,
  3. applies separate 1/8th dimensional self-attention mechanism to each of them,
  4. concatenates the result.

Each separate self-attention in above is called self-attention head. As a whole this layer is called multi-head attention. Multi-head attention allows each head to focus on a different subspace, with a different semantic or syntactic meaning. Splitting vector representation into subspaces is related to disentangled representation training, where we train model to give selected subspaces specific meaning.

Most heads don’t attend to the identical sequence position, probably because residual connection always adds the embedding at each position to the positions result. Special tokens are used by some heads to “attend” to nothing.

Addition of multiple heads serves more as a computation parallelization trick rather than power expansion trick. </small>

Scaled Dot-Product Attention
Scaled Dot-Product Attention (source).

Self-Attention Computational Complexity

  • complexity is quadratic in sequence length \( O(L^2) \)
  • because we need to calculate \( L \times L \) attention matrix \( \mathbf{softmax}(\frac{QK^\intercal}{\sqrt{d}}) \)
  • but context size is crucial for some tasks e.g. character-level models
  • multiple speedup approaches already exits
  • for example Performer, Expire-Span, SRU++ are architectures reducing transformer computational complexity.
Attention Complexity
Attention Complexity (source).

Positional Embeddings

positional embeddings in BERT architecture

In BERT, positional embeddings give first few tens of dimensions of the token embeddings meaning of relative positional closeness within the input sequence. In Perceiver IO positional embeddings are concatenated to the input embedding sequence instead. In SRU++ the positional embeddings are learned feature of the RNN.

Fourier Positional Encodings in BERT

  • Positional embeddings are added to the word embeddings once before the first layer.
  • Each position \( t \) within the sequence gets different embedding
    • if \( t = 2i \) is even then \( P_{t, j} := \sin (p / 10^{\frac{8i}{d}}) \)
    • if \( t = 2i + 1 \) is odd then \( P_{t, j} := \cos (p / 10^{\frac{8i}{d}}) \)
  • This is similar to fourier expansion of Diracs delta
  • dot product of any two positional encodings decays fast after first 2 nearby words
  • average sentence has around 15 words, thus only first dimensions carry information
  • the rest of the embeddings can thus function as word embeddings

Training a Transformer

Transformers are usually pre-trained with self-supervised tasks like masked language modelling or next-token prediction on large datasets. Pre-trained models are often very general and publicly distributed e.g. on HuggingFace. Big transformer models are typically pre-trained on multiple GPUs. While there are various approaches to speedup transformer itself, there are also ways to improve its training. For example ELECTRA training scheme speeds up training by using GAN-like setting using a loss over entire sequence.

Then fine-tuning training is used to specialize the model for a specific task on using a small labelled dataset. A single GPU is often enough for fine-tuning. For example model like BART are fine-tuned for summarization tasks. Sometimes we fine-tune twice, as authors did with BART model equipped with diminishing self-attention to increase summarization coverage (read my summary).

Beware of possibility of the double descent of test accuracy contrary to bias-variance trade-off hypothesis (read my summary).

Serving Transformer in Kubernetes Cluster in Cloud

While you can train and predict with small transformers on for example Thinkpad P52 graphics card (see my review), to run bigger models, or deploy your models to production, you will need to a bit of MLOps and DevOps, so read:

Example Transformer Models

Transformer vs Word2vec Continuous Bag-of-Words

Word2vec CBOW

Word2vec was used in many state-of-the-art models between 2013-2015. It was gradually replaced by more advanced variants like FastText, and StarSpace a general-purpose embeddings, and more sophisticated models like LSTM and transformers. Word2vec Continuous Bag-of-Words predicts word using its surrounding 10-word context by:

  1. summing the input embeddings corresponding to the input context words
  2. finding maximum a dot-product with all output embeddings

Note average sentence length is about 15 words. Word2vec uses 2 sets of embeddings: input and output (context) embeddings. Word2vec CBOW (w2v CBOW) model is similar to an extremely simplified a single layer transformer.

If we use:

  • a single self-attention (remove the feed forward layer and layer normalization)
  • single attention head
  • fourier positional encodings \( p_j \)
    • that behave as if concatenated: for all embeddings and positional encodings \( p_j^\intercal e_w \approx 0 \)
    • decay fast after relative distance of 4: \( p_j^\intercal p_i \approx \delta_{ i-j <= 4 } \)
  • identity key, query linear transformations \( W_K = W_Q = 1 \).
  • masked word vector has the same dot-product with all embeddings \( e_{mask}^\intercal e_w \approx C \)

Then we approximately: \( (e_w + p_w)^\intercal (e_{mask} + p_{mask}) \approx \) \( e_w^\intercal e_{mask} + p_w^\intercal p_{mask} = \) \( C + \delta_{|i-j| <= 4} \)

And if we define output embeddings via the value projection matrix multiplied with embeddings: \( W_V E \)

Then the Transformer output for a masked word is close to summation of the surrounding word vectors like in CBOW Word2vec. The positional embeddings probably do not behave as concatenated. The term above would contain relative and absolute positional terms, which are not present in Word2vec. So the transformer result would still be more expressive.

But if we would additionally not use positional encodings, and use sliding context window of size matching Word2vec’s context size, then the results would be even closer to the Word2vec.

Transformer vs FastText

Transformer architecture cannot really be compared to FastText well in other things than performance. That is because apart from whole words FastText trains also on sub-words or n-grams, while Transformer always trains only on the word tokens.

Created on 05 Mar 2022. Updated on: 19 Apr 2022.

Let's connect

Privacy Policy How many days left in this quarter? Twitter Bullet Points to Copy & Paste