Alternative activation functions The original transformer uses
ReLU activation function. Other activation functions were developed. The
Llama series and
PaLM used SwiGLU; both GPT-1 and BERT Alternative activation functions are often used in combination with
Gated Linear Units in the feedforward module. which is used in the
Llama series. Other examples include CapsuleNorm ScaleNorm, or FixNorm. The original transformer paper reported using a learned positional encoding, but finding it not superior to the sinusoidal one. found that causal masking itself provides enough signal to a transformer decoder that it can learn to implicitly perform absolute positional encoding without the positional encoding module.
RoPE RoPE (rotary positional embedding), is best explained by considering a list of 2-dimensional vectors [(x^{(1)}_1, x^{(2)}_1), (x^{(1)}_2, x^{(2)}_2), (x^{(1)}_3, x^{(2)}_3), ...]. Now pick some angle \theta. Then RoPE encoding is\text{RoPE}\big(x^{(1)}_m, x^{(2)}_m, m\big) = \begin{pmatrix} \cos m \theta & - \sin m \theta \\ \sin m \theta & \cos m \theta \end{pmatrix} \begin{pmatrix} x^{(1)}_m \\ x^{(2)}_m \\ \end{pmatrix} = \begin{pmatrix} x^{(1)}_m \cos m\theta - x^{(2)}_m \sin m \theta \\ x^{(2)}_m \cos m\theta + x^{(1)}_m \sin m \theta \\ \end{pmatrix} Equivalently, if we write the 2-dimensional vectors as complex numbers z_m := x^{(1)}_m + i x^{(2)}_m, then RoPE encoding is just multiplication by an angle:\text{RoPE}\big(z_m, m\big) = e^{i m\theta} z_m For a list of 2n-dimensional vectors, a RoPE encoder is defined by a sequence of angles \theta^{(1)}, ..., \theta^{(n)}. Then the RoPE encoding is applied to each pair of coordinates. The benefit of RoPE is that the dot-product between two vectors depends on their relative location only: \text{RoPE}\big(x, m\big)^T\text{RoPE}\big(y, n\big) = \text{RoPE}\big(x, m+k\big)^T\text{RoPE}\big(y, n+k\big) for any integer k.
ALiBi ALiBi (Attention with Linear Biases) is not a
replacement for the positional encoder on the original transformer. Instead, it is an
additional positional encoder that is directly plugged into the attention mechanism. Specifically, the ALiBi attention mechanism is\begin{align} \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^\mathrm{T}}{\sqrt{d_k}} + s B\right)V \end{align}Here, s is a real number ("scalar"), and B is the
linear bias matrix defined byB = \begin{pmatrix} 0 & 1 & 2 & 3 & \cdots \\ -1 & 0 & 1 & 2 & \cdots \\ -2 & -1 & 0 & 1 & \cdots \\ -3 & -2 & -1 & 0 & \cdots \\ \vdots & \vdots & \vdots & \vdots & \ddots \\ \end{pmatrix} in other words, B_{i, j} = j - i. The idea being that the linear bias matrix is a softened mask. Just as 0 represent full attention paid, and -\infty represents no attention paid, the linear bias matrix increases attention paid in one direction and decreases attention paid in the other direction. ALiBi allows pretraining on short context windows, then fine-tuning on longer context windows. Since it is directly plugged into the attention mechanism, it can be combined with any positional encoder that is plugged into the "bottom" of the entire network (which is where the sinusoidal encoder on the original transformer, as well as RoPE and many others, are located).
Relative Position Encodings Relative Position Encodings is similar to ALiBi, but more generic:\begin{align} \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^\mathrm{T}}{\sqrt{d_k}} + B\right)V \end{align}where B is a
Toeplitz matrix, that is, B_{i, j} = B_{i', j'} whenever i-j = i'-j'. This is contrasted with the original sinusoidal positional encoding, which is an "absolute positional encoding".
Efficient implementation The transformer model has been implemented in standard deep learning
frameworks such as
TensorFlow and
PyTorch.
Transformers is a library produced by
Hugging Face that supplies transformer-based architectures and pretrained models. If a transformer is used with a baked-in prompt, such as ["You are a customer support agent..."], then the key and value vectors can be computed for the prompt, and saved on disk. The saving in compute is significant when the model is used for many short real-time interactions, such as in online chatbots. In general, when a user uses an autoregressive transformer to generate a continuation to a sequence of tokens, the model would first perform a forward-pass on this sequence, whereby the KV caches over this sequence are computed. This is called
prefilling.
Hyperscalers serving large Transformer models may use
disaggregated inference, wherein prefilling and decoding are performed on separately specialized hardware.
FlashAttention FlashAttention is an algorithm that implements the transformer attention mechanism efficiently on a
GPU. It is a communication-avoiding algorithm that performs
matrix multiplications in blocks, such that each block fits within the
cache of a GPU, and by careful management of the blocks it minimizes data copying between GPU caches (as data movement is slow). See the page on
softmax for details. An improved version, FlashAttention-2, was developed to cater to the rising demand for language models capable of handling longer context lengths. It offers enhancements in work partitioning and parallelism, enabling it to achieve up to 230 TFLOPs/s on
A100 GPUs (
FP16/
BF16), a 2x speed increase over the original FlashAttention. Key advancements in FlashAttention-2 include the reduction of non-matmul FLOPs, improved parallelism over the sequence length dimension, better work partitioning between GPU warps, and added support for head dimensions up to 256 and multi-query attention (MQA) and grouped-query attention (GQA). Benchmarks revealed FlashAttention-2 to be up to 2x faster than FlashAttention and up to 9x faster than a standard attention implementation in PyTorch. Future developments include optimization for new hardware like
H100 GPUs and new data types like
FP8. FlashAttention-4 focuses on
pipelining to increase instruction
throughput, and was developed to perform particularly well on
Blackwell GPUs.
Multi-Query Attention Multi-Query Attention changes the Multihead Attention mechanism. Whereas normally, \text{MultiheadAttention}(Q, K, V) = \text{Concat}_{i \in [n_{\text{heads}}]}\left(\text{Attention}(XW^Q_i, XW^K_i, XW^V_i)\right) W^Owith Multi-Query Attention, there is just one W^K, W^V, thus: \text{MultiQueryAttention}(Q, K, V) = \text{Concat}_{i \in [n_{\text{heads}}]}\left(\text{Attention}(XW^Q_i, XW^K, XW^V)\right) W^O This has a neutral effect on model quality and training speed, but increases inference speed. More generally, grouped-query attention (GQA) partitions attention heads into groups, each of which shares the key-value pair. MQA is GQA with one group, while standard Multihead Attention is GQA with the maximal number of groups. Multihead Latent Attention (MLA) is a
low-rank approximation to standard MHA. Specifically, each hidden vector, before entering the attention mechanism, is first projected to two low-dimensional spaces ("latent space"), one for query and one for key-value (KV vector). This design minimizes the KV cache, as only the low-dimensional KV vector needs to be cached. is a method to accelerate token decoding. Similarly to
speculative execution in CPUs, future tokens are computed quickly, then verified. If the quickly computed tokens are incorrect, they are discarded and computed slowly. The key factor in speculative decoding is that a transformer decoder can verify faster than it can decode, in the following sense. Suppose we have two transformer models like GPT-3 and GPT-3-small, both with a context window size of 512. To generate an entire context window autoregressively with greedy decoding with GPT-3, it must be run for 512 times, each time generating a token x_1, x_2, ..., x_{512}, taking time 512 T_{\text{GPT-3}}. However, if we had some educated guess for the values of these tokens, we could verify all of them in parallel, in one run of the model, by checking that each x_t is indeed the token with the largest log-likelihood in the t-th output. In speculative decoding, a smaller model or some other simple heuristic is used to generate a few speculative tokens that are subsequently verified by the larger model. For example, suppose we use GPT-3-small to generate four speculative tokens: \tilde{x}_1, \tilde{x}_2, \tilde{x}_3, \tilde{x}_4. This only takes 4 T_{\text{GPT-3-small}}. These tokens are then run through the larger GPT-3 in one go. Suppose that \tilde{x}_1 and \tilde{x}_2 are verified by GPT-3 as what it would have picked, then those are kept, but \tilde{x}_3 is not, so \tilde{x}_3, \tilde{x}_4 are discarded, and GPT-3 is run on those. This would take 4 T_{\text{GPT-3-small}} + 3 T_{\text{GPT-3}}, which might be shorter than 4 T_{\text{GPT-3}}. For non-greedy decoding, similar ideas apply, except the speculative tokens are accepted or rejected stochastically, in a way that guarantees the final output distribution is the same as if speculative decoding was not used. In Multi-Token Prediction, a single forward pass creates a final embedding vector, which then is un-embedded into a token probability. However, that vector can then be further processed by another transformer block to predict the
next token, and so on for arbitrarily many steps into the future. This trades off accuracy for speed, since each new token costs just one more transformer block, rather than the entire stack.
Sub-quadratic transformers Training transformer-based architectures can be expensive, especially for long inputs. Many methods have been developed to attempt to address the issue. In the image domain, Swin transformer is an efficient architecture that performs attention inside shifting windows. In the audio domain, SepTr decouples the attention in time and frequency domains.
Long Range Arena (2020) is a standard benchmark for comparing the behavior of transformer architectures over long inputs.
Alternative attention graphs The standard attention graph is either all-to-all or causal, both of which scales as O(N^2) where N is the number of tokens in a sequence. Reformer (2020) reduces the computational load from O(N^2) to O(N\ln N) by using
locality-sensitive hashing and reversible layers. Sparse attention uses attention graphs that grows slower than O(N^2). For example, BigBird (2020) uses random
small-world networks which grows as O(N). Ordinary transformers require a memory size that is quadratic in the size of the context window. Attention-free transformers reduce this to a linear dependence while still retaining the advantages of a transformer by linking the key to the value.
Random Feature Attention Random Feature Attention (2021) uses
Fourier random features:\varphi(x) = \frac{1}{\sqrt D}[\cos\langle w_1, x\rangle, \sin\langle w_1, x\rangle, \cdots \cos\langle w_D, x\rangle, \sin\langle w_D, x\rangle]^Twhere w_1, ..., w_D are independent samples from the normal distribution N(0, \sigma^2 I). This choice of parameters satisfy \mathbb E[\langle \varphi(x), \varphi(y)\rangle] = e^{-\frac{\|x-y\|^2}{2\sigma^2}}, or e^{\langle x, y\rangle/\sigma^2} = \mathbb E[\langle e^{\|x\|^2/2\sigma^2} \varphi(x), e^{\|y\|^2/2\sigma^2}\varphi(y)\rangle] \approx \langle e^{\|x\|^2/2\sigma^2} \varphi(x), e^{\|y\|^2/2\sigma^2}\varphi(y)\rangle Consequently, the one-headed attention, with one query, can be written as \text{Attention}(q, K, V) = \text{softmax}\left(\frac{qK^\mathrm{T}}{\sqrt{d_k}}\right)V \approx \frac{\varphi(q)^T \sum_i e^{\|k_i\|^2/2\sigma^2}\varphi(k_i) v_i^T}{\varphi(q)^T \sum_i e^{\|k_i\|^2/2\sigma^2}\varphi(k_i)}where \sigma = d_K^{1/4}. Similarly for multiple queries, and for multihead attention. This approximation can be computed in linear time, as we can compute the matrix \varphi(k_i) v_i^T first, then multiply it with the query. In essence, we have managed to obtain a more precise version of \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^\mathrm{T}}{\sqrt{d_k}}\right)V \approx Q(K^TV/\sqrt{d_k}) Performer (2022) uses the same Random Feature Attention, but w_1, ..., w_D are first independently sampled from the normal distribution N(0, \sigma^2 I), then they are
Gram–Schmidt processed.
Multimodality Transformers can also be used/adapted for modalities (input or output) beyond just text, usually by finding a way to "tokenize" the modality. Multimodal models can either be trained from scratch, or by finetuning. A 2022 study found that transformers pretrained only on natural language can be finetuned on only 0.03% of parameters and become competitive with
LSTMs on a variety of logical and visual tasks, demonstrating
transfer learning. The LLaVA was a vision-language model composed of a language model (Vicuna-13B) and a vision model (
ViT-L/14), connected by a linear layer. Only the linear layer is finetuned.
Vision transformers and later
Whisper follow the same pattern for
speech recognition, first turning the speech signal into a
spectrogram, which is then treated like an image, i.e. broken down into a series of patches, turned into vectors and treated like embedding vector of tokens in a standard transformer.
Perceivers are a variant of transformers designed for multimodality. For image generation, notable architectures are
DALL-E 1 (2021), Parti (2022), Phenaki (2023), and Muse (2023). Unlike later models, DALL-E is not a
diffusion model. Instead, it uses a decoder-only transformer that autoregressively generates a text, followed by the token representation of an image, which is then converted by a
variational autoencoder to an image. Parti is an encoder–decoder transformer, where the encoder processes a text prompt, and the decoder generates a token representation of an image. Muse is an encoder-only transformer that is trained to predict masked image tokens from unmasked image tokens. During generation, all input tokens are masked, and the highest-confidence predictions are included for the next iteration, until all tokens are predicted. Phenaki is a text-to-video model. It is a bidirectional masked transformer conditioned on pre-computed text tokens. The generated tokens are then decoded to a video. == Applications ==