# SRU++ Model Speeds Up Transformer with Simple Recurrent Unit

Reducing compute by combining RNN with self-attention from Transformer architecture.

Here are my notes on SRU, and thanks to the paper authors and Yannic’s Discord meetup discussions.

## Summary:

### Self-Attention vs Recurrent Layer

• attention vs recurrence = graph vs sequence = Transformer vs LSTM
• attention connects across entire sequence as fully connected graph
• example graph task: dependency parse is a syntactic graph over the word sequence
• RRNs keeps information from previous states in a state vector as a memory
• RRNs not parallelizable in time dimension as future steps depend on the past
• RRNs have difficulty accessing long time ago information
• RNNs handle repetition better, can use CTC Loss e.g. for OCR
• SRU++ uses both attention and recurrence

### How SRU helps parallelization?

• while the state computation of SRU is time-dependent, each state dimension is independent
• time step: $$t$$, input vector: $$x_t$$, (inner) state $$c_t$$
• (inner) forget gate $$f_t := \sigma(W_f x_t + V_f c_{t-1} + b_f)$$
• problem: both $$c_t, f_t$$ depend on all dimensions $$c_{t-1}$$
• due to matrix-multiplication: $$V_f c_{t-1}$$
• solution: pointwise (Hadamard) multiplication $$v_f \odot c_{t-1}$$
• gives parallel computation $$c_t, f_t$$
• state $$c_t := f_t \odot c_{t-1} + (1 - f_t) \odot W x_t$$
• all $$W, V, b$$ are trained

### Highway Network Component

• highway network more dynamic than a skip connection
• reset gate weights output skip connection
• defined as $$r_t := \sigma( W_r x_t + v_r \odot c_{t-1} + b_r )$$
• combines the state with the input
• then used for output $$h_t$$ that allows gradient flow
• output (hidden) vector: $$h_t := r_t \odot c_t + (1 - r_t) \odot x_t$$

### All Equations

• $$f_t := \sigma( W_f x_t + v_f \odot c_{t-1} + b_f)$$
• $$r_t : = \sigma( W_r x_t + v_r \odot c_{t-1} + b_r )$$
• $$c_t := f_t \odot c_{t-1} + (1-f_t) \odot (W x_t)$$
• $$h_t : = r_t \odot c_t + (1-r_t) \odot x_t$$

Can also decompose into primitives:

• $$\mathrm{Way}(a, b, g, W) := g \odot a + (1 - g) \odot (W b)$$
• $$\mathrm{Gate}(a, b, W, v, w) := \sigma(W b + v \odot a + w)$$

### Similarity to LSTM

• equations are similar to LSTM
• but output gate, input gate are replaced with reset gate
• highway network
• SRU equations:
• $$f_t := \sigma( W_f x_t + v_f \odot c_{t-1} + b_f)$$
• $$r_t : = \sigma( W_r x_t + v_r \odot c_{t-1} + b_r )$$
• $$c_t := f_t \odot c_{t-1} + (1-f_t) \odot (W x_t)$$
• $$h_t : = r_t \odot c_t + (1-r_t) \odot x_t$$
• LSTM equations:
• $$f_t = \sigma_g (W_f x_t + U_f c_{t-1} + b_f )$$
• $$i_t = \sigma_g (W_i x_t + U_i c_{t-1} + b_i )$$
• $$o_t = \sigma_g (W_o x_t + U_o c_{t-1} + b_o )$$
• $$c_t = f_t \odot c_{t-1} + i_t \odot \sigma_c (W_c x_t + b_c)$$
• $$h_t = o_t \odot \sigma_h(c_t)$$

### CUDA kernels

• CUDA kernels are C++ functions executed N times by N CUDA threads
// Kernel definition
__global__ void VecAdd(float* A, float* B, float* C)
{
C[i] = A[i] + B[i];
}

int main()
{
...
// Kernel invocation with N threads
...
}


### Parallel Implementation

• single matrix multiplication $$U = (W, W_f, W_r) x_t$$
• point-wise operations are in a single fused CUDA kernel
• and parallelize across each hidden state dimension
• computation still sequential in time dimension
• complexity O(L · B · d)

### SRU Results

• On its own SRU slightly outperforms to QRNN (Quasi-RNN)
• SRU “replaces convolutions” in QRNN and KNN with more recurrent connections
• both SRU and QRNN similar speed
• 5 - 9x speed-up over cuDNN-optimized LSTM on classification and question answering datasets

## SRU++: Attention with SRU

### SRU++ Layer

• SRU++ is SRU with self-attention instead of $$(W, W_f, W_r) x$$
• Attention
• no positional encodings
• operates on dim 512 instead of 2048 “projection trick”
• residual connection both on attention and SRU
• layer normalization after attention block
• attention help significantly
• but needed only in every k-th layer e.g. every 5th

### Datasets

#### ENWIK8 (Hutter, 2006)

• is a character-level language modeling dataset consisting of 100M tokens taken from Wikipedia.
• The vocabulary size of this dataset about 200k.
• BPC is bits-per-character

#### WIKI-103 (Merity et al., 2017)

• is a wordlevel language modeling dataset.
• 100M tokens extracted from Wikipedia
• vocabulary of 260K tokens

### Results

• PPL = perplexity
• attention helps the most in the last layers
• maybe first layers learn local features
• which attention then uses
• outperforms Transformer-XL baseline by -3% BPC
• if larger context, then even lower BPC

#### How Often To Include Attention?

• 1 attention-SRU every 10 layers

#### Max Performance Enwik8

• maximum performance comparison
• larger model d = 3072, base model 4096
• context length train = 1024, eval 3072
• SoTA enwik8, but not on Wiki-103

#### Max Performance Wiki-103

• On par with Compressive memory, worse than kNN-LM, Routing Transformer

## Terraformer

• Uses SRU, but not covered here

Created on 26 Feb 2022. Updated on: 05 Jun 2022.