DeepMind's RETRO Retrieval-Enhanced Transformer

Retrieval-Enhanced Language Model cross-attends trillions of tokens for SoTA on Wikitext103 and The Pile with 25x fewer parameters.
DeepMind's RETRO Retrieval-Enhanced Transformer
JS disabled! Watch DeepMind's RETRO Retrieval-Enhanced Transformer on Youtube
Watch video "DeepMind's RETRO Retrieval-Enhanced Transformer"

Other Retrieval Architectures

retrieval transformer comparison
retrieval transformer comparison

RETRO’s Training Dataset

MassiveText dataset composition table
MassiveText dataset composition table

RETRO’s Architecture

  • Frozen BERT kNN retriever on chunk level
  • differentiable encoder conditioned on query
  • chunked cross-attention with previous chunk retrieval set
  • ablations show retrieval helps
retriever transformer achitecture
retriever transformer achitecture

RETRO’s Retriever

  • database is key-value memory of chunks
  • each value is two consecutive chunks (128 tokens)
  • each key is the first chunk from its value (first 64 tokens)
  • each key is time-averaged BERT embedding of the first chunk
  • key-vectors stored in k-nearest neighbors (kNN using similarity) ScaNN db
  • db stores entire MassiveText train set during evaluation
    • training on 600B train subset
    • test set leakage into train set is controlled via a 13-gram Jaccard similarity
  • 1.7T token database queried in 10ms
  • retrieval is part of the input dataset pipeline
  • optimum number of neighbors between 2 and 40

RETRO’s Encoding Retrieved Neighbours

  • all retrieved values: 128 consecutive tokens
  • are first passed through a bi-directional transformer encoder
  • differentiably modulates retrieved chunks
  • conditioning on query-chunks via cross-attention
    • query-chunks hidden representations serves as key and value
    • encoded representations serve as queries
    • at the last layer before first cross-attention
  • output is called retrieval set

RETRO’s Chunked Cross-Attention

  • take previous chunk retrieval set to be autoregressive
  • add relative positional encodings to each retrieved
  • concatenate into time dimension
  • use hidden representation at the layer as query
  • cross-attend
retrieval transformer
retrieval transformer

RETRO’s Results

  • SoTA on Wikitext103 and Pile
  • on Pile with 7B params outperforms Jurassic-1 and Gopher
    • strongly outperforms on Github - repetitive dataset?
    • weakly outperforms on HackerNews
    • underperforms on Math - not in MassiveText, poor search?
  • comparable with GPT-3 when 25x less params
  • generates on-topic and coherent text likely thanks to long memories
  • underperforms specialized QA models
RETRO on Pile
RETRO on Pile
RETRO generated text keeps on topic thanks to longer sequences
RETRO generated text keeps on topic thanks to longer sequences
RETRO question answering results
RETRO question answering results
RETRO Wikitext103
RETRO Wikitext103

How You can Use RETRO Ideas?

  • freeze any pre-trained transformer
  • add and train chunked cross-attention and the encoder
  • tune number of neighbours between 2 and 40 to your model size
  • results should get close to training whole from scratch
  • see “Retro-fitting baseline models” section
  • Retro source code not published yet

SRU++ Model Speeds Up Transformer with Simple Recurrent Unit

Created on 29 Dec 2021. Updated on: 17 Sep 2022.
Thank you










About Vaclav Kosar How many days left in this quarter? Twitter Bullet Points to Copy & Paste Averaging Stopwatch Privacy Policy
Copyright © Vaclav Kosar. All rights reserved. Not investment, financial, medical, or any other advice. No guarantee of information accuracy.