I don’t assume readers have any familiarity with state space models, but I do assume some familiarity with machine learning and mathematical notation.
If at any point you spot any errors, typos, or confusing wording, please let me know!
Mamba is a state space model (SSM) architecture that improves upon the S4 architecture. Sometimes known as S6, it makes two important modifications to S4:
Mamba parallelizes well during training, scales well with context length, performs inference efficiently, and most importantly, displays strong empirical results.
Sequence models can be placed on a spectrum based on their approach to information representation, from highly compressed (e.g. RNNs) to highly explicit (e.g. transformers).
Consider a vanilla RNN:
\[\begin{aligned}h_t &= \tanh(W_{hh}h_{t-1} + W_{xh}x_t)\\y_t &= W_{hy}h_t\end{aligned}\]The fixed size state \(h_{t-1}\) represents all prior context in a sequence at time \(t\). This underpins the core tradeoffs associated with RNNs:
On the other end of that spectrum, consider a decoder-only transformer model, à la GPT-3. In particular, let’s focus on its scaled dot-product self-attention layer:
\[\text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V\]The contrasting tradeoffs of transformers and RNNs highlight the crux of sequence modeling research: how can we improve model quality within the constraints of available compute?
Recently, the industry has made rapid forward progress not from algorithmic breakthroughs but instead dramatic increases in compute, due to both increased funding and continual improvements in hardware development.
Which is to say, scaling isn’t particularly clever, but oh boy is it effective^{3}.
Perhaps Rich Sutton said it best in The Bitter Lesson:
One thing that should be learned from the bitter lesson is the great power of general purpose methods, of methods that continue to scale with increased computation even as the available computation becomes very great.
Speculatively, RNNs and transformers have a limited lifespan because they make poor use of increasingly abundant compute. It’s critical that we design models that better leverage compute while also maintaining or preserving fundamental model quality. Hence, the interest in Mamba.
Mamba is based on S4, which is a linear time invariant (LTI) state space model (SSM), a common and useful subset of state space models more generally. For now, let’s focus on LTI SSMs, starting with their continuous form:
Ah yes, a chunk of \(\LaTeX\), how intuitive. Let’s break it down.
In the continuous case, the dynamics of how \(\mathbf{h}\) evolves over time are determined via the differential equation \(\mathbf{h}'(t) = \mathbf{A} \mathbf{h}(t) + \mathbf{B}\mathbf{x}(t)\). That is, the current value of \(\mathbf{h}\) itself determines how \(\mathbf{h}\) is changing at that moment in time.
The discrete case is similar, but because we cannot differentiate discrete functions, we notate this self-modifying recursive behavior with, well, a recurrence^{4}. We also use subscript notation instead of function notation to help emphasize this distinction:
\[\begin{aligned} \mathbf{h}_{t} &= \mathbf{\overline{A}} \mathbf{h}_{t-1}+ \mathbf{\overline{B}}\mathbf{x}_t \\ \mathbf{y}_t &= \mathbf{\overline{C}}\mathbf{h}_t + \mathbf{\overline{D}}\mathbf{x}_t \end{aligned}\]Furthermore, you may now also notice some horizontal bars over our SSM parameters, which is indicative of discretized parameters:
Because SSMs’ most general formulation is continuous, when working with discrete data there is (usually) a discretization step, where the “discretized” parameters are annotated with an overhead line, and are dependent on a learned “step size” parameter \(\Delta\) and fixed choices for discretization functions \(f_\mathbf{A}\), \(f_\mathbf{B}\), \(f_\mathbf{C}\), and \(f_\mathbf{D}\).
\[\begin{aligned} \mathbf{\overline{A}} &= f_{\mathbf{A}}(\Delta, \mathbf{A}) \\ \mathbf{\overline{B}} &= f_{\mathbf{B}}(\Delta, \mathbf{A}, \mathbf{B}) \\ \mathbf{\overline{C}} &= f_{\mathbf{C}}(\Delta, \mathbf{C}) \\ \mathbf{\overline{D}} &= f_{\mathbf{D}}(\Delta, \mathbf{D}) \\ \end{aligned}\]There are several reasonable strategies for discretization, but an intuitive exposition here would be longer than I’d like for this post, so for now you’ll have to just trust me (or go read Albert Gu’s 330 page thesis).
Discretization fine print and personal musings
S4 (or rather, S4D) goes a step further and parametrizes \(\mathbf{A}\) as a diagonal matrix, initialized via very clever approximations of the HiPPO matrix^{7} used in the original S4 paper. Some high level pointers:
Earlier, we described LTI SSMs as being able to handle inputs with multiple channels. In particular, we parametrized the SSM matrices as \(\mathbf{A} \in \mathbb{R}^{\texttt{N}\times\texttt{N}}\), \(\mathbf{B} \in \mathbb{R}^{\texttt{N}\times\texttt{D}}\), \(\mathbf{C} \in \mathbb{R}^{\texttt{V}\times\texttt{N}}\), and \(\mathbf{D} \in \mathbb{R}^{\texttt{V}\times\texttt{D}}\), where
However for S4 (and Mamba), the SSM is parametrized like so:
In this case, the \(\mathbf{A} \in \mathbb{R}^{\mathtt{N} \times \mathtt{N}}, \mathbf{B} \in \mathbb{R}^{\mathtt{N}\times 1}, \mathbf{C} \in \mathbb{R}^{1\times\mathtt{N}}\) matrices can all be represented by \(\mathtt{N}\) numbers. To operate over an input sequence \(\mathbf{x}\) of batch size \(\mathtt{B}\) and length \(\mathtt{L}\) with \(\mathtt{D}\) channels, the SSM is applied independently to each channel.
Structure fine print and personal musings
To recap, we discussed
The only actual model architecture change from Mamba is the removal of linear time-invariance for \(\mathbf{B}\), \(\mathbf{C}\), and \(\Delta\). They are now functions of \(\mathbf{x}_t\), i.e. the parameters are “selective”.
\[\begin{aligned} \mathbf{h}_{t} &= \mathbf{\overline{A}} \mathbf{h}_{t-1}+ \overline{\mathbf{B}(\mathbf{x}_t)}\mathbf{x}_t \\ \mathbf{y}_t &= \overline{\mathbf{C}(\mathbf{x}_t)}\mathbf{h}_t + \mathbf{\overline{D}}\mathbf{x}_t \end{aligned}\]Linear time invariance was critical to S4 because it provided the foundation for its efficient implementation. The sequential (slow) implementation is the obvious one: just use a for-loop, applying the parameters to each input one at a time.
The parallel (fast) implementation is more complex. In brief, because the updates from the SSM matrices are time-invariant, you can compute the outputs by constructing a particular kernel and then performing a full-width convolution. Since full-width convolutions can be computed quickly with the FFT trick, this ends up being a very efficient implementation.
However, LTI caps the expressivity of the model, as every update to state is handled identically, no matter what needs to be updated, if the current input is relevant or irrelevant, etc. With linear time variance, instead of learning fixed matrix parameters, Mamba learns functions which ingest the input and state at time \(t\) and then controls output dynamics.
This idea of not just naively applying the same simple state update for each input is also why gating RNNs (e.g. LSTM, GRU) are much more effective than vanilla RNNs, and indeed there is a strong theoretical connection between gating heuristics and selectivity in Mamba.
Since Mamba is not a linear time-invariance SSM, we can no longer rely on a full-width convolution as the basis of fast (i.e. parallel) implementation. We need other strategies, otherwise we are still stuck with slow training and limited practical utility.
The most critical techniques the authors employ are the Blelloch work-efficient^{9} parallel scan and hardware-aware memory management, which together facilitate the fast (in real world wall-clock time) training of Mamba.
Pedagogical fine print and personal musings
In most of computer science, we lean pretty hard on a computation model that hand waves away a lot of real world performance characteristics in favor of “constant time”. The random access machine is one popular model, but far from the only one. In any case, these models are akin to idealized models in an introductory physics class: they are exceedingly useful, but they also lie.
The real world is messier, with different hierarchies of performance in compute, memory, and bandwidth, all thanks to the very real implications of squeezing transistors on a silicon wafer of limited size.
Most GPU programs follow the same recipe:
Loading is by far the slowest part of this process, or at least for most neural network applications. The more you can minimize memory loading the better, even if you need to spend some compute to do so.^{11}
One of the Mamba authors, Tri Dao, is well-known for his work on FlashAttention, which introduced hardware-aware computation of self-attention, brining memory requirements from quadratic to linear and also providing dramatic wall-clock time speedups.
The core realization with FlashAttention was that the size of the intermediate computations dwarfed the actual size of the inputs and outputs. I.e., the \(QK^T\) matrix has size \(O(\mathtt{L}^2)\), even though the inputs themselves are only \(O(\mathtt{L})\).
The same principle is applicable here. The intermediate computations (the actual state machine mechanics) are again larger than the inputs and outputs, i.e. \(O(\mathtt{BLDN} > O(\mathtt{BLD} + \mathtt{DN})\). Thus a similar approach works well here, were computations are done in a blockwise fashion, maximizing the amount of computation that can occur in SRAM before needing to load to/from HBM.
If this still feels a bit hand-wavey, then please go read Horace He’s excellent blog post Making Deep Learning Go Brrrr From First Principles, which is perhaps the best technical blog post I’ve read in the past few years.
And if you’re feeling brave, then I’d encourage you to dive into the core implementation in the authors’ reference implementation, i.e. selective_scan_fwd_kernel.cuh
. In particular, pay attention to anything with smem
or shared
.
Of course, a very memory efficient implementation isn’t particularly useful if you can’t compute it parallel, thanks to how modern hardware accelerators are built. Hence the desire for parallelism.
To be clear, parallel computation of linear RNNs is not something the Mamba authors invented, nor is it even particularly recent^{12}. But it’s critical to an efficient implementation for Mamba, hence discussing it here. At a high level, there are two key pieces to gaining intuition here for Mamba:
This section will focus on the former.
Before discussing the parallel scan, let’s first think about something simpler, the parallel reduce.
Suppose I have a list with \(k\) elements and I want to perform a reduce operation over some binary associative operator \(\oplus\) where:
One possible naive implementation is just looping through the inputs in order. For example, with inputs [3, 1, 7, 0, 4, 1, 6, 3]
:
def naive_reduce(op, identity, inputs):
result = identity
for i in inputs:
result = op(result, i)
return result
>>> naive_reduce(op=lambda x, y: x + y, identity=0, inputs_=[3, 1, 7, 0, 4, 1, 6, 3])
25
This has linear time complexity w.r.t. the number of inputs, which is as good as we can hope for without parallelism, but what if we had more workers? We could split our input into pairs, compute those sums, and then recursively use those results as the inputs to another reduce, forming a computation tree:
A scan aggregates results from a sequence of elements by applying the operator cumulatively. For example, given the same inputs, the scan output would be [3, 4, 11, 11, 15, 16, 22, 25]
.
Closely related is the prescan which just shifts outputs by one, starting with an identity for the given operator (e.g., zero for summation): [0, 3, 4, 11, 11, 15, 16, 22]
. The Blelloch algorithm technically computes a prescan, although that is sufficient for our use case since it’s easy to go to/from a prescan.
The first step to a fast scan computation is to compute a parallel reduction, as we described in the previous section. However, this time, we preserve intermediate computations.
In the down sweep, we will maintain the invariant that: every node contains the sum of all prior leaf nodes, as determined visit order in a pre-order traversal, e.g.:
def pre_order_traversal(node: Node) -> None:
if not node:
return
visit(node)
pre_order_traversal(node.left)
pre_order_traversal(node.right)
If every node can contain the sum of all prior leaf nodes, then the values of the leaves themselves will be the results of a prescan!
Stepping through the down sweep with a concrete example:
When there are no prior leaf nodes, we use use the identity value, e.g. 0 for summation.
We now need to be careful about maintaining the invariant. Filling in the down-sweep level by level, for any particular node N
:
downsweep[N].left.value = downsweep[N].value
downsweep[N].right.value = downsweep[N].value + upsweep[N].left.value
Voila! For a sketch of the proof of why this works, consider a node in a preorder traversal.
The above examples used “sum”, but the Blelloch parallel scan works for all binary associative operators. For Mamba, we can define an associative operator, that when given the proper inputs, will compute the prescan of state vectors in parallel!
In fact, what we’ll discuss here is contextualized by Mamba, but is actually valid for all first-order recurrences of the following form^{13}:
\[h_t= \begin{cases} b_0 & t = 0 \\ (a_t\otimes h_{t-1}) \oplus b_t & t>0 \\ \end{cases}\]Where \(\oplus\) and \(\otimes\) meet the following criteria:
We massage our inputs into the sequence of \(c_t \equiv[a_t, b_t] \text{ for } t = 1, 2, ... , \mathtt{L}\), where
\[\begin{aligned} a_t &= \mathbf{\overline{A}} \\ b_t &= \overline{\mathbf{B(x}_t)}\mathbf{x}_t\\ \end{aligned}\]And define a new operator \(\bull\) as follows:
\[\begin{aligned} c_i \bullet c_j &\equiv [c_{j,a} c_{i,a}, \, c_{j,a} c_{i,b} + c_{j,b}] \\ &\equiv [a_j a_i, \, a_j b_i + b_j] \end{aligned}\]The “identity” for this operator is \(c_0 = \left[\mathbf{\overline{A}}, \, \mathbf{h}_0 \right]\), where \(\mathbf{h}_0\) is the initial state vector, before having seen any inputs. This is analogous to \(0\) for the \(\text{add}\) operator.
We then apply the operator in parallel with the Blelloch (pre)scan algorithm, and the outputs at the second index will be the desired \(\mathbf{h}_t\) results, computed in an efficient manner!
From the set up of the operator and how we set up \(a_t\) and \(b_t\), the second component is basically guaranteed to compute the proper state vectors. Recall the first line of the Mamba SSM:
\[\mathbf{h}_{t} = \mathbf{\overline{A}} \mathbf{h}_{t-1}+ \overline{\mathbf{B}(\mathbf{x}_t)}\mathbf{x}_t\]A sketch of the high level intuition here is:
This, unfortunately, is simply a wall of algebra:
\[\begin{aligned} &\text{Apply the definition of } \bull: \\ (c_i \bull c_j) \bull c_k &= [c_{j,a} \odot c_{i,a}, \; (c_{j,a} \otimes c_{i,b}) \oplus c_{j,b}] \bull c_k \\ &\text{Apply the definition of} \bull \text{again:} \\ &= [ c_{k,a} \odot (c_{j,a} \odot c_{i,a} ) , \; (c_{k,a} \otimes ((c_{j,a} \otimes c_{i,b}) \oplus c_{j,b})) \oplus c_{k,b}] \\ &\text{Associativity of } \odot : \\ &= [(c_{i,a} \odot c_{j,a}) \odot c_{i,a}, \; (c_{k,a} \otimes ((c_{j,a} \otimes c_{i,b}) \oplus c_{j,b})) \oplus c_{k,b}] \\ &\otimes \text{distributes } c_{k,a} \text { over} \oplus \text{:} \\ &= [(c_{k,a} \odot c_{j,a}) \odot c_{i,a}, \; (c_{k,a} \otimes (c_{j,a} \otimes c_{i,b}) \oplus (c_{k,a} \otimes c_{j,b})) \oplus c_{k,b}] \\ &\text{Associativity of} \oplus \text{:} \\ &= [(c_{k,a} \odot c_{j,a}) \odot c_{i,a}, \; (c_{k,a} \otimes (c_{j,a} \otimes c_{i,b})) \oplus ((c_{k,a} \otimes c_{j,b}) \oplus c_{k,b})] \\ &\text{Semiassociativity of } \otimes : \\ &= [(c_{k,a} \odot c_{j,a}) \odot c_{i,a}, \; ((c_{k,a} \odot c_{j,a}) \otimes c_{i,b}) \oplus ((c_{k,a} \otimes c_{j,b}) \oplus c_{k,b})] \\ &\text{Apply operator definition:} \\ &= c_i \bull [c_{k,a} \odot c_{j,a}, \; ( c_{k,a} \otimes c_{j,b} ) \oplus c_{k,b}] \\ &\text{Apply operator definition again:} \\ (c_i \bull c_j) \bull c_k &= c_i \bull (c_j \bull c_k) \end{aligned}\]The above seems to make sense, but perhaps you prefer python
to \(\LaTeX\) (I wouldn’t blame you).
As a first sanity check, if you’ve gotten this far, now is perhaps a good time to start perusing some reference implementations. E.g., in the author’s CUDA code, you can see the above operator implemented quite literally at selective_scan_common.h:113
.
But also, let’s write our own “unit tests” as an additional sanity check. FYI, this will leverage the jax.lax.associative_scan
implementation, which is a batteries-included implementation^{14} of Blelloch’s algorithm.
# Various imports
from einops import einsum
import jax
# Jax.lax already has a convenient parallel scan implementation.
import jax.lax as lax
import jax.numpy as jnp
key = jax.random.PRNGKey(seed=42)
B = 1 # batch size
L = 8192 # context length
N = 64 # hidden state size
D = 2 # num in channels
V = 1 # num out channels
# Gets the various fake x_t inputs.
def generate_random_xs(key, num_inputs=L, num_channels=D):
key, subkey = jax.random.split(key)
xs = jax.random.lognormal(subkey, shape=(L, D))
return key, xs
# Gets various fake A matrices. This isshape= actually constant in the paper,
# but it doesn't have to be.
def generate_random_As(key, num_inputs=L, state_size=N):
key, subkey = jax.random.split(key)
As = jax.random.lognormal(subkey, shape=(L, N, N))
return key, As
# Gets various fake B(x_t) matrices.
def generate_random_Bxs(key, num_inputs=L, state_size=N, num_channels=D):
key, subkey = jax.random.split(key)
Bxs = jax.random.lognormal(subkey, shape=(L, N, D))
return key, Bxs
# Gets the b_t term.
def get_bs(xs, Bxs):
return einsum(Bxs, xs, "l n d, l d -> l n")
# Jax plays nicest with jnp.arrays, so we'll stuff the values inside a
# single array and just unpack things here. I suppose I could use PyTrees
# but please forgive a bit of laziness/hackiness on my part.
def extract(c, state_size):
assert c.ndim == 1
assert c.shape[0] == state_size * state_size + state_size
return (
c[:state_size * state_size].reshape((state_size, state_size)),
c[-state_size:].reshape((state_size,))
)
def operator(c_prev, c_curr, num_inputs=L, state_size=N, num_channels=D):
prev_a, prev_b = extract(c_prev, state_size)
curr_a, curr_b = extract(c_curr, state_size)
return jnp.concatenate([
jnp.ravel(curr_a @ prev_a),
jnp.ravel(curr_a @ prev_b + curr_b)
])
vectorized_operator = jax.vmap(operator, in_axes=(0, 0), out_axes=0)
# Actually generate some fake test data.
key, xs = generate_random_xs(key)
key, Bxs = generate_random_Bxs(key)
key, As = generate_random_As(key)
bs = get_bs(xs, Bxs)
cs = jnp.concatenate([As.reshape(-1, N * N), bs], axis=1)
# %%timeit results on a freebie Google Colab VM:
# 283 ms ± 16.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
lax_scanned = lax.associative_scan(vectorized_operator, cs)[:, -N:]
def naive_scan_hs(h_0, As, Bxs, xs):
output = [h_0]
for a, bx, x in zip(As, Bxs, xs):
b = einsum(bx, x, "n d, d -> n")
output.append(a @ output[-1] + b)
return output[1:]
# %%timeit results on a freebie Google Colab VM:
# 3.34 s ± 313 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
naive_hs = jnp.vstack(
naive_scan_hs(jnp.zeros((N,)), As, Bxs, xs)
)
# The following returns Array(True, dtype=bool)! Which means that
# we're getting identical results, (allowing for some floating point
# imprecision), regardless of if we're using the naive iterative
# implementation, or the fast parallel implementation.
jnp.allclose(naive_hs, lax_scanned)
Hooray, we’re getting identical results as well as a much faster wall time. These trends only get more extreme as you leverage more capable hardware (the above results were on a freebie Google Colaboratory VM, i.e. not hardware accelerated and using only two CPU cores).
Whew, that was a lot. Let’s do a quick recap of what we’ve covered:
In particular, we paid extra attention to
There is a bunch of other important that I’ve glossed over here. Conceptually, I believe the selective SSM formulation and the scan implementation are the most important contributions, but a thorough discussion would also cover
And of course, I provided a bit of code to sanity check the discussion around parallel computation of the state vectors, but this is not a full implementation 🙂. Perhaps I’ll get around to writing a more thorough treatise in the future, but for now, I hope you’ve found this interesting!
Mamba is one of the most impressive papers I’ve read in years. It’s particularly impressive because the core concepts (selectivity and efficient implementation) are simple (where simple ≠ easy) yet effective.
In fact, throughout the paper there are not many particular insights that are demanding of galaxy brain intellect. For example:
But on the other hand, combining a useful knowledge of all of the above and distilling that into a cohesive and novel body of research? Super cool. I suppose that’s why both authors are younger than me but also tenure-track professors at top institutions.
I.e. post-training, the parameters involved in the updates are always the same. ↩
There are some tricks like KV caching which help here, but they unfortunately do not change asymptotic behavior. ↩
Not to say that scaling is easy either, because HPC is a fiendishly difficult thing to get right, it’s just a very simple concept to understand: use moar computer ↩
At the risk of offending proper mathematicians, you can kinda think of a first order differential equation as a infinitesimally precise recurrence. ↩
You should take any of my opinions with a large grain of salt. I, for one, have not been awarded any ICLR Outstanding Paper Awards 🙂 ↩
IMHO, this is where most people get stuck when trying to understand S4. I mean, how often does the typical computer scientist encounter Legendre polynomials? ↩
It is technically possible but quite silly to use a value smaller than \(\mathtt{D}\). Remember that the state itself is only \(\mathtt{N}\) dimensional, i.e. \(\mathbf{h}_t \in \mathbb{R}^\mathtt{N}\). It would be very challenging for a model to compress multiple \(\mathbb{R}^{\mathtt{D}}\) inputs into a single \(\mathbb{R}^{\mathtt{N}}\) if \(\mathtt{N} < \mathtt{D}\). ↩
There are other parallel scan implementations, e.g. the Hillis-Steele scan. In a nutshell, the Blelloch scan is more “work-efficient” yet less “step-efficient”. At a high level, when the amount of work to do exceeds the amount of available parallelism, the Blelloch implementation is more efficient. This matches up the typical situation in most ML problems, where even a lot of parallel compute is dwarfed by the scale of data/input demands. ↩
This is where the GPU kernel comes in, as in “kernel fusion”, a scary name for a simple thing. Also, is there a more overloaded word in STEM other than kernel? Kernel trick, GPU kernel, Linux kernel, convolution kernel, etc. Enough to warrant a fairly long Kernel (disambiguation) Wikipedia article. Maybe entropy or set are more overloaded… ↩
I mean, to an extent; the KV-cache is still a thing. ↩
Guy Blelloch published his work in 1988 and 1990, building upon work by Hillis and Steele from 1986. In doing so, he explicitly highlighted the capability of the parallel scan to efficiently compute certain first-order (and higher-order) recurrences.
To my knowledge, Eric Martin and Chris Cundy were the first to publish work connecting the Blelloch parallel scan algorithm to linear RNNs in early 2018. The publication was successful by usual metrics, appearing in ICLR and garnering 43 citations as of early 2024, but it was perhaps overshadowed at the time due to the decreasing popularity of recurrent architectures due to the then-recent introduction of the transformer architecture. ↩
Note that this is slightly different than what you will find in the OG Blelloch paper, since the first-order recurrence referred to in that paper swaps the position of \(h_t\) and \(a_t\). All following proofs have been modified accordingly. Rest assured, they are the same in spirit, with only some very minor algebra changes. ↩
I’m a bit curious about whether the authors considered using JAX, its lax
module and corresponding associative_scan
implementation with Pallas whenever necessary. I suppose the implementation of the Blelloch scan is not actually that many lines of code, so perhaps it just comes down to preference. ↩
I.e., do as much as possible in SRAM since writing to/from SRAM/HBM is slow. ↩
The Perceiver family of models from DeepMind decouple context length from memory and compute requirements. Perceiver AR extends this with support for autoregressive generation. It also has a refreshingly simple implementation, since at it’s core it is just a small variation on top of an otherwise standard decoder-only transformer.
I’ve provided a lightweight implementation here and provide additional context in this post.
Let’s set some consistent notation:
For example, these inputs could be
Transformers are increasingly powerful and expressive with increasing training data and parameter count, as demonstrated by large transformers like GPT-4 and PaLM. However, transformers scale poorly with w.r.t. both compute and memory when context size increases.
The most important operation in a transformer model is the attention operation, hence the seminal paper’s title being Attention Is All You Need.
If you’re unfamiliar with attention, then there are an inordinate amount of intuitive explanations out there, but here’s my abbreviated take that highlights the scaling: At its core, self-attention tries to determine for each input token, how it relates to each other input token. This is basically a doubly nested for-loop:
def self_attention(some_inputs):
scores = [
[0 for _ in range(len(some_inputs))]
for _ in range(len(some_inputs))
]
for r, query_tok in enumerate(some_inputs):
for c, key_tok in enumerate(some_inputs):
score_qk = relation(get_query_vector(query_tok), get_key_vector(key_tok))
scores[r][c] = score_qk
... # Normalize, combine with value vectors, etc
This is obviously quadratic.
If you’re familiar with basic matrix math, you’ll observe that you can rewrite this more compactly with some matrix operations. I.e. given \(Q, K, V \in \mathbb{R}^{M \times C}\), the attention operation is:
\[\text{softmax} \left( \frac { QK^T }{\sqrt C} \right) V\]Here you can see that attention scales quadratically via observing that \(QK^T \in \mathbb{R}^{M \times M}\).
A standard transformer has \(L\) blocks, which means both self-attention and subsequent MLP are run \(L\) times, which brings overall complexity to \(O(M^2L)\) and \(O(ML)\) for the attention and MLP operations respectively.
Suppose by some miracle we are able to reduce self-attention’s complexity down to \(O(M)\), perhaps via some clever approximations. Throwing the math under the rug of Big-O notation, this would imply that a transformer model scales with \(O(ML)\). This is of course better than \(O(M^2L)\), but for large M (e.g. long context models), this is still rather poor scaling.
This incentivizes model designers to come up with clever and bespoke ways of reducing context length (e.g. the ViT, which famously converts an image into patches first). These are effective, but it’s not always clear how to handle other domains, and it’s tremendously easy to get input sizes that are in the deep hundreds of thousands, if not larger. For example, 10 seconds of time-domain audio (\(10s \times 44.1 \text{kHz} = 441 \text{ thousand}\)) or 10 seconds of low resolution (e.g. 240p) grayscale video without any audio (\(\text{240 px high} \times \text{362 px wide} \times 24 \text{Hz} \times 10s \approx 20 \text{ million}\)).
Therefore, there’s a desire to not only handle the quadratic nature of self-attention, but also to further decouple context length from computation (past processing the inputs, of course).
The crux of the Perceiver architecture is to use cross-attention for the first block. The only difference between cross-attention and self-attention is that cross attention uses query vectors that are attending to key and value vectors that may come from a different underlying inputs^{1}:
\[Q \in \mathbb{R} ^ {N \times C}; \; K, V \in \mathbb{R} ^{M \times C}\]Consider \(N\) to be a hyperparameter, selected such that \(N \ll M\). Now \(QK^T \in \mathbb{R}^{N \times M}\); as in, the first attention layer is now an \(O(NM)\) operation! And, since the output latents are of length \(N\), each subsequent self atttention is \(O(N^2)\), and each MLP layer is also only \(O(N)\) now.
This is all well and good, but it depends on our ability to actually get reasonable queries.
For the initial query, both the original Perceiver and Perceiver IO used a fixed latent, analogous to the initial hidden state in an RNN. Perceiver AR uses an arguably simpler approach, where the initial query just comes from the last \(N\) inputs. This way, you can still apply causal masking for autoregressive generation!
Since \(N\) is now a hyperparameter choice, if one uses \(N=M\), then this is actually equivalent to a standard decoder-only transformer, which I find particularly elegant.
So conceptually this makes sense, so what do we need to actually modify? Well, as you can also see from the above diagram, there’s not too much to change. You need to
x = torch.stack([data[i : i + block_size] for i in ix])
y = torch.stack(
[data[i + block_size + 1 - query_size : i + block_size + 1] for i in ix]
)
For a standard transformer, query_size
is identical to block_size
.
def forward(self, inputs_q, inputs_kv):
...
# Causal masking.
mask = torch.tril(
torch.ones(q_time, kv_time, device=attention.device),
diagonal=kv_time - q_time,
)
attention = attention.masked_fill(~mask.bool(), float("-inf"))
For a standard transformer, the diagonal
argument is just 0
, the default value.
def forward(self, x: torch.Tensor):
inputs_q, inputs_kv = x[:, -self.query_size :, :], x
normed_q, normed_kv = self.ln1(inputs_q), self.ln1(inputs_kv)
x = inputs_q + self.attn(inputs_q=normed_q, inputs_kv=normed_kv)
x = x + self.mlp(self.ln2(x))
return x
For a standard transformer, the main Block
’s forward
function only passes the input to attention, since it’s doing self-attention, but since we’re now doing cross-attention on inputs, we need to handle inputs_q
and inputs_kv
separately.
The attention operations in the middle of the transformer are self-attentions over the \(\mathbb{R} ^ {N \times C}\) latent space, so for these layers the inputs_q == inputs_kv
since getting the last the full size of the latent is just self.query_size
anyways.
And… that’s about it. Feel free to see the full repo here.^{2} The repo takes inspiration (and code) from Karpathy’s NanoGPT repo, and reuses the core logic around a simple training script over a single plain text file Shakespeare “corpus”. I encourage you to mess around with some of the parameters in the train.py
script. I was able to train surprisingly long context models, caveated that I was running locally on a 2019 Macbook, but of course you’re free to use a proper accelerator of your choice.
If you want to see a more fully fleshed out implementation, then you’ll be pleased to know that DeepMind has actually open sourced a repo here. It’s a research codebase so it definitely was not designed for pedagogical friendliness, but it is supremely flexible^{3}.
The original Perceiver paper goes further and uses \(Q \in \mathbb{R} ^ {N \times D}\), which requires further projections. Fun fact, the original implementation refers to this projection as conv_1d
even though it’s just a Linear
. I’m not exactly sure why they’d call it this; perhaps due to other patterns in computer vision where a 1x1 convolution has been used to reduce channel size? ↩
This repo doesn’t experiment with any more sophisticated optimizers, nonlinearities, or further-optimized attention chunking/computation. This repo also assumes you are not meddling with channel dimensions, e.g. not projecting into larger or smaller channel sizes. ↩
It’s also implemented in JAX, which is a pro or con depending on whether or not you’re a Googler (jk, kind of 🙂). ↩
My favorite research findings are ones that make me go “Why didn’t I think of that?” due to a paradoxical combination of simplicity and cleverness. The OpenAI paper “Efficient Training of Language Models to Fill in the Middle” (FIM) is the most recent thing I’ve read that makes me feel this way.
For the uninitiated, modern language model (pre)training has a very short number of steps involved:
I’m being a bit facetious here because each step is complex, and also because this doesn’t discuss things like fine-tuning, inference, etc… But for pretraining, this is a reasonable approximation.
Some research focuses on step 1, e.g. how to construct the most useful training dataset, given the expected tradeoffs between quality and quantity. An enormous amount of research focuses on step 2, e.g. how to architect your model to improve its performance in any number of ways (memory efficiency, modeling distant relationships, etc). And some research focuses on step 3, e.g. how to set up your task. BART has a denoising approach (input text is corrupted), decoder-only transformers tend to do next token prediction (i.e., all tokens are visible except the last one), etc.
Efficient Training of Language Models to Fill in the Middle tackles step 3. It maintains the standard “next token” autoregressive task of most decoder-only transformers, but with the simple twist that some fraction of those tokens have their order modified.
For simplicity, we’ll also assume that we’re doing whitespace/punctuation tokenization.
Now consider the sentence “What I cannot create, I do not understand”. Under the typical autoregressive training setup, your input and target would look something like
Input: ["<bos>", "What", "I", "cannot", "create", ",", "I", "do", "not", "understand"]
Target: ["What", "I", "cannot", "create", ",", "I", "do", "not", "understand", "<eos>"]
Where <bos>
and <eos>
are special tokens representing the beginning and end of an input sequence. For FIM, the model maintains the same mechanics around training and optimization, but the data is transformed a bit: it’s chunked into a prefix, middle, and suffix, represented with corresponding special tokens:
Input: ["<pre>", "What", "I", "<suf>", "do", "not", "understand", "<mid>", "cannot", "create", ",", "I"]
Target: ["What", "I", "<suf>", "do", "not", "understand", "<mid>", "cannot", "create", ",", "I", "<eot>"]
This way, during inference time, you can prompt a model to fill in the middle very naturally by just providing context up to <mid>
. And… that’s all folks.
Well, not exactly. There are other details and context which, of course, you can find in the paper. But the crux of the paper really is that simple!
Chatbots are a natural extension to autoregressive models, and relatively convenient for the engineers implementing them. But not every product or task is most naturally represented as an extension of “predict the most probable next word”. Compared to an off-the-shelf autoregressively pretrained LM, filling in the middle seems like an more intuitive way to generate text for certain types of products, e.g.:^{1}
Admittedly, there is a grab bag of overlapping techniques to address the problem of alignment. That is, there is ongoing research in turning something that predicts the most probable next token into something that is useful. These techniques may include standard supervised fine tuning, instruction tuning, RLHF^{2}, or simply prompt engineering.^{3}
But there is no reason these techniques wouldn’t work with a FIM pretrained model. To top it off, FIM also seems to grant these new middle-filling capabilities without harming overall model performance, making it just about the closest thing I’ve seen to a free lunch in a little while.
To be fair, it seems that filling in the middle may be a fundamentally more challenging problem than left-to-right (i.e. typical) sampling, perhaps because generated text needs to flow naturally with both a prefix and suffix, as opposed to just a prefix. ↩
RLHF comes with its own grab bag of challenges, such as detecting/preventing over optimization of the reward model, high VRAM of actor-critic policy gradient methods, and the overarching sample inefficiency of basically all RL techniques. ↩
“Prompt engineering” always felt like kind of a bizarre phrase to me. “Engineering” ideally implies some sort of systematic understanding and predictability, but prompt engineering seems more like an art at best and voodoo magic at worst. ↩
Perhaps it’s a new approach that will finally™ solve the problem of quadratic scaling of transformers w.r.t. context length, be it via clever tweaks inspired by convolutions, literally using convolutions, more clever utilization of accelerators, or various memory bottlenecks.^{1} Perhaps it’s any number of new models that have been fine-tuned by hobbyists, perhaps using leaked LLaMA weights or ChatGPT/ShareGPT data.^{2}
But there is another thing that hasn’t gotten as much mainstream attention. That is, just how easy it has become to experiment with some seriously advanced models, models that would have quite recently been state of the art and required non-trivial capital to train. Of course, researchers have been publishing models and code for a while now, but the current state of affairs with easy to use APIs, reasonably good documentation, and emphasis placed on community interaction and contribution? That feels rather new.^{3}
As an example, I wanted to walk through a small language model that I trained for my own amusement.
I wanted to train a model that could talk a bit more informally, and perhaps talk a bit more like how I will text with friends. It’s no secret that LLMs tend to output text that errs on the side of formal/verbose. My suspicion is that this is due to some combination of
My hope (inspired by this paper) was that it would take a relatively minimal amount of fine tuning to get a language model to chat more like me. That is, talk a bit more informally, use different punctuation (e.g. newlines instead of periods), etc.
Getting training data for this was fairly simple. I’ve been using Facebook Messenger for a long time now, and Facebook provides a convenient way to download all of your data as a bunch of JSONs. Then it’s just a bit of Python to parse the messages, yielding the ones where I’m responding. That generator then can be used directly to create a Dataset, via Hugging Face’s dataset API; specifically from_generator
.
To be precise, the training data was in the format of “prompts” and “responses”, where prompts were continguous blocks of messages from anybody that wasn’t me, joined with a pipe char (|
). Responses were of the same format, just comprising messages that I sent. I use a pipe char to avoid any preprocessing shenanigans regarding newlines that some tokenizers may perform (e.g., transforming into newlines into spaces). I used a character that wouldn’t normally come up in texting, since substituting other punctuation (e.g. a period) can change perceived tone.^{5}
On Hugging Face, there are many language models appropriate for conversational interaction. I chose to run things using a 400M-parameter distilled BlenderBot, which uses a standard seq2seq (i.e. encoder-decoder) transformer. The paper came out in 2020, which is comparatively old, but these models are convenient since they’ve already been fine-tuned on conversational prompts. In particular, the 400M-parameter model isn’t so large that one needs to start thinking about model parallelism yet and has the added benefit of knowledge distillation from the bigger versions. I.e., it ought to be fine for some quick/fun hacking.
I used a bone-stock PyTorch Lightning training loop, which is almost a one-liner. Of course one can choose to implement their own checkpointing, looping over epochs, etc., but why bother reinventing the wheel, especially for a one-off just-for-fun training run?
As far as compute goes, my 2019-era MacBook is woefully underpowered, but Colab Pro is cheap and good enough here. An instance with 50GB RAM and a 16GB VRAM GPU., albeit an old one, was plenty for my purpose, and cost roughly $.20/hr.^{6}
I didn’t want to fine-tune the entire model, since that would’ve taken a while with the admittedly slower NVIDIA T4 that I was using. However, the peft library makes it surprisingly easy to leverage SOTA fine-tuning methods. For me, using AdaLoRA, an improvement of low-rank adaptation that was only published in March, was three operations: an import, a config initialization, and then an assignment.
One quirk of the library is that it has a hard-coded mapping from transformer model architectures (as strings!) to modules for which to actually adapt via AdaLoRA, see source here.^{7} These mappings don’t work out of the box for BlenderBot, but you can just inspect the module names^{8} and then it’s no problem:
class Imitator(pl.LightningModule):
def __init__(self):
super().__init__()
pretrained_model = BlenderbotForConditionalGeneration.from_pretrained("facebook/blenderbot-400M-distill")
peft_config = AdaLoraConfig(
peft_type="ADALORA",
task_type=TaskType.SEQ_2_SEQ_LM,
inference_mode=False,
r=16,
lora_alpha=32,
target_modules=["q_proj", "k_proj", "v_proj", "out_proj", "fc1", "fc2"],
lora_dropout=0.1,
)
self.model = get_peft_model(pretrained_model, peft_config)
...
> dude i need your advice
> she said she wanted to "take a break"
> LOL
> i'm not sure what to say to that
> i don't really know what advice to give
Alright, perhaps a bit unempathetic and not much of a specific response, but definitely talks more like me. In particular, the model’s learned that I tend to encapsulate thoughts in distint messages often without ending punctuation.
> dude
> gels are such a game changer
> felt amazing at mile 15 today
> oh wow
> that's a lot of running
> i'm impressed
> i can't run a mile
This isn’t perfectly accurate since I can indeed run a mile and also advocate for gels^{9} on runs longer than 90 minutes, but this is plausibly something I would’ve said before getting into running.
The punchline isn’t that I took a pre-trained model and then fine-tuned it on another dataset; after all that’s been done before.
The punchline is that, not including the script I used to parse my Facebook messages, this was only ~50 lines of code. Code that, by my assessment, is rather explicit and not reminiscent of code golf.
It’s incredible that the open source ecosystem has advanced to the stage where you can experiment with very modern techniques (transformers, parameter efficient fine-tuning, etc) in just O(dozens) of LoC.
This lowers the barrier of entry for not only hobbyists and enthusiasts, but also for professionals who have requirements that aren’t met by the existing ecosystem of model inference APIs, or simply prefer driving stick.
What’s even more fascinating than this research outright, is just how much other research continues to be done on top of bone stock transformers, even pretty hype research. E.g., DeepMind’s Gato was trained over a decoder-only transformer “for simplicity and scalability”. ↩
What’s also interesting here is that both of these seem to be in a legal gray area since LLaMA weights were leaked and OpenAI’s terms of service prohibits their products from being used to train competitor models, and despite that they’re both incredibly popular approaches. ↩
This is perhaps biased by my experience, which has been mostly with heavy duty infrastructure that originated from within Alphabet. ↩
Some may not consider Reddit comments to be “high quality”, but it’s important to compare it to internet text en masse. Seriously, take a look at some examples from these canonical large web scrapes. There’s an amusing amount of SEO spam, websites that just aren’t even parsed correctly, e.g. “This website requires JavaScript …”. ↩
Here’s a fun article from the New York Times discussing this in more detail. ↩
This is an estimate, since the Colab pricing model depends on “compute credits” per hour, and it’s not entirely clear how those rates are calculated. Regardless, you can get 100 credits per month for $10, and high-ram instances with an NVIDIA T4 were consistently around 2.xx credits per hour. ↩
Quite a few of these are also commented out, for reasons that aren’t super clear to me. Perhaps it’s the library maintainers being strict about tests? ↩
This is easy, via nn.Module.modules()
. It’s always nice when something does exactly what it says on the tin. ↩
For the uninitiated, energy gels are portable and easy-to-digest carbs for endurance athletes. ↩