Efficient Transformers
Published:
A survey on efficient transformers for long documents.
Dear Reader, There are around eighteen transformer variants summarized in this writeup. Kindly note that the figures and tables are directly taken from the original papers (sometimes their slides presented at conferences) unless specified otherwise. I have followed the survey paper below and the original papers to compile these notes. These are only to keep track of what I am reading, and no form of plagiarism is intended. I have tried my best to cite all resources, but in case I have forgotten something, please feel free to write to me at singh_shruti@iitgn.ac.in
Notes from:
- Efficient Transformers: A Survery by Tay et al. 2020
- Beyond Paragraphs: NLP for Long Sequences (NAACL 2021 Tutorial)
Outline of the writeup
Transformer Components:
- The input to a transformer model is a tensor of size $R^B$x$R^N$, where $B$ is the batch size and $N$ is the sequence length.
The input passes through an embedding layer, that converts tokens into $d$ dimensional embeddings, i.e. $R^B$x$R^N$x$R^D$. This new tensor is composed with Positional encodings, then passed through a multi-headed self-attention module. The input and output of the multi-headed self-attention module are connected by residual connectors and layer normalization layer (directly and using a skip connection as well). The multi-headed self-attention module is then passed through a two-layered feed forward network as well as the layer normalization. The output of the feed forward network as well is fed to the layer normalization.
Transformer Mode
The transformer can be used in three modes: 1. Encoder only (e.g. for classification): There is no restriction on the self attention to be causal, i.e. dependent solely on the past and present tokens. 2. Decoder only (e.g. for language modelling): 3. Encoder Decoder (e.g. for machine translation): Contain standard self attention modules in both the enc and decoder, and enc-dec cross-attention as well. The encoder and encoder-decoder cross-attention can be non-causal, but the decoder self attention must be causal.
Efficient Transformer Models Categorization
Efficient transformer models either approximate the quadratic-cost attention matrix, or use segment-based recurrence. Each method applies some notion of sparsity to the otherwise dense attention mechanism.
1. Fixed Patterns (FP): Sparsify the self-attention matrix by limiting the field of view to fixed predefined patterns such as local windows and block patterns of fixed strides. * Blockwise Patterns: Chunk the input sequence into fixed blocks, which reduces the complexity from $N^2$ to $B^2$, with $B « N$. * Strided Patterns: Strided attention patterns refers to only attending at fixed intervals. E.g.: Sparse Transformer, Longformer employ strided or dilated windows. * Compressed Patterns: Use some pooling operator to down-sample the sequence length to be a form of fixed pattern. E.g.: Compressed Attention uses strided convolution to reduce the sequence length. 2. Combination of Patterns: The key idea here is to improve coverage by combining two or more distinct access patterns. Eg: Sparse Transformer combines strided and local attention by assigning half of its heads to pattern. Axial Transformer, given a high dimensional tensor as input, applies a sequence of self-attention computations each along a single axis of the input tensor. The combination reduces memory complexity, but combination of multiple patterns also improves the overall coverage of the self-attention mechanism. 3. Learnable Patterns: These methods learn a notion of token relevance and then assign tokens to buckets or clusters.
Work | Technique |
---|---|
Reformer (Kaitev et al. 2020) | Hash-based similarity to cluster tokens into chunks. |
Routing Transformer (Roy et al. 2020) | Apply online k-means clustering on tokens. |
Sinkhorn Sorting Network (Tay et al. 2020b) | Expose sparsity in attention network by learning to sort blocks of the input sequence |
The key idea is to exploit fixed patterns, however these methods learn a sorting/clutering on the input tokens to enable a more optimal global view of the sequence while also maintiang the efficiency benefits of fixed pattern approaches.
4. Memory: Use a side memory module. A common technique is to use a global memory which is able to access the entire sequence. With a limited numver of memory, we are able to perform a preliminiary pooling operation of the input sequence. Some examples are: ETC, Longformer, Set Transformerss, etc. 5. Low-Rank Methods: Efficiency can also be improved by low-rank approximations of the NxN self-attention matrix. E.g.: The Linformer projects the length dimension of keys and values to a lower dimension (say some k from N). So the NxN matrix is not decomposed to Nxk. 6. Kernels: Use kernels to avoid the explicit NxN computation. 7. Recurrence: The blocks in the blockwise method can be connected via recurrence.
Detailed Description
SUMMARY: Detailed summary of each are provided below the summary table.
Memory Constrained Transformer
ICLR 2018 Main contributions are (i) localizing the attention span, and (ii) using memory compresed attention.
Localizing attention span is done by dividing the input sequence into block of similar length, and run self-attention within each block independently. As a result, the number of activations scales linearly with the input length. Memory-compressed attention tries to reduce the number of keys and queries by using a strided convolution, while queries remain the same. The size of the attention matrix is reduced based on the kernel and stride of the convolution. Memory-compressed attention lets the model exchange information globally instead of just local attention.
Complexity: For block size $b$, the computational and memory cost of self-attention is $O(b^2)$. There would be $n/b$ blocks, so cost of local attention is $n/b * b^2$, i.e. $O(nb)$. For memory-compressed attention, if the convolution of kkernel and stride $k$ is used, then the complexity is $O(n * n/k)$.
Image Transformer
Similar to CNN, restrict the field of self-attention to local neighbourhoods. This helps the model to scale up to process large batch sizes while keepign the likelihood loss tractable. Uses the encoder-decoder, where the encoder generates a contextualized representations for every pixel-channel in the inputs and the decoder autoregressively generates one channel per pixel at each time step.
Local Attention Span: The input is partitioned into query blocks and associated memory blocks. For all queries from the same query block, the model attends to the same memory block. The two types of local attention to choose the local query blocks and their associated memeory block neighbourhood: (i) 1-d Local Attention, and (ii) 2-d Local Attention.
For 1-d local attention, the image is flattened in the raster order and partitioned into non-overlapping query blocks $Q$ of length $l_q$, and for each query block, a memory block $M$ is build from the same pixels in $Q$ as well as a fixed number of pixels $l_m$. (Shown in the above figure part (a).) In the caes of 2-d attention, the image is divided into multiple non-overlapping rectangular blocks, that are of length $l_q$ = $w_q$ x $h_q$. Here, the memory block extends the query block to the top, left, and right by $h_m$ and $w_m$ pixels. This means $l_m$ = ($w_q$x$q_h$) + 2x($h_m$x$w_m$).
The shape of the attention matrix is $l_q$x$m$, where $l_q$ is the chosen length for the query blocks, and the length of the memory blocks is $l_q$+$l_m$. Since the memory blocks do not overlap, overall $n$x$l_q$ attention matrices are computed. Computation and memory complexity is $O(n.m)$.
Set Transformer
ICML 2019 This work adapts the Transformer model for set-input. The input is a set of features, and the output is some function of the set, and invariant to the permutation/ordering of the input features. Leverages attention to capture the interactions between the elements of the input set. The following components are used:
- $MAB(X, Y)$ = LayerNorm$(H+rFF(H))$
- $H$ = LayerNorm($X$ + MultiheadAttention$(X, Y, Y))$
- $SAB(X)$ = $MAB(X, X)$
- $ISAB_m(X)$ = $MAB(X, MAB(I_m, X))$
- $PMA_k(X)$ = $MAB(S_k, rFF(X))$
$rFF$ is a parameterized feed forward layer that operates on each row of the input matrix separately. The input $X$ \belong to $R^{Nd}$. $I_m$ $\belongs$ $R^{kd}$, is the m trainable d-dimensional inducing points (similar to the memory constrained transformer). $S_k$ \belong $R^{kd}$ is the k trainable d-dimensional seed vectors. $ISAB$ and $SAB$ are permutation equivariant, i.e. if the input is permuted in some way, then the corresponding output is permuted in exactly the same way. Set Transformer Encoder: N layers of either SAB or ISAB. N usually set to 2. Set Transformer Decoder: $rFF(SAB(PMA_k(X)))$
Sparse Transformer
Link Introduce sparse factorizations of the attention matrix which reduces time and memory complexity to $O(n\sqrt{n})$ (which is a reduction from the quadractic time complexity). The dense attention matrix is sparsified, by computing the attention on a sparse number of $q_i$, $k_j$ pairs. Uses fixed attention patterns defined by strides and local neighbourhood.
Local Attention Heads: Half of the heads are dedicated to local attention. Attention is computed only if the query and the key are in the same block. $A_{ij}$ = {$Q_{i}$ $K_{j}^{T}$; if floor(j/N)==floor(i/N); 0 otherwise}
Strided Attention Pattern: The other half heads are dedicated to a fixed strided patterns (also referred to as strides). $A_{ij}$ = {$Q_{i}$ $K_{j}^{T}$; if (i-j) mod N == 0; 0 otherwise}
It is importatnt to note that the parameter cost in this model is still the same, i.e. $O(n^2)$ since the Q, K, and V from the original model are still retained. The memory complexity of the attention layer is however reduced from $O(n^2)$ to $O(n log n)$.
Restriction: Implementation requires sparse GPU kernels to implement specific block-sparse variant of matrix-matrix multiplication. Even more difficult to implement on other hardware such as TPU.
Axial Transformer
(Link)[https://arxiv.org/pdf/1912.12180.pdf]
Instead of applying attention to a flattened input, multiple attentions are applied, each along a single axis of the input tensor. Each attention utilises information from a single axis, and information from other axes is kept independent. This helps to significantly reduce the cost as the length of a single axis much smaller than the total number of elements.
For a d-dimensional tensor where $N$ = $S^d$, there would be $S^{d-1}$ sequences of length $S$, so the complexity would be $O$($S^{d-1}$ x $S^2$) = $O$($N^{d-1/d}$x$N^{2/d}$). Over the regular attention cost of $N^2$, this is an improvement of $O$($N^{d-1/d}$).
Axial attention over axis $k$, denoted by Attention$_k$(x), can be implemented by tranposing all axes except $k$ to the batch axis, and then undoing the transpose once attention is computed. This provides an advantage as a custom kernel is not required to implement this. So in the case of a 2-D image, Attention$_1$ and Attention$_2$ are called column and row attention.
Masked attention is the causal version, where the component $i$ of the MaskedAttention$_k$(x), depends only on components 1,2,…,i of x along axis k.
Longformer
Variant of the Sparse Transformer. The key contribution is Dilated Sliding Window, which enables better long-range coverage without sacrificing sparsity. The model gradually increases the receptive field by introducing gaps in the attention as the model goes deeper. Lower levels are dedicated for modeling local patterns, and upper levels are dedicated for modeling global patterns. Longformer scales linearly with sequence length.
This paper also introduces Longformer-Encoder-Decoder (LED), discussed later in the writeup.
Global Attention is included by adopting global tokens (few pre-selected tokens) (e.g. CLS token) that have access to all input tokens. This is ymmetric, i.e. the global tokens attend to all tokens across the sequence, and all tokens across the sequence attend to the global tokens. The count of global tokesn is independent of the length of the sequence n, and hence the complexity is still O(n).
Sliding Window Attention is the fixed-sized window attention surrounding each token (specifically, if the window size is $w$, then each token attends to $w/2$ tokens on each side). Computation complexity of this is $O(nxw)$, which scales linearly with the sequence length n. If there are $l$ layers, then the receptive field size at the top lth layer is lxw (w, 2w, 3w, and so on).
Dilated Sliding Window Attention is a modification of the sliding window, which is dilated, i.e. the window has gaps of size dilation $d$. The dilation increases the coverage, so for l layers, and window size w, and dilation d, the voverage is $l$x$w$x$d$, which can reach tens of thousands of tokens even for mall values of $d$. This when combined with multi-head attention, leads to a better coverage. Different dilation configurations per heads improves the performance by allowing some heads without dilation to focus on the local context, and some heads with dilation to focus on the longer context.
*Linear Projections for Attention Computation: * In the orignal transformer model, the linear projections K,Q, and V are used to compute attention. In the longformer model, two et of matrices or linear projections are used to compute the global and the sliding window attention separately. These are however initialized with the same values to begin with ($Q_s$, $V_s$, $K_s$ and $Q_g$, $K_g$, $V_g$).
In regular transformers, the expensive operation is the matrix multiplication of $QK^T$ because both Q and K have $n$ (sequence length) projections. However, in longformer, the dilated sliding window attention computes only a fixed number of diagonals of $QK^T$. But its implementation requires a form of banded matrix multiplication which is not supported in existing deep learning libraries like pytorch and tf.
Longformer outperforms the comparable transformer-XL model, and matches the performance of Sparse transformer, and slightly underperforms the recent models that have twice as many parameters.
Pretraining and Finetuning Longformer pretrained on a document corpus, and finetuned on six tasks such as QA, classification, and coreference resolution. The resulting model can process sequences up to 4096 tokens long which is 8 times longer than BERT. MLM is as the pretraining objective.
Increasing the window size from the bottom to the top layer leads to the best performance, arranging them in the reverse way leads to worse performance, and using a fixed window size (the average of window sizes of the other configuration) leads to a performance that it is in between. Adding some dilation to two heads leads to some improvement compared with no dilation at all.
Longformer Encoder Decoder
LED is a longformer variant that has both the encoder and the decoder transformer stack. The encoder uses longformer’s efficient local+global attention instead of the full self attention. The decoder uses full self-attention to the entire encoded tokens and the previously decoded locations. LED params are initialized with BART parameters. Results of LED on the arxiv summarization dataset.
Extended Transformer Construction
ETC is a variant of the Sparse Transformer. Addresses two key challenges of the standard architecture: (i) Scaling input length, and (ii) Encoding structured inputs (underlying graph structure or the hierarchical structure among the input tokens), by combining global-local attention with relative position encodings and a Contrastive Prective Coding (CPC) pretraining objective allows ETC to encode structured inputs. Additionally, ETC can be initialized from pre-existing pre-trained standard BERT models.
CPC improves performance for tasks where the structure is important. It plays the role similar as MLM but at a sentence level granularity.
The input is divided into two sequences: global input and long input. The main componenets of ETC are: Relative Position Encoding Replace the absolute position encoding with relative PE, which encodes information about the relative tokens in the input with respect to each other. The input seq $x$ = ($x_1$, $x-2$, …, $x_n$) can be thought of as a fully connected and directed graph, where the label of the edge between tokens i and j, is $l_{ij}$. For a maximum clipping distance k, the relative position labels possible are 2k+1: $l_{-k}$, .., $l_{k}$. For input pairs with $j-i >= k$ label $l_k$ is used and for $j-i <= -k$ label $l_{-k}$ is used. Each label then becomes a learnable vecotr $a^{K}_{l}$.
Global-Local Attention ETC receives two input sequences: global input $x^g$ = ($x_{1}^{g}$, .. , $x_{n_g}^{g}$) and the long input $x^l$ = ($x_{1}^{l}$, .. , $x_{n_l}^{l}$). Long input is the regular input, but the global input is much smaller and consists of some auxiliary tokens ($n_g$ « $n_l$). Attention is then split into four components: (i) g2g (global-to-global) (ii) g2l (global-to-long) (iii) l2g (long-to-global) (iv) l2l (long-to-long) - attention restricted to a fixed radius $r$ « $n_l$
Complexity is O($n_g$($n_g$+$n_l$) + $n_l$($n_g$+2r+1)). If we assume that $n_g$=O(2r+1), then attention is linear in the size of the long input.
Note that $r$=1 and $n_g$=1 is Star Transformer. Similarly, placing all the tokens in the global input and setting $n_l$=0 is the standard transformer. For flexibility of implementation, boolean attention matrices $M^{g2g}$, $M^{g2l}$, $M^{l2g}$, and $M^{l2l}$ with zeroes for those pairs of tokens that should not attend to each other.
Pretraining: Use MLM and CPC. The goal of CPC is o predict subsequent inputs in latent space, i.e., to predict internal hidden representations of blocks of tokens. This was adapted in ETC by using global input sentence summary tokens. Given an input sequence containingnsentences, we mask all the tokens corresponding to a subset of sentences (but leave the sentence summary tokens in the global input). We then train the model to minimize the difference between the hid-den representation of the global sentence summary tokens for the masked sentences with respect to that of a global summary token that can see the un-masked sentence and nothing else. A Noise Contrastive Estimation (NCE) loss as in the work of Oord et al. (2018) is used.
BigBird
(NeurIPS 2020) The focus of BigBird is on the idea of adding random sparse attention patterns to global-local attention, that reduces the quadractic complexity to linear. They also show under which conditions, models like BigBird/ETC are universal approximators of sequence functions and are Turing complete, thereby preserving the properties of the quadratic full attention model. BigBird is built on top of ETC. The model comprises of Global attention, sliding window attention, and Random attention.
The generalized attention mechanism is described by a directed graph $D$ whose vertex set is n ={1,…,n}. The set of arcs (directed edges) represent the set of inner products that the attention mechanism will consider. The problem of reducing the quadratic complexity of self-attention can now be seen as a graph sparsification problem. It is well-known that random graphs are expanders and can approximate complete graphs in a number of different contexts including in their spectral properties. The authors believe that a sparse random graph for attention mechanism should have two desiderata: small average path length between nodes and a notion of locality.
After experimenting with Erdo-Renyi model, Watts and Strogatz model, they find that random block and local window attention wa insufficient in capturing all the context necessary to compete the performance of BERT. The solution is to intriduce global tokens: either by making some existing topkens ‘global’ or by ncluding additional ‘global’ tokens such as CLS.
The final attention mechanism for BIGBIRD(Fig. 1d) has all three of these properties: queries attend to random keys, each query attends to $w/2$ tokens to the left of its location and $w/2$ to the right of its location and they contain global tokens (The global tokens can be from existing tokens or extra added tokens).
Both ETC and BigBird cannot be used to autoregreively decode, making it an encoder-only model.
Routing Transformer
This is a content-based sparse attention mechanism, and learns dynamic sparse attention patterns using online k-means, reducing the complexity to O($n^{1.5}$).
The Q and K are projected into a routing matrix R of dimension nxd. So $R$ = $Q$x$W_R$ + $K$x$W_R$, where $W_R$ is a $d$ x $d$ orthonormal matrix.
Then k-means clustering is applied to the R matrix, and each tokens distance is computed to the cluster centroid. Since, the cluster centroids are trainable parameters, this behaves similarly to an all-attention layer. Then the Routing Strategy is defined as: $X^{‘}{i}$ = $\sum{j \in C_i}$ ($A_{ij}$ $V_{j}$); which means tht the token $i$ only attends to other tokens in the same cluster.
Reformer
(ICLR 2020) (Sparse Attention framerwork), finds the nearest neighbours of the attention query (those input tokens that would result in the highest attention weight) using local sensitivity hashing (LSH) and uses only those for attention.
Linformer
Linformer is based on the idea that elf-attention mechanism (i.e. the attention matrix P) is low-rank.
The authors apply singular value decomposition to A across different layers and different heads of the model, and plot the normalized cumulative singular value averaged over 10k sentences. The results exhibit a clear long-tail spectrum distribution across each layer, head and task. This implies that most of the information of matrix A can be recovered from the first few largest singular values.
This idea is utilied and the Linformer model projects the Nxd dimensional K and V to kxd dimensions using projection layers. This is a reduction of the length dimension and not the key and value dimensions.
Performer
Others briefly scanned
Summary image from the ETC paper:
Adaptive Span Transformer
In Adaptive Span Transformer (Sparse Attention framerwork), each attention head is associated with a decaying learnable masking fuction, which limits the number of tokens it can attend to. Show that lower layers use short attention spans, and only in higher layers the attention span longer.
Transformer-XL
Transformer-XL (Recurrence based) divides the input into egments and processes each segment one at a time. At each layer, the model attends to the layer immediately below for both the current and the previous input segments. This lead to layer $k$ being influenced by the current segment and the $k-1$ previous segments.
HIBERT
HIBERT (Hierarchical mechanism) uses the idea of hierarchical attention (splitting into blocks and ingesting independently to produce representative summary embeddings for the whole block and then using separate layers that work with the concatenation of the block representations).
BP-Transformer
BP-Transformer (Compressive Attention) builds a binary partioning tree over the input, and only lets the model attend to the leaves (i.e. the raw tokens) for nearby tokens, and higher nodes in the tree (summaries of group of tokens) as tokens grow more distant.
Star Transformer
Star Transformer (Guo et al. 2019) (Compressive Attention) where each token can attend only to its immediate left/right neighbors and to a separate auxiliary token that represents a summary of the whole input.
Compressive Transformer
Compressive Transformer (Rae et al. 2019) integrates the idea of summary tokens into Transformer-XL by compressing tokens in the input that are far away. The attention is computed from nearby tokens, and summarized infromation for more distant tokens.