If you need to recap transformer read: Transformer’s Self-Attention Mechanism Simplified.
Self-Attention Complexity
- complexity is quadratic in sequence length \( O(L^2) \)
- because we need to calculate \( L \times L \) attention matrix \( \mathbf{softmax}(\frac{QK^\intercal}{\sqrt{d}}) \)
- but context size is crucial for some tasks e.g. character-level models
- multiple approaches already exits
Previous Approaches
- approximate softmax e.g. Performer
- sparsify attention e.g. BigBird
- sliding span (window attention) e.g. Multi-passage BERT
- a combination of above + global attention e.g. Longformer
Auto-Regressive Transformers
- based on previous predict the next
- one-directional attention works best on LM tasks
- all models below are auto-regressive
Transformer-XL
- Transformer-XL (Extra Long): Attentive Language Models Beyond a Fixed-Length Context
- first self-attention model better than RNN on both char & word level LM
- auto-regressive: attention is backward only not bi-directional like BERT transformer model
- instead of recalculating embeddings for each fixed span
- rather memorize previous results
- because previous results saw context not available in the next step
- this effectively increases context
- positional embeddings must be relative
Compressive Transformer
- Compressive Transformers for Long-Range Sequence Modelling
- Modifies Transformer-XL memory by additional compression function
- maps several past embeddings into one
- compressed embeddings are appended into the context
- less flexibility due to fixed compression window size
Adaptive Span
- learns to increase context length when needed
- similar to Expire-Span
- except predicts span length instead of memory forgetting
Expire-Span Attention
- Facebook AI Paper: Not All Memories are Created Equal: Learning to Forget by Expiring
- Source
- uses memory ala Transformer-XL
- has default minimal span of size \( K \)
- \( L \) is maximum span
- for each input (memory) \( h_i \) into each layer compute once scalar \( e_i \in [0, L] \)
- \( e_i \) is called expire-span (expiration time span)
- \( e_i = L \mathbf{\sigma}(w^\intercal h_i + b) \)
- The model slides over the text with time steps.
- \( t \) denotes time-step.
- if \( r_{ti} := e_i - (t-i) < 0 \) then memory input is forgotten
- For differentiability linearly phase-out attention output:
- \( m_{ti} := \max(0, \min(1, 1 + \frac{r_{ti}}{R})) \)
- \( a^\prime_{ti} := \frac{ m_{ti} a_{ti} }{ \sum_j m_{tj} a_{tj} } \)
Expire-Span Loss
- Penalize higher memory usage with auxiliary term
- \( \alpha > 0 \) is a compression parameter
- \( L_{total} = L_{task} + \alpha \sum_{i \in \lbrace 1, 2, …, L \rbrace} e_i / T \)
- randomly shorten memory for regularization
Results on Enwik8
- better performance, less memory, faster
- LM metric bits per byte (or character) = average negative log base-2 probability of the target label