Transformer's Self-Attention Mechanism Simplified

How transformer models like BERT and GPT work?

Transformer Architecture Explained

While self-attention layer is the central mechanism of the Transformer architecture, it is not the whole picture. Transformer architecture is a composite of following parts:

Transformer full model diagram
Transformer full model diagram

Self-Attention in Transformer Visualized

Self-Attention compares all input sequence members with each other, and modifies the corresponding output sequence positions. In other words, self-attention layer differentiably key-value searches the input sequence for each inputs, and adds results to the output sequence.

transformer self-attention calculation visualization
Output and input have the same sequence length and dimension. Weight each value by similarity of the corresponding query and key. For each sequence position output sum up the weighted values.

Self-Attention Explained in Detail

  • input \( X \in \mathbf{R}^{L \times d} \) is a sequence of embeddings of dimension \( d \) of length \( L \)
  • output \( Y \in \mathbf{R}^{L \times d} \) has the same shape as input
  • project \( X \) into 3 matrices of the same shape
    • query \( X^Q := W^Q X \),
    • key \( X^K := W^K X \)
    • value \( X^V := W^V X \)
  • calculate “soft sequence-wise nearest neighbor search”
    • “search” all \( L \times L \) combinations of sequence elements of \( X^K \) and \( X^Q \)
    • for each sequence position \( m \): output more of \( X^V_{o} \) the more is \( X^K_o \) similar to \( X^Q_{m} \)
    • this is done by weighting the value with a softmax of a dot-product and summing the values
    • in pseudo-code: \( Y = \mathrm{matmul}_L(\mathrm{softmax}_L(\mathrm{matmul_d}(X_q, X_k^\intercal)), X_v) \)
    • in equation: \( Y = \mathbf{softmax}(QK^\intercal)V \)
  • results are added to the residual connection and normalized
  • More details in Attention Is All You Need paper e.g.: dot-product is “scaled”, residual connection, layer normalization

Self-Attention vs Recurrent Layer

  • attention vs recurrence = graph vs sequence = Transformer vs LSTM
  • attention connects across entire sequence as fully connected graph
  • example graph task: dependency parse is a syntactic graph over the word sequence
  • RRNs keeps information from previous states in a state vector as a memory
  • RRNs not parallelizable in time dimension as future steps depend on the past
  • RRNs have difficulty accessing long time ago information
  • RNNs handle repetition better, can use CTC Loss e.g. for OCR
  • SRU++ uses both attention and recurrence
Dependency parse tree example from Spacy
Dependency parse tree example from Spacy

Multi-Head Attention

Instead of basic self-attention above, Transformer implements special more complicated layer:

  1. for each key, value, and query multiplies by additional projection weight matrix
  2. then splits each resulting embedding into 8 equal sized vectors,
  3. applies separate 1/8th dimensional self-attention mechanism to each of them,
  4. concatenates the result.

Each separate self-attention in above is called self-attention head. As a whole this layer is called multi-head attention. Multi-head attention allows each head to focus on a different subspace, with a different semantic or syntactic meaning. Splitting vector representation into subspaces is related to disentangled representation training, where we train model to give selected subspaces specific meaning.

Most heads don’t attend to the identical sequence position, probably because residual connection always adds the embedding at each position to the positions result. Special tokens are used by some heads to “attend” to nothing.

Addition of multiple heads serves more as a computation parallelization trick rather than power expansion trick. </small>

Scaled Dot-Product Attention
Scaled Dot-Product Attention (source).

Self-Attention Computational Complexity

  • complexity is quadratic in sequence length \( O(L^2) \)
  • because we need to calculate \( L \times L \) attention matrix \( \mathbf{softmax}(\frac{QK^\intercal}{\sqrt{d}}) \)
  • but context size is crucial for some tasks e.g. character-level models
  • multiple speedup approaches already exits
  • for example Performer, Expire-Span, SRU++ are architectures reducing transformer computational complexity.
Attention Complexity
Attention Complexity (source).

Training a Transformer

Transformers are usually pre-trained with self-supervised tasks like masked language modelling or next-token prediction on large datasets. Pre-trained models are often very general and publicly distributed e.g. on HuggingFace. Big transformer models are typically pre-trained on multiple GPUs. While there are various approaches to speedup transformer itself, there are also ways to improve its training. For example ELECTRA training scheme speeds up training by using GAN-like setting using a loss over entire sequence.

Then fine-tuning training is used to specialize the model for a specific task on using a small labelled dataset. A single GPU is often enough for fine-tuning. For example model like BART are fine-tuned for summarization tasks. Sometimes we fine-tune twice, as authors did with BART model equipped with diminishing self-attention to increase summarization coverage (read my summary).

Beware of possibility of the double descent of test accuracy contrary to bias-variance trade-off hypothesis (read my summary).

Serving Transformer in Kubernetes Cluster in Cloud

While you can train and predict with small transformers on for example Thinkpad P52 graphics card Quadro P1000 4GB VRAM (see my review), to run bigger models, or deploy your models to production, you will need to a bit of MLOps and DevOps, so read:

Example Transformer Models

Transformer vs Word2vec Continuous Bag-of-Words

Word2vec CBOW
Word2vec CBOW

Word2vec was used in many state-of-the-art models between 2013-2015. It was gradually replaced by more advanced variants like FastText, and StarSpace a general-purpose embeddings, and more sophisticated models like LSTM and transformers. Word2vec Continuous Bag-of-Words predicts word using its surrounding 10-word context by:

  1. summing the input embeddings corresponding to the input context words
  2. finding maximum a dot-product with all output embeddings

Note average sentence length is about 15 words. Word2vec uses 2 sets of embeddings: input and output (context) embeddings. Word2vec CBOW (w2v CBOW) model is similar to an extremely simplified a single layer transformer.

If we use:

  • a single self-attention (remove the feed forward layer and layer normalization)
  • single attention head
  • fourier positional encodings \( p_j \)
    • that behave as if concatenated: for all embeddings and positional encodings \( p_j^\intercal e_w \approx 0 \)
    • decay fast after relative distance of 4: \( p_j^\intercal p_i \approx \delta_{ i-j <= 4 } \)
  • identity key, query linear transformations \( W_K = W_Q = 1 \).
  • masked word vector has the same dot-product with all embeddings \( e_{mask}^\intercal e_w \approx C \)

Then we approximately: \( (e_w + p_w)^\intercal (e_{mask} + p_{mask}) \approx \) \( e_w^\intercal e_{mask} + p_w^\intercal p_{mask} = \) \( C + \delta_{|i-j| <= 4} \)

And if we define output embeddings via the value projection matrix multiplied with embeddings: \( W_V E \)

Then the Transformer output for a masked word is close to summation of the surrounding word vectors like in CBOW Word2vec. The positional embeddings probably do not behave as concatenated. The term above would contain relative and absolute positional terms, which are not present in Word2vec. So the transformer result would still be more expressive.

But if we would additionally not use positional encodings, and use sliding context window of size matching Word2vec’s context size, then the results would be even closer to the Word2vec.

Transformer vs FastText

Transformer architecture cannot really be compared to FastText well in other things than performance. That is because apart from whole words FastText trains also on sub-words or n-grams, while Transformer always trains only on the word tokens.

Created on 05 Mar 2022. Updated on: 24 Feb 2023.
Thank you

About Vaclav Kosar How many days left in this quarter? Twitter Bullet Points to Copy & Paste Averaging Stopwatch Privacy Policy
Copyright © Vaclav Kosar. All rights reserved. Not investment, financial, medical, or any other advice. No guarantee of information accuracy.