Cross attention is:
- an attention mechanism in Transformer architecture that mixes two different embedding sequences
- the two sequences must have the same dimension
- the two sequences can be of different modalities (e.g. text, image, sound)
- one of the sequences defines the output length as it plays a role of a query input
- the other sequence then produces key and value input
Cross-attention Applications
- image-text classification with Perceiver
- machine translation: cross-attention helps decoder predict next token of the translated text
Cross-attention vs Self-attention
Except for inputs, cross-attention calculation is the same as self-attention. Cross-attention combines asymmetrically two separate embedding sequences of same dimension, in contrast self-attention input is a single embedding sequence. One of the sequences serves as a query input, while the other as a key and value inputs. Alternative cross-attention in SelfDoc, uses query and value from one sequence, and key from the other.
The feed forward layer is related to cross-attention, except the feed forward layer does use softmax and one of the input sequences is static. Augmenting Self-attention with Persistent Memory paper shows that Feed Forward layer calculation made the same as self-attention.
Cross-attention Algorithm
- Let us have embeddings (token) sequences S1 and S2
- Calculate Key and Value from sequence S1
- Calculate Queries from sequence S2
- Calculate attention matrix from Keys and Queries
- Apply Values to the attention matrix
- Output sequence has dimension and length of sequence S2
In an equation: \( \mathbf{softmax}((W_Q S_2) (W_K S_1)^\intercal) W_V S_1 \)
Cross-attention Alternatives
Feature-wise Linear Modulation Layer (FiLM is simpler alternative, which does not require the input to be a sequence and is linear complexity to calculate. Q-Transformer applies FiLM to a visual EfficientNet to condition with embeddings of textual instructions to predict Q-values.
Cross-attention Implementation
Have a look at CrossAttention implementation in Diffusers library, which can generate images with Stable Diffusion. In this case the cross-attention is used to condition transformers inside a UNet layer with a text prompt for image generation. The constructor shows, how we can also have different dimensions and if you step through with a debugger, you will also see the different sequence length between the two modalities .
class CrossAttention(nn.Module):
r"""
A cross attention layer.
Parameters:
query_dim (`int`): The number of channels in the query.
cross_attention_dim (`int`, *optional*):
The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
bias (`bool`, *optional*, defaults to False):
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
"""
In particular at this part, where you can see how query, key, and value interact. This is encoder-decoder architecture, so:
- keys and values are created from encoder hidden states,
- while queries are created from the decoder hidden states.
query = attn.to_q(hidden_states)
query = attn.head_to_batch_dim(query)
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
Cross-Attention in Popular Architectures
Cross-attention is widely used in encoder-decoder or multi-modality use cases.
Cross-Attention in Transformer Decoder
Cross-attention was described in the Transformer paper, but it was not given this name yet. Transformer decoding starts with full input sequence, but empty decoding sequence. Cross-attention introduces information from the input sequence to the layers of the decoder, such that it can predict the next output sequence token. The decoder then adds the token to the output sequence, and repeats this autoregressive process until the EOS token is generated.
Cross-Attention in Stable Diffusion
Stable Diffusion uses cross-attention for image generation to condition transformers with a text prompt inside the denoising U-Net layer.
Cross-Attention in Perceiver IO
Perceiver IO is a general-purpose multi-modal architecture that can handle wide variety of inputs as well as outputs. Perceiver can be applied to for example image-text classification. Perceiver IO uses cross-attention for merging:
- multimodal input sequences (e.g. image, text, audio) into a low dimensional latent sequence
- “output query” or “command” to decode the output value e.g. predict this masked word
Advantage of the Perceiver architecture is that in general you can work with very large inputs. Architecture Hierarchical Perceiver has ability to process even longer input sequences by splitting into subsequences and then merging them. Hierarchical Perceiver also learns the positional encodings with a separate training step with a reconstruction loss.
Cross-Attention in SelfDoc
In Selfdoc, cross-attention is integrated in a special way. First step of their Cross-Modality Encoder, instead uses value and query from sequence A and then key from the sequence B.
Other Cross-Attention Examples
- DeepMind’s RETRO Transformer uses cross-attention to incorporate the database retrived sequences
- Code example: HuggingFace BERT (key, value are from the encoder, while query is from the decoder)
- CrossVit - here only simplified cross-attention is used
- On the Strengths of Cross-Attention in Pretrained Transformers for Machine Translation