Vaclav Kosar's face photo
Vaclav Kosar
Software And Machine Learning Blog

Expire-Span: Scaling Transformer by Forgetting

Reducing computational costs by differentiably dropping memorized embeddings from self-attention context.
Expire-Span: Scaling Transformer by Forgetting

Self-Attention Simplified Recap

  • input \( X \in \mathbf{R}^{L \times d} \) is a sequence of embeddings of dimension \( d \) of length \( L \)
  • output \( Y \in \mathbf{R}^{L \times d} \) has the same shape as input
  • project \( X \) into 3 matrices of the same shape
    • query \( X^Q := W^Q X \),
    • key \( X^K := W^K X \)
    • value \( X^V := W^V X \)
  • calculate “soft sequence-wise nearest neighbor search”
    • “search” all \( L \times L \) combinations of sequence elements of \( X^K \) and \( X^Q \)
    • for each sequence position \( m \): output more \( X^V_{o} \) when \( X^K_o \) is more similar to \( X^Q_{m} \)
    • in pseudo-code: \( Y = \mathrm{matmul}_L(\mathrm{softmax}_L(\mathrm{matmul_d}(X_q, X_k^\intercal)), X_v) \)
    • in equation: \( Y = \mathbf{softmax}(QK^\intercal)V \)
  • More details in Attention Is All You Need paper
Scaled Dot-Product Attention
Scaled Dot-Product Attention (source).

Self-Attention Complexity

  • complexity is quadratic in sequence length \( O(L^2) \)
  • because we need to calculate \( L \times L \) attention matrix \( \mathbf{softmax}(\frac{QK^\intercal}{\sqrt{d}}) \)
  • but context size is crucial for some tasks e.g. character-level models
  • multiple approaches already exits
Attention Complexity
Attention Complexity (source).

Previous Approaches

Longformer self-attention patterns comparison
Longformer self-attention patterns comparison (source).

Auto-Regressive Transformers

  • based on previous predict the next
  • one-directional attention works best on LM tasks
  • all models below are auto-regressive

Transformer-XL

  • Transformer-XL (Extra Long): Attentive Language Models Beyond a Fixed-Length Context
  • first self-attention model better than RNN on both char & word level LM
  • auto-regressive: attention is backward only not bi-directional like BERT
  • instead of recalculating embeddings for each fixed span
  • rather memorize previous results
  • because previous results saw context not available in the next step
  • this effectively increases context
  • positional embeddings must be relative
Transformer-XL uses memory of previously calculated results to increase span
Transformer-XL uses memory of previously calculated results to increase span (source)

Compressive Transformer

Compressive transformer
Compressive transformer (source).

Adaptive Span

  • learns to increase context length when needed
  • similar to Expire-Span
  • except predicts span length instead of memory forgetting

Expire-Span Attention

  • Facebook AI Paper: Not All Memories are Created Equal: Learning to Forget by Expiring
  • Source
  • uses memory ala Transformer-XL
  • has default minimal span of size \( K \)
  • \( L \) is maximum span
  • for each input (memory) \( h_i \) into each layer compute once scalar \( e_i \in [0, L] \)
  • \( e_i \) is called expire-span (expiration time span)
  • \( e_i = L \mathbf{\sigma}(w^\intercal h_i + b) \)
  • The model slides over the text with time steps.
  • \( t \) denotes time-step.
  • if \( r_{ti} := e_i - (t-i) < 0 \) then memory input is forgotten
  • For differentiability linearly phase-out attention output:
    • \( m_{ti} := \max(0, \min(1, 1 + \frac{r_{ti}}{R})) \)
    • \( a^\prime_{ti} := \frac{ m_{ti} a_{ti} }{ \sum_j m_{tj} a_{tj} } \)
Expire-span attention: For every sequence input h_i it calculates expiration time span e_i.
Expire-span attention: For every sequence input h_i it calculates expiration time span e_i. (source)

Expire-Span Loss

  • Penalize higher memory usage with auxiliary term
  • \( \alpha > 0 \) is a compression parameter
  • \( L_{total} = L_{task} + \alpha \sum_{i \in \lbrace 1, 2, …, L \rbrace} e_i / T \)
  • randomly shorten memory for regularization

Results on Enwik8

  • better performance, less memory, faster
  • LM metric bits per byte (or character) = average negative log base-2 probability of the target label
Expire-span, Trans-XL, Adapt-Span, Compressive transformer parameters count and bits-per-byte on Enwik8
Expire-span, Trans-XL, Adapt-Span, Compressive transformer bpb, memory, speed on Enwik8 (source)
Expire-span, Trans-XL, Adapt-Span, Compressive transformer performance in bits-per-byte (bps) vs memory size on Enwik8
Expire-span, Trans-XL, Adapt-Span, Compressive transformer performance in bits-per-byte (bps) vs memory size on Enwik8 (source)
Expire-span, Trans-XL, Adapt-Span, Compressive transformer parameters count and bits-per-byte on Enwik8
Expire-span, Trans-XL, Adapt-Span, Compressive transformer parameters count and bpb on Enwik8 (source)

24 Aug 2021