Vaclav Kosar's face photo
Vaclav Kosar
Software, Machine Learning, & Business

SRU++ Model Speeds Up Transformer with Simple Recurrent Unit

Reducing compute by combining RNN with self-attention from Transformer architecture.

Here are my notes on SRU, and thanks to the paper authors and Yannic’s Discord meetup discussions.

Summary:

Attention and Recurrence

  • attention vs recurrence = graph vs sequence
  • attention connects across entire sequence as fully connected graph
  • recurrence keeps information from previous states in a state vector
  • original recurrent LSTM is less parallelizable than Transformer
    • future steps in LSTM depend on the past and is not parallelizable

Dependency parsing and sequence from Standford Speech and Language Processing Daniel Jurafsky & James H. Martin

How SRU helps parallelization?

  • while the state computation of SRU is time-dependent, each state dimension is independent
  • time step: \( t \), input vector: \( x_t \), (inner) state \( c_t \)
  • (inner) forget gate \( f_t := \sigma(W_f x_t + V_f c_{t-1} + b_f) \)
    • problem: both \( c_t, f_t \) depend on all dimensions \( c_{t-1} \)
    • due to matrix-multiplication: \( V_f c_{t-1} \)
    • solution: pointwise (Hadamard) multiplication \( v_f \odot c_{t-1} \)
    • gives parallel computation \( c_t, f_t \)
  • state \( c_t := f_t \odot c_{t-1} + (1 - f_t) \odot W x_t \)
  • all \( W, V, b \) are trained

Highway Network Component

  • highway network more dynamic than a skip connection
    • provides regulated gradient flow
  • reset gate weights output skip connection
    • defined as \( r_t := \sigma( W_r x_t + v_r \odot c_{t-1} + b_r ) \)
    • combines the state with the input
    • then used for output \( h_t \) that allows gradient flow
  • output (hidden) vector: \( h_t := r_t \odot c_t + (1 - r_t) \odot x_t \)

All Equations

  • \( f_t := \sigma( W_f x_t + v_f \odot c_{t-1} + b_f) \)
  • \( r_t : = \sigma( W_r x_t + v_r \odot c_{t-1} + b_r ) \)
  • \( c_t := f_t \odot c_{t-1} + (1-f_t) \odot (W x_t) \)
  • \( h_t : = r_t \odot c_t + (1-r_t) \odot x_t \)

Can also decompose into primitives:

  • \( \mathrm{Way}(a, b, g, W) := g \odot a + (1 - g) \odot (W b) \)
  • \( \mathrm{Gate}(a, b, W, v, w) := \sigma(W b + v \odot a + w) \)

Simple Recurrent Unit diagram

Similarity to LSTM

  • equations are similar to LSTM
  • but output gate, input gate are replaced with reset gate
    • highway network
  • SRU equations:
    • \( f_t := \sigma( W_f x_t + v_f \odot c_{t-1} + b_f) \)
    • \( r_t : = \sigma( W_r x_t + v_r \odot c_{t-1} + b_r ) \)
    • \( c_t := f_t \odot c_{t-1} + (1-f_t) \odot (W x_t) \)
    • \( h_t : = r_t \odot c_t + (1-r_t) \odot x_t \)
  • LSTM equations:
    • \( f_t = \sigma_g (W_f x_t + U_f c_{t-1} + b_f ) \)
    • \( i_t = \sigma_g (W_i x_t + U_i c_{t-1} + b_i ) \)
    • \( o_t = \sigma_g (W_o x_t + U_o c_{t-1} + b_o ) \)
    • \( c_t = f_t \odot c_{t-1} + i_t \odot \sigma_c (W_c x_t + b_c) \)
    • \( h_t = o_t \odot \sigma_h(c_t) \)

GPU vs CPU

From Nvidia: GPU vs CPU in CUDA documentation

CUDA kernels

  • CUDA kernels are C++ functions executed N times by N CUDA threads
// Kernel definition
__global__ void VecAdd(float* A, float* B, float* C)
{
    int i = threadIdx.x;
    C[i] = A[i] + B[i];
}

int main()
{
    ...
    // Kernel invocation with N threads
    VecAdd<<<1, N>>>(A, B, C);
    ...
}

Parallel Implementation

  • single matrix multiplication \( U = (W, W_f, W_r) x_t \)
  • point-wise operations are in a single fused CUDA kernel
  • and parallelize across each hidden state dimension
  • computation still sequential in time dimension
  • complexity O(L · B · d)

SRU Results

  • On its own SRU slightly outperforms to QRNN (Quasi-RNN)
    • SRU “replaces convolutions” in QRNN and KNN with more recurrent connections
  • both SRU and QRNN similar speed
  • 5 - 9x speed-up over cuDNN-optimized LSTM on classification and question answering datasets

SRU results on enwik8

SRU++: Attention with SRU

SRU++ Simple Recurrent Unit on Enwik8 bits per character

SRU++ Layer

  • SRU++ is SRU with self-attention instead of \( (W, W_f, W_r) x \)
  • Attention
    • no positional encodings
    • operates on dim 512 instead of 2048 “projection trick”
    • residual connection both on attention and SRU
    • layer normalization after attention block
  • attention help significantly
    • but needed only in every k-th layer e.g. every 5th

SRU++ diagram - Simple Recurrent Unit with attention

Datasets

ENWIK8 (Hutter, 2006)

  • is a character-level language modeling dataset consisting of 100M tokens taken from Wikipedia.
  • The vocabulary size of this dataset about 200k.
  • BPC is bits-per-character

WIKI-103 (Merity et al., 2017)

  • is a wordlevel language modeling dataset.
  • 100M tokens extracted from Wikipedia
  • vocabulary of 260K tokens

Results

  • PPL = perplexity
  • attention helps the most in the last layers
    • maybe first layers learn local features
    • which attention then uses
  • outperforms Transformer-XL baseline by -3% BPC
  • if larger context, then even lower BPC

Fair Comparison to Transformer-XL

SRU++ comparison to Trans-XL

How Often To Include Attention?

  • 1 attention-SRU every 10 layers

SRU++ attention every k layers

Max Performance Enwik8

  • maximum performance comparison
  • larger model d = 3072, base model 4096
  • context length train = 1024, eval 3072
  • SoTA enwik8, but not on Wiki-103

Comparison with top-performing modesl on enwik8 dataset

Max Performance Wiki-103

  • On par with Compressive memory, worse than kNN-LM, Routing Transformer

SRU++ WIKI-103 results Routing Transformer

Speed Comparison

SRU++ inference speed

Terraformer

  • Uses SRU, but not covered here

Created on 26 Feb 2022. Updated on: 19 Apr 2022.

Let's connect





Privacy Policy How many days left in this quarter? Twitter Bullet Points to Copy & Paste