Vaclav Kosar's face photo
Vaclav Kosar
Software, Machine Learning, & Business

Expire-Span: Scaling Transformer by Forgetting

Reducing computational costs by differentiably dropping memorized embeddings from self-attention context.

If you need to recap transformer read: Transformer’s Self-Attention Mechanism Simplified.

Self-Attention Complexity

  • complexity is quadratic in sequence length \( O(L^2) \)
  • because we need to calculate \( L \times L \) attention matrix \( \mathbf{softmax}(\frac{QK^\intercal}{\sqrt{d}}) \)
  • but context size is crucial for some tasks e.g. character-level models
  • multiple approaches already exits
Attention Complexity
Attention Complexity (source).

Previous Approaches

Longformer self-attention patterns comparison
Longformer self-attention patterns comparison (source).

Auto-Regressive Transformers

  • based on previous predict the next
  • one-directional attention works best on LM tasks
  • all models below are auto-regressive

Transformer-XL

Transformer-XL uses memory of previously calculated results to increase span
Transformer-XL uses memory of previously calculated results to increase span (source)

Compressive Transformer

Compressive transformer
Compressive transformer (source).

Adaptive Span

  • learns to increase context length when needed
  • similar to Expire-Span
  • except predicts span length instead of memory forgetting

Expire-Span Attention

  • Facebook AI Paper: Not All Memories are Created Equal: Learning to Forget by Expiring
  • Source
  • uses memory ala Transformer-XL
  • has default minimal span of size \( K \)
  • \( L \) is maximum span
  • for each input (memory) \( h_i \) into each layer compute once scalar \( e_i \in [0, L] \)
  • \( e_i \) is called expire-span (expiration time span)
  • \( e_i = L \mathbf{\sigma}(w^\intercal h_i + b) \)
  • The model slides over the text with time steps.
  • \( t \) denotes time-step.
  • if \( r_{ti} := e_i - (t-i) < 0 \) then memory input is forgotten
  • For differentiability linearly phase-out attention output:
    • \( m_{ti} := \max(0, \min(1, 1 + \frac{r_{ti}}{R})) \)
    • \( a^\prime_{ti} := \frac{ m_{ti} a_{ti} }{ \sum_j m_{tj} a_{tj} } \)
Expire-span attention: For every sequence input h_i it calculates expiration time span e_i.
Expire-span attention: For every sequence input h_i it calculates expiration time span e_i. (source)

Expire-Span Loss

  • Penalize higher memory usage with auxiliary term
  • \( \alpha > 0 \) is a compression parameter
  • \( L_{total} = L_{task} + \alpha \sum_{i \in \lbrace 1, 2, …, L \rbrace} e_i / T \)
  • randomly shorten memory for regularization

Results on Enwik8

  • better performance, less memory, faster
  • LM metric bits per byte (or character) = average negative log base-2 probability of the target label
Expire-span, Trans-XL, Adapt-Span, Compressive transformer parameters count and bits-per-byte on Enwik8
Expire-span, Trans-XL, Adapt-Span, Compressive transformer bpb, memory, speed on Enwik8 (source)
Expire-span, Trans-XL, Adapt-Span, Compressive transformer performance in bits-per-byte (bps) vs memory size on Enwik8
Expire-span, Trans-XL, Adapt-Span, Compressive transformer performance in bits-per-byte (bps) vs memory size on Enwik8 (source)
Expire-span, Trans-XL, Adapt-Span, Compressive transformer parameters count and bits-per-byte on Enwik8
Expire-span, Trans-XL, Adapt-Span, Compressive transformer parameters count and bpb on Enwik8 (source)

Created on 24 Aug 2021.

Let's connect





Privacy Policy How many days left in this quarter? Twitter Bullet Points to Copy & Paste