Expire-Span: Scaling Transformer by Forgetting

Reducing computational costs by differentiably dropping memorized embeddings from self-attention context.
Expire-Span: Scaling Transformer by Forgetting
JS disabled! Watch Expire-Span: Scaling Transformer by Forgetting on Youtube
Watch video "Expire-Span: Scaling Transformer by Forgetting"

If you need to recap transformer read: Transformer’s Self-Attention Mechanism Simplified.

Self-Attention Complexity

  • complexity is quadratic in sequence length \( O(L^2) \)
  • because we need to calculate \( L \times L \) attention matrix \( \mathbf{softmax}(\frac{QK^\intercal}{\sqrt{d}}) \)
  • but context size is crucial for some tasks e.g. character-level models
  • multiple approaches already exits
Attention Complexity
Attention Complexity (source).

Previous Approaches

Longformer self-attention patterns comparison
Longformer self-attention patterns comparison (source).

Auto-Regressive Transformers

  • based on previous predict the next
  • one-directional attention works best on LM tasks
  • all models below are auto-regressive

Transformer-XL

Transformer-XL uses memory of previously calculated results to increase span
Transformer-XL uses memory of previously calculated results to increase span (source)

Compressive Transformer

Compressive transformer
Compressive transformer (source).

Adaptive Span

  • learns to increase context length when needed
  • similar to Expire-Span
  • except predicts span length instead of memory forgetting

Expire-Span Attention

  • Facebook AI Paper: Not All Memories are Created Equal: Learning to Forget by Expiring
  • Source
  • uses memory ala Transformer-XL
  • has default minimal span of size \( K \)
  • \( L \) is maximum span
  • for each input (memory) \( h_i \) into each layer compute once scalar \( e_i \in [0, L] \)
  • \( e_i \) is called expire-span (expiration time span)
  • \( e_i = L \mathbf{\sigma}(w^\intercal h_i + b) \)
  • The model slides over the text with time steps.
  • \( t \) denotes time-step.
  • if \( r_{ti} := e_i - (t-i) < 0 \) then memory input is forgotten
  • For differentiability linearly phase-out attention output:
    • \( m_{ti} := \max(0, \min(1, 1 + \frac{r_{ti}}{R})) \)
    • \( a^\prime_{ti} := \frac{ m_{ti} a_{ti} }{ \sum_j m_{tj} a_{tj} } \)
Expire-span attention: For every sequence input h_i it calculates expiration time span e_i.
Expire-span attention: For every sequence input h_i it calculates expiration time span e_i. (source)

Expire-Span Loss

  • Penalize higher memory usage with auxiliary term
  • \( \alpha > 0 \) is a compression parameter
  • \( L_{total} = L_{task} + \alpha \sum_{i \in \lbrace 1, 2, …, L \rbrace} e_i / T \)
  • randomly shorten memory for regularization

Results on Enwik8

  • better performance, less memory, faster
  • LM metric bits per byte (or character) = average negative log base-2 probability of the target label
Expire-span, Trans-XL, Adapt-Span, Compressive transformer parameters count and bits-per-byte on Enwik8
Expire-span, Trans-XL, Adapt-Span, Compressive transformer bpb, memory, speed on Enwik8 (source)
Expire-span, Trans-XL, Adapt-Span, Compressive transformer performance in bits-per-byte (bps) vs memory size on Enwik8
Expire-span, Trans-XL, Adapt-Span, Compressive transformer performance in bits-per-byte (bps) vs memory size on Enwik8 (source)
Expire-span, Trans-XL, Adapt-Span, Compressive transformer parameters count and bits-per-byte on Enwik8
Expire-span, Trans-XL, Adapt-Span, Compressive transformer parameters count and bpb on Enwik8 (source)

Created on 24 Aug 2021.
Thank you










About Vaclav Kosar How many days left in this quarter? Twitter Bullet Points to Copy & Paste Averaging Stopwatch Privacy Policy
Copyright © Vaclav Kosar. All rights reserved. Not investment, financial, medical, or any other advice. No guarantee of information accuracy.