Vaclav Kosar's face photo
Vaclav Kosar
Software And Machine Learning Blog

ELECTRA - How to Train BERT 4x Cheaper

Reducing training flops 4x by GAN-like discriminative task compared to RoBERTa-500K transformer model.
ELECTRA - How to Train BERT 4x Cheaper

Can you afford to fully train and retrain your own BERT language model? Training costs is important part of machine learning production as transformer language models get bigger. ELECTRA is also available on HuggingFace including a model for pre-training.

Why Is BERT Training Inefficient?

BERT model pre-training and fine-tuning

How To Improve?

  • How to get difficult enough task for all tokens?
  • ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators
    • ELECTRA = Efficiently Learning an Encoder that Classifies Token Replacements Accurately
    • Stanford & Google Brain
    • ICRL 2020, Not SoTA
  • Smaller generator and big discriminator
  • Jointly train the generator and discriminator
  • The generator is trained with masked language modeling (MLM)
  • For each masked position generator samples one token
  • The big model discriminates true or fake token
  • Not exactly GAN setup: Generator is trained for MLM

ELECTRA model generator discriminator pre-training diagram

The Architecture and Methods

  • Generator and discriminator same architecture
    • only embeddings or tokens and positional are shared
    • sharing more was not helping
  • Generator 2x - 4x smaller
    • bigger are not helping
    • compute more expensive
    • perhaps bigger too difficult task
  • Trained jointly otherwise discriminator fails to learn
    • otherwise, the discriminator fails to learn
    • generator selects harder cases
    • but must not be too much better than discriminator

ELECTRA model loss is sum of generator masked language modeling and discriminator loss

ELECTRA model generator size and GLUE benchmark performance

Results

  • Datasets:
    • GLUE: natural understanding benchmark
    • SQuAD: questions answering benchmark
  • RoBERTa = BERT with better training and dataset
    • longer training, bigger batches, more data
    • remove next sentence objective
    • train on longer sequences
    • dynamically changing masking pattern
  • XLNet = BERT with permutation language modelling
    • maximizes likelihood of the original sequence
    • compared to all other permutations
    • next-token prediction task
  • ELECTRA-400K on par with RoBERTa-500K with 4x less FLOPs

ELECTRA model performance on GLUE benchmark

ELECTRA model performance on SQuAD benchmark

Source of The Improvement

  • compared alternative tasks on GLUE score
  • results:
    • loss over all inputs is important
    • masking is worse than replacing tokens
TaskDescriptionGLUE score
BERTMLM with [MASK] token82.2
Replace MLMmasked tokens replaced with generated + LM82.4
Electra 15%Discriminator over 15% of the tokens82.4
All-Tokens MLMReplace MLM on all tokens + copy mechanism84.3
ElectraDiscriminator over all tokens85.0

Personal Speculations:

  • ELECTRA could be suitable for low-resource settings
  • ELECTRA training is like augmentation:
    • samples again from generator on each epoch

Follow up - MC-BERT

MC-BERT model extension of ELECTRA diagram

Follow Up - TEAMS

  • also contrastive
  • shares more weights

TEAMS model extension of ELECTRA diagram

04 Oct 2021