DeepMind's RETRO Retrieval-Enhanced Transformer

Retrieval-Enhanced Language Model cross-attends trillions of tokens for SoTA on Wikitext103 and The Pile with 25x fewer parameters.

Other Retrieval Architectures

retrieval transformer comparison

RETRO’s Training Dataset

MassiveText dataset composition table

RETRO’s Architecture

  • Frozen BERT kNN retriever on chunk level
  • differentiable encoder conditioned on query
  • chunked cross-attention with previous chunk retrieval set
  • ablations show retrieval helps

retriever transformer achitecture

RETRO’s Retriever

  • database is key-value memory of chunks
  • each value is two consecutive chunks (128 tokens)
  • each key is the first chunk from its value (first 64 tokens)
  • each key is time-averaged BERT embedding of the first chunk
  • key-vectors stored in k-nearest neighbors (kNN using similarity) ScaNN db
  • db stores entire MassiveText train set during evaluation
    • training on 600B train subset
    • test set leakage into train set is controlled via a 13-gram Jaccard similarity
  • 1.7T token database queried in 10ms
  • retrieval is part of the input dataset pipeline
  • optimum number of neighbors between 2 and 40

RETRO’s Encoding Retrieved Neighbours

  • all retrieved values: 128 consecutive tokens
  • are first passed through a bi-directional transformer encoder
  • differentiably modulates retrieved chunks
  • conditioning on query-chunks via cross-attention
    • query-chunks hidden representations serves as key and value
    • encoded representations serve as queries
    • at the last layer before first cross-attention
  • output is called retrieval set

RETRO’s Chunked Cross-Attention

  • take previous chunk retrieval set to be autoregressive
  • add relative positional encodings to each retrieved
  • concatenate into time dimension
  • use hidden representation at the layer as query
  • cross-attend

retrieval transformer

RETRO’s Results

  • SoTA on Wikitext103 and Pile
  • on Pile with 7B params outperforms Jurassic-1 and Gopher
    • strongly outperforms on Github - repetitive dataset?
    • weakly outperforms on HackerNews
    • underperforms on Math - not in MassiveText, poor search?
  • comparable with GPT-3 when 25x less params
  • generates on-topic and coherent text likely thanks to long memories
  • underperforms specialized QA models

RETRO on Pile

RETRO generated text keeps on topic thanks to longer sequences

RETRO question answering results

RETRO Wikitext103

How You can Use RETRO Ideas?

  • freeze any pre-trained transformer
  • add and train chunked cross-attention and the encoder
  • tune number of neighbours between 2 and 40 to your model size
  • results should get close to training whole from scratch
  • see “Retro-fitting baseline models” section
  • Retro source code not published yet

SRU++ Model Speeds Up Transformer with Simple Recurrent Unit

Created on 29 Dec 2021. Updated on: 17 Sep 2022.
Thank you

Ask or Report A Mistake


Let's connect








Privacy Policy How many days left in this quarter? Twitter Bullet Points to Copy & Paste About Vaclav Kosar