- next token (~word) prediction = autoregressive language model
- full name = Retrieval-Enhanced Transformer (RETRO)
- introduced in DeepMind’s Improving Language Models by Retrieving from Trillions of Tokens (2021), Deep Mind Blog
- retrieves from kNN database BERT-similar to the current text-chunk
- conditions on retrieved chunk and its continuation chunk
- so attends to previously encountered “future texts”
- SoTA on Wikitext103 and the Pile datasets
- Competitive on QA same perf GPT-3 with 25x less params
- model performs even when low train-test overlap
- retrieval reduces hallucinations and increases interpretability
- merges symbolic with deep learning similar to Dream Coder program learning
Other Retrieval Architectures
- historically inverted index matching TF-IDF and BM25
- latent topic modelling e.g. LDA (2003)
- edit-distance search for translation (2018)
- kNN-LM (2020)
- search context LM embedding in database
- linearly interpolate with LM predictions
- DPR (2020)
- trains one BERT (2017) for keys and one for values
- uses contrastive loss
- DeepMind’s RETRO in contrast uses
- longer sequences
- cross-attention allowing for multiple retrievals
- bigger database
RETRO’s Training Dataset
- 10-lingual MassiveText dataset
- SentencePiece tokenizer vocabulary of 128k tokens
- Retrieval database 1.75T tokens
- 1 token ~ 4 characters ~ 1 word
- Chucks are 64 token sequences
- database ~13B records?
- not retrieval from the same document during training
RETRO’s Architecture
- Frozen BERT kNN retriever on chunk level
- differentiable encoder conditioned on query
- chunked cross-attention with previous chunk retrieval set
- ablations show retrieval helps
RETRO’s Retriever
- database is key-value memory of chunks
- each value is two consecutive chunks (128 tokens)
- each key is the first chunk from its value (first 64 tokens)
- each key is time-averaged BERT embedding of the first chunk
- key-vectors stored in k-nearest neighbors (kNN using similarity) ScaNN db
- db stores entire MassiveText train set during evaluation
- training on 600B train subset
- test set leakage into train set is controlled via a 13-gram Jaccard similarity
- 1.7T token database queried in 10ms
- retrieval is part of the input dataset pipeline
- optimum number of neighbors between 2 and 40
RETRO’s Encoding Retrieved Neighbours
- all retrieved values: 128 consecutive tokens
- are first passed through a bi-directional transformer encoder
- differentiably modulates retrieved chunks
- conditioning on query-chunks via cross-attention
- query-chunks hidden representations serves as key and value
- encoded representations serve as queries
- at the last layer before first cross-attention
- output is called retrieval set
RETRO’s Chunked Cross-Attention
- take previous chunk retrieval set to be autoregressive
- add relative positional encodings to each retrieved
- concatenate into time dimension
- use hidden representation at the layer as query
- cross-attend
RETRO’s Results
- SoTA on Wikitext103 and Pile
- on Pile with 7B params outperforms Jurassic-1 and Gopher
- strongly outperforms on Github - repetitive dataset?
- weakly outperforms on HackerNews
- underperforms on Math - not in MassiveText, poor search?
- comparable with GPT-3 when 25x less params
- generates on-topic and coherent text likely thanks to long memories
- underperforms specialized QA models
How You can Use RETRO Ideas?
- freeze any pre-trained transformer
- add and train chunked cross-attention and the encoder
- tune number of neighbours between 2 and 40 to your model size
- results should get close to training whole from scratch
- see “Retro-fitting baseline models” section
- Retro source code not published yet
Read Next: Melting the Recurrence with Attention
SRU++ Model Speeds Up Transformer with Simple Recurrent Unit