Cross-Attention in Transformer Architecture

Merge two embedding sequences regardless of modality, e.g., image with text in Stable Diffusion U-Net with encoder-decoder attention.
Cross-Attention in Transformer Architecture
JS disabled! Watch Cross-Attention in Transformer Architecture on Youtube
Watch video "Cross-Attention in Transformer Architecture"

Cross attention is:

  • an attention mechanism in Transformer architecture that mixes two different embedding sequences
  • the two sequences must have the same dimension
  • the two sequences can be of different modalities (e.g. text, image, sound)
  • one of the sequences defines the output length as it plays a role of a query input
  • the other sequence then produces key and value input

Cross-attention Applications

Cross-attention vs Self-attention

Except for inputs, cross-attention calculation is the same as self-attention. Cross-attention combines asymmetrically two separate embedding sequences of same dimension, in contrast self-attention input is a single embedding sequence. One of the sequences serves as a query input, while the other as a key and value inputs. Alternative cross-attention in SelfDoc, uses query and value from one sequence, and key from the other.

The feed forward layer is related to cross-attention, except the feed forward layer does use softmax and one of the input sequences is static. Augmenting Self-attention with Persistent Memory paper shows that Feed Forward layer calculation made the same as self-attention.

cross-attention perceiver io detail
cross-attention perceiver io detail

Cross-attention Algorithm

  • Let us have embeddings (token) sequences S1 and S2
  • Calculate Key and Value from sequence S1
  • Calculate Queries from sequence S2
  • Calculate attention matrix from Keys and Queries
  • Apply Values to the attention matrix
  • Output sequence has dimension and length of sequence S2

In an equation: \( \mathbf{softmax}((W_Q S_2) (W_K S_1)^\intercal) W_V S_1 \)

Cross-attention Alternatives

Feature-wise Linear Modulation Layer (FiLM is simpler alternative, which does not require the input to be a sequence and is linear complexity to calculate. Q-Transformer applies FiLM to a visual EfficientNet to condition with embeddings of textual instructions to predict Q-values.

Cross-attention Implementation

Have a look at CrossAttention implementation in Diffusers library, which can generate images with Stable Diffusion. In this case the cross-attention is used to condition transformers inside a UNet layer with a text prompt for image generation. The constructor shows, how we can also have different dimensions and if you step through with a debugger, you will also see the different sequence length between the two modalities .

class CrossAttention(nn.Module):
    A cross attention layer.

        query_dim (`int`): The number of channels in the query.
        cross_attention_dim (`int`, *optional*):
            The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
        heads (`int`,  *optional*, defaults to 8): The number of heads to use for multi-head attention.
        dim_head (`int`,  *optional*, defaults to 64): The number of channels in each head.
        dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
        bias (`bool`, *optional*, defaults to False):
            Set to `True` for the query, key, and value linear layers to contain a bias parameter.

In particular at this part, where you can see how query, key, and value interact. This is encoder-decoder architecture, so query is created from encoder hidden states.

        query = attn.to_q(hidden_states)
        query = attn.head_to_batch_dim(query)

        encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)
        key = attn.head_to_batch_dim(key)
        value = attn.head_to_batch_dim(value)

        attention_probs = attn.get_attention_scores(query, key, attention_mask)
        hidden_states = torch.bmm(attention_probs, value)

Cross-attention is widely used in encoder-decoder or multi-modality use cases.

Cross-Attention in Transformer Decoder

Cross-attention was described in the Transformer paper, but it was not given this name yet. Transformer decoding starts with full input sequence, but empty decoding sequence. Cross-attention introduces information from the input sequence to the layers of the decoder, such that it can predict the next output sequence token. The decoder then adds the token to the output sequence, and repeats this autoregressive process until the EOS token is generated.

Cross-Attention in the Transformer decoder of Attention is All You Need paper
Cross-Attention in the Transformer decoder of Attention is All You Need paper

Cross-Attention in Stable Diffusion

Stable Diffusion uses cross-attention for image generation to condition transformers with a text prompt inside the denoising U-Net layer.

stable diffusion architecture with cross-attention
stable diffusion architecture with cross-attention

Cross-Attention in Perceiver IO

Perceiver IO is a general-purpose multi-modal architecture that can handle wide variety of inputs as well as outputs. Perceiver can be applied to for example image-text classification. Perceiver IO uses cross-attention for merging:

  • multimodal input sequences (e.g. image, text, audio) into a low dimensional latent sequence
  • “output query” or “command” to decode the output value e.g. predict this masked word
Perceiver IO architecture
Perceiver IO architecture

Advantage of the Perceiver architecture is that in general you can work with very large inputs. Architecture Hierarchical Perceiver has ability to process even longer input sequences by splitting into subsequences and then merging them. Hierarchical Perceiver also learns the positional encodings with a separate training step with a reconstruction loss.

Cross-Attention in SelfDoc

selfdoc cross-attention
selfdoc cross-attention

In Selfdoc, cross-attention is integrated in a special way. First step of their Cross-Modality Encoder, instead uses value and query from sequence A and then key from the sequence B.

Other Cross-Attention Examples

Created on 28 Dec 2021. Updated on: 30 Dec 2022.
Thank you

About Vaclav Kosar How many days left in this quarter? Twitter Bullet Points to Copy & Paste Averaging Stopwatch Privacy Policy
Copyright © Vaclav Kosar. All rights reserved. Not investment, financial, medical, or any other advice. No guarantee of information accuracy.