- self-attention (sometimes KQV-attention) layer is central mechanism in transformer architecture introduced in Attention Is All You Need paper
- an example of architecture based on Transformer is BERT, which contains only Transformer’s encoder (it is encoder-only).
- a decoder-only) example is GPT-2.
- models based on transformer are often state-of-the-art in various domains (vision, speech, multi-modal classification …) as of 2022
- compared to LSTM (RNN), Transformer is more parallelizable, thus faster to train
Transformer Architecture Explained
While self-attention layer is the central mechanism of the Transformer architecture, it is not the whole picture. Transformer architecture is a composite of following parts:
- Tokenizers convert text to tokens and tokens are mapped to embeddings
- Positional encodings inject input word-position information
- Self-attention layer contextually encodes the input sequence information
- Feed forward layer which operates bit like a static key-value memory. FF layer is similar to self-attention except it does not use softmax and one of the input sequences is a constant.
- Cross-attention decodes output sequence of different inputs and modalities.
Self-Attention in Transformer Visualized
Self-Attention compares all input sequence members with each other, and modifies the corresponding output sequence positions. In other words, self-attention layer differentiably key-value searches the input sequence for each inputs, and adds results to the output sequence.
Self-Attention Explained in Detail
- input \( X \in \mathbf{R}^{L \times d} \) is a sequence of embeddings of dimension \( d \) of length \( L \)
- output \( Y \in \mathbf{R}^{L \times d} \) has the same shape as input
- project \( X \) into 3 matrices of the same shape
- query \( X^Q := W^Q X \),
- key \( X^K := W^K X \)
- value \( X^V := W^V X \)
- calculate “soft sequence-wise nearest neighbor search”
- “search” all \( L \times L \) combinations of sequence elements of \( X^K \) and \( X^Q \)
- for each sequence position \( m \): output more of \( X^V_{o} \) the more is \( X^K_o \) similar to \( X^Q_{m} \)
- this is done by weighting the value with a softmax of a dot-product and summing the values
- in pseudo-code: \( Y = \mathrm{matmul}_L(\mathrm{softmax}_L(\mathrm{matmul_d}(X_q, X_k^\intercal)), X_v) \)
- in equation: \( Y = \mathbf{softmax}(QK^\intercal)V \)
- results are added to the residual connection and normalized
- More details in Attention Is All You Need paper e.g.: dot-product is “scaled”, residual connection, layer normalization
Self-Attention vs Recurrent Layer
- attention vs recurrence = graph vs sequence = Transformer vs LSTM
- attention connects across entire sequence as fully connected graph
- example graph task: dependency parse is a syntactic graph over the word sequence
- RRNs keeps information from previous states in a state vector as a memory
- RRNs not parallelizable in time dimension as future steps depend on the past
- RRNs have difficulty accessing long time ago information
- RNNs handle repetition better, can use CTC Loss e.g. for OCR
- SRU++ uses both attention and recurrence
Multi-Head Attention
Instead of basic self-attention above, Transformer implements special more complicated layer:
- for each key, value, and query multiplies by additional projection weight matrix
- then splits each resulting embedding into 8 equal sized vectors,
- applies separate 1/8th dimensional self-attention mechanism to each of them,
- concatenates the result.
Each separate self-attention in above is called self-attention head. As a whole this layer is called multi-head attention. Multi-head attention allows each head to focus on a different subspace, with a different semantic or syntactic meaning. Splitting vector representation into subspaces is related to disentangled representation training, where we train model to give selected subspaces specific meaning.
Most heads don’t attend to the identical sequence position, probably because residual connection always adds the embedding at each position to the positions result. Special tokens are used by some heads to “attend” to nothing.
Addition of multiple heads serves more as a computation parallelization trick rather than power expansion trick. </small>
Self-Attention Computational Complexity
- complexity is quadratic in sequence length \( O(L^2) \)
- because we need to calculate \( L \times L \) attention matrix \( \mathbf{softmax}(\frac{QK^\intercal}{\sqrt{d}}) \)
- but context size is crucial for some tasks e.g. character-level models
- multiple speedup approaches already exits
- for example Performer, Expire-Span, SRU++ are architectures reducing transformer computational complexity.
- In Perceiver IO, cross-attention is used to reduce dimensionality and thus the complexity.
Training a Transformer
Transformers are usually pre-trained with self-supervised tasks like masked language modelling or next-token prediction on large datasets. Pre-trained models are often very general and publicly distributed e.g. on HuggingFace. Big transformer models are typically pre-trained on multiple GPUs. While there are various approaches to speedup transformer itself, there are also ways to improve its training. For example ELECTRA training scheme speeds up training by using GAN-like setting using a loss over entire sequence.
Then fine-tuning training is used to specialize the model for a specific task on using a small labelled dataset. A single GPU is often enough for fine-tuning. For example model like BART are fine-tuned for summarization tasks. Sometimes we fine-tune twice, as authors did with BART model equipped with diminishing self-attention to increase summarization coverage (read my summary).
Beware of possibility of the double descent of test accuracy contrary to bias-variance trade-off hypothesis (read my summary).
Serving Transformer in Kubernetes Cluster in Cloud
While you can train and predict with small transformers on for example Thinkpad P52 graphics card Quadro P1000 4GB VRAM (see my review), to run bigger models, or deploy your models to production, you will need to a bit of MLOps and DevOps, so read:
- store your trained models e.g. using Quilt Data in S3 (read more here)
- deploy to Kubernetes (read more here on Cortex, BentoML, and Helm)
Example Transformer Models
- Google Pathways Language Model outperforms GPT-3 and humans on more tasks
- Wav2vec uses Transformer with quantization to predict phonemes
- Diminishing self-attention improves summarization coverage
- DeepMind’s RETRO Transformer uses cross-attention to incorporate the database retrived sequences
- Expire-Span uses attention with forgetting
- SRU++ fuses of RNN and Self-attention
- Performer uses random kernel features to speedup attention
- Lambda Networks introduce self-attention modification
- For similarity task, you may also consider lightweight approximation of word movers distance - WM embedding
Transformer vs Word2vec Continuous Bag-of-Words
Word2vec was used in many state-of-the-art models between 2013-2015. It was gradually replaced by more advanced variants like FastText, and StarSpace a general-purpose embeddings, and more sophisticated models like LSTM and transformers. Word2vec Continuous Bag-of-Words predicts word using its surrounding 10-word context by:
- summing the input embeddings corresponding to the input context words
- finding maximum a dot-product with all output embeddings
Note average sentence length is about 15 words. Word2vec uses 2 sets of embeddings: input and output (context) embeddings. Word2vec CBOW (w2v CBOW) model is similar to an extremely simplified a single layer transformer.
If we use:
- a single self-attention (remove the feed forward layer and layer normalization)
- single attention head
- fourier positional encodings \( p_j \)
- that behave as if concatenated: for all embeddings and positional encodings \( p_j^\intercal e_w \approx 0 \)
-
decay fast after relative distance of 4: \( p_j^\intercal p_i \approx \delta_{ i-j <= 4 } \)
- identity key, query linear transformations \( W_K = W_Q = 1 \).
- masked word vector has the same dot-product with all embeddings \( e_{mask}^\intercal e_w \approx C \)
Then we approximately: \( (e_w + p_w)^\intercal (e_{mask} + p_{mask}) \approx \) \( e_w^\intercal e_{mask} + p_w^\intercal p_{mask} = \) \( C + \delta_{|i-j| <= 4} \)
And if we define output embeddings via the value projection matrix multiplied with embeddings: \( W_V E \)
Then the Transformer output for a masked word is close to summation of the surrounding word vectors like in CBOW Word2vec. The positional embeddings probably do not behave as concatenated. The term above would contain relative and absolute positional terms, which are not present in Word2vec. So the transformer result would still be more expressive.
But if we would additionally not use positional encodings, and use sliding context window of size matching Word2vec’s context size, then the results would be even closer to the Word2vec.
Transformer vs FastText
Transformer architecture cannot really be compared to FastText well in other things than performance. That is because apart from whole words FastText trains also on sub-words or n-grams, while Transformer always trains only on the word tokens.