# Expire-Span: Scaling Transformer by Forgetting

Reducing computational costs by differentiably dropping memorized embeddings from self-attention context. Watch Expire-Span: Scaling Transformer by Forgetting on Youtube

If you need to recap transformer read: Transformer’s Self-Attention Mechanism Simplified.

## Self-Attention Complexity

• complexity is quadratic in sequence length $$O(L^2)$$
• because we need to calculate $$L \times L$$ attention matrix $$\mathbf{softmax}(\frac{QK^\intercal}{\sqrt{d}})$$
• but context size is crucial for some tasks e.g. character-level models

## Auto-Regressive Transformers

• based on previous predict the next
• one-directional attention works best on LM tasks
• all models below are auto-regressive

## Transformer-XL Transformer-XL uses memory of previously calculated results to increase span (source)

## Compressive Transformer

• learns to increase context length when needed
• similar to Expire-Span
• except predicts span length instead of memory forgetting

## Expire-Span Attention

• Facebook AI Paper: Not All Memories are Created Equal: Learning to Forget by Expiring
• Source
• uses memory ala Transformer-XL
• has default minimal span of size $$K$$
• $$L$$ is maximum span
• for each input (memory) $$h_i$$ into each layer compute once scalar $$e_i \in [0, L]$$
• $$e_i$$ is called expire-span (expiration time span)
• $$e_i = L \mathbf{\sigma}(w^\intercal h_i + b)$$
• The model slides over the text with time steps.
• $$t$$ denotes time-step.
• if $$r_{ti} := e_i - (t-i) < 0$$ then memory input is forgotten
• For differentiability linearly phase-out attention output:
• $$m_{ti} := \max(0, \min(1, 1 + \frac{r_{ti}}{R}))$$
• $$a^\prime_{ti} := \frac{ m_{ti} a_{ti} }{ \sum_j m_{tj} a_{tj} }$$ Expire-span attention: For every sequence input h_i it calculates expiration time span e_i. (source)

## Expire-Span Loss

• Penalize higher memory usage with auxiliary term
• $$\alpha > 0$$ is a compression parameter
• $$L_{total} = L_{task} + \alpha \sum_{i \in \lbrace 1, 2, …, L \rbrace} e_i / T$$
• randomly shorten memory for regularization

## Results on Enwik8

• better performance, less memory, faster
• LM metric bits per byte (or character) = average negative log base-2 probability of the target label

Created on 24 Aug 2021. 