In this post, I attempt to provide a walkthrough of the essence of the Mamba state space model architecture, occasionally sacrificing some rigor for intuition and overall pedagogical friendliness.

I don’t assume readers have any familiarity with state space models, but I do assume some familiarity with machine learning and mathematical notation.

If at any point you spot any errors, typos, or confusing wording, please let me know!

TL;DR

Mamba is a state space model (SSM) architecture that improves upon the S4 architecture. Sometimes known as S6, it makes two important modifications to S4:

  • Selective SSM parameters
  • Efficient implementation via parallel scan

Mamba parallelizes well during training, scales well with context length, performs inference efficiently, and most importantly, displays strong empirical results.

Setting the stage

Sequence models can be placed on a spectrum based on their approach to information representation, from highly compressed (e.g. RNNs) to highly explicit (e.g. transformers).

Exhibit A: the RNN

Consider a vanilla RNN:

\[\begin{aligned}h_t &= \tanh(W_{hh}h_{t-1} + W_{xh}x_t)\\y_t &= W_{hy}h_t\end{aligned}\]

The fixed size state \(h_{t-1}\) represents all prior context in a sequence at time \(t\). This underpins the core tradeoffs associated with RNNs:

Pros

  • Efficient autoregressive inference: Since \(h_{t}\) encapsulates prior inputs, the model only needs to consider a small and constant set of new information for each subsequent input.
  • No limits to context length: There is nothing in the formulation that explicitly constrains the model to a maximal sequence length.

Cons

  • Ineffective modeling of complex dependencies: All prior context must be compressed, via static1 updates, into a fixed amount of bits.
  • Slow training: Training requires sequential backpropagation through time, making poor utilization of hardware accelerators, e.g. GPUs or TPUs. Accelerators have enormous throughput for parallel computation, but are otherwise surprisingly slow at sequential computation.

Exhibit B: the transformer

On the other end of that spectrum, consider a decoder-only transformer model, à la GPT-3. In particular, let’s focus on its scaled dot-product self-attention layer:

\[\text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V\]

Pros

  • Unreasonably effective at modeling complex dependencies: Every token gets to explicitly attend to all other prior tokens, instead of relying on a fixed-sized state as a “summary”.
  • Highly parallel training: There are no dependencies along the time dimension, and the core operations are matrix multiplications, which hardware accelerators have been excellent at parallelizing for decades.

Cons

  • Quadratic scaling with context length: Since every input attends to all prior inputs, the total amount of computation required accelerates as the number of tokens increases.
  • Autoregressive inference is expensive2: Unlike RNNs, there is no fixed-sized compressed representation of the prior tokens; each new token must explicitly attend to all prior tokens.

Why Mamba? Why now?

The contrasting tradeoffs of transformers and RNNs highlight the crux of sequence modeling research: how can we improve model quality within the constraints of available compute?

Recently, the industry has made rapid forward progress not from algorithmic breakthroughs but instead dramatic increases in compute, due to both increased funding and continual improvements in hardware development.

Which is to say, scaling isn’t particularly clever, but oh boy is it effective3.

Perhaps Rich Sutton said it best in The Bitter Lesson:

One thing that should be learned from the bitter lesson is the great power of general purpose methods, of methods that continue to scale with increased computation even as the available computation becomes very great.

Speculatively, RNNs and transformers have a limited lifespan because they make poor use of increasingly abundant compute. It’s critical that we design models that better leverage compute while also maintaining or preserving fundamental model quality. Hence, the interest in Mamba.

Linear time-invariant state space models

Mamba is based on S4, which is a linear time invariant (LTI) state space model (SSM), a common and useful subset of state space models more generally. For now, let’s focus on LTI SSMs, starting with their continuous form:

Continuous form

\[\begin{aligned} \mathbf{h}'(t) &= \mathbf{A} \mathbf{h}(t) + \mathbf{B}\mathbf{x}(t) \\ \mathbf{y}(t) &= \mathbf{C}\mathbf{h}(t) + \mathbf{D}\mathbf{x}(t) \end{aligned}\]

Ah yes, a chunk of \(\LaTeX\), how intuitive. Let’s break it down.

Components

  • \(t\) represents time, and is a scalar real number, i.e. \(t \in \mathbb{R}\).
    • Although integers are a subset real of numbers (\(\mathbb{Z} \subset \mathbb{R}\)), there is some specific handling and notation for the discrete case. We’ll discuss this further in the next section.
  • \(\mathbf{x}(t) \in \mathbb{R}^{\mathtt{D}}\) is the input to our model at time \(t\), which has dimensionality \(\mathtt{D}\) (i.e., has \(\mathtt{D}\) channels).
    • E.g., if you are doing modeling over raw audio files, then \(\mathtt{D} = 1\), \(t \in \mathbb{R}^+\), and \(\mathbf{x}(t)\) is the amplitude from the microphone’s recording at time \(t\).
    • E.g., if you are doing modeling over text token embeddings, then \(\mathtt{D} = \text{embedding dimensionality}\), \(t \in \mathbb{Z}^+\), and \(\mathbf{x}(t)\) is the \(\mathtt{D}\)-length embedding vector for the token at position index \(t\).
  • \(\mathbf{y}(t) \in \mathbb{R}^{\mathtt{V}}\) is the corresponding output of our model, at time \(t\).
    • E.g. for binary music/not-music classification over raw audio files: \(\mathtt{V} = 1\), \(t \in \mathbb{R}^+\), and \(\mathbf{y}(t) \in [0, 1]\) is the predicted probability that all prior audio context was music at time \(t\).
    • E.g. for next-token language modeling: \(\mathtt{V}=\text{tokenizer vocab size}\), \(t \in \mathbb{Z}^+\), and \(\mathbf{y}(t) \in [0, 1]^\mathtt{V}\) is the predicted probability distribution over all tokens in the vocabulary.
  • Similar to a vanilla RNN, the state vector \(\mathbf{h}(t) \in \mathbb{R}^\mathtt{N}\) encapsulates all prior inputs at time \(t\).
  • The matrices \(\mathbf{A} \in \mathbb{R}^{\texttt{N}\times\texttt{N}}\), \(\mathbf{B} \in \mathbb{R}^{\texttt{N}\times\texttt{D}}\), \(\mathbf{C} \in \mathbb{R}^{\texttt{V}\times\texttt{N}}\), and \(\mathbf{D} \in \mathbb{R}^{\texttt{V}\times\texttt{D}}\), known as the state matrix, input matrix, output matrix, and feedthrough matrix respectively, comprise the actual parameters of the SSM. These parameters, along with inputs and state, determine the outputs. These parameters also determine how the state itself evolves over the sequence of inputs.

Discrete form

In the continuous case, the dynamics of how \(\mathbf{h}\) evolves over time are determined via the differential equation \(\mathbf{h}'(t) = \mathbf{A} \mathbf{h}(t) + \mathbf{B}\mathbf{x}(t)\). That is, the current value of \(\mathbf{h}\) itself determines how \(\mathbf{h}\) is changing at that moment in time.

The discrete case is similar, but because we cannot differentiate discrete functions, we notate this self-modifying recursive behavior with, well, a recurrence4. We also use subscript notation instead of function notation to help emphasize this distinction:

\[\begin{aligned} \mathbf{h}_{t} &= \mathbf{\overline{A}} \mathbf{h}_{t-1}+ \mathbf{\overline{B}}\mathbf{x}_t \\ \mathbf{y}_t &= \mathbf{\overline{C}}\mathbf{h}_t + \mathbf{\overline{D}}\mathbf{x}_t \end{aligned}\]

Furthermore, you may now also notice some horizontal bars over our SSM parameters, which is indicative of discretized parameters:

Parameter discretization

Because SSMs’ most general formulation is continuous, when working with discrete data there is (usually) a discretization step, where the “discretized” parameters are annotated with an overhead line, and are dependent on a learned “step size” parameter \(\Delta\) and fixed choices for discretization functions \(f_\mathbf{A}\), \(f_\mathbf{B}\), \(f_\mathbf{C}\), and \(f_\mathbf{D}\).

\[\begin{aligned} \mathbf{\overline{A}} &= f_{\mathbf{A}}(\Delta, \mathbf{A}) \\ \mathbf{\overline{B}} &= f_{\mathbf{B}}(\Delta, \mathbf{A}, \mathbf{B}) \\ \mathbf{\overline{C}} &= f_{\mathbf{C}}(\Delta, \mathbf{C}) \\ \mathbf{\overline{D}} &= f_{\mathbf{D}}(\Delta, \mathbf{D}) \\ \end{aligned}\]

There are several reasonable strategies for discretization, but an intuitive exposition here would be longer than I’d like for this post, so for now you’ll have to just trust me (or go read Albert Gu’s 330 page thesis).

Discretization fine print and personal musings

  • The authors prefer to drop \(\mathbf{\overline{D}}\) entirely from their exposition of SSMs, because it “can be viewed as a skip connection that does not interact with the state…, the most important part of the SSM.” To be clear, it is still explicitly parametrized in their actual reference S4 and Mamba implementations, so IMHO5 it’s a bit more clear to leave it in the exposition even at the cost of verbosity.
  • \(\mathbf{C}\) is not actually discretized per se, because \(\mathbf{A}\) and \(\mathbf{B}\) are already discretized, so in the Mamba paper there is no \(\mathbf{\overline{C}}\). Equivalently, in the S4 paper, \(\mathbf{C}\) is in fact discretized, where its discretization function is the identity. I personally prefer the notation of the latter, but it’s not really a big deal.
  • There are at least a few valid methods of discretization, the Euler method, zero-order-hold (ZOH) method, or the bilinear method. The Euler method is the weakest, but choosing between the latter two is nuanced. In fact, the S4 paper goes with the bilinear method, but Mamba highlights ZOH instead.
    • I am no mathematician, but I find the setup of predefined discretization functions to be a touch surprising, because it goes against my personal deep-learning intuition of “let the model learn everything”. I’m curious how well learned discretization functions would perform, in addition to simply a learned step size parameter.
  • Whether or not discretization is necessary at all is also an interesting question. There are various nice properties and interpretations, such as viewing \(\mathbf{x}_t\) as being sampled from continuous data, i.e. \(\mathbf{x}_t = \mathbf{x}(t\Delta)\). Yet, authors also mention how other SSM research e.g. Zhang et al., 2018, Effectively Modeling Time Series with Simple6 Discrete State Spaces, did not explicitly discretize and still achieved good results.

Structured SSMs, à la S4/S4D

Structure for A, i.e. initialization

S4 (or rather, S4D) goes a step further and parametrizes \(\mathbf{A}\) as a diagonal matrix, initialized via very clever approximations of the HiPPO matrix7 used in the original S4 paper. Some high level pointers:

  • HiPPO initialization was required to get good performance from the S4 architecture. A random initialization with S4 produced a model with middling performance.
  • Matrix-vector multiplications are, of course, much faster/cheaper if your matrix is diagonal.
  • S4D style initialization continues to work well with Mamba, but now random initialization also performs pretty well!

Inputs and shapes

Earlier, we described LTI SSMs as being able to handle inputs with multiple channels. In particular, we parametrized the SSM matrices as \(\mathbf{A} \in \mathbb{R}^{\texttt{N}\times\texttt{N}}\), \(\mathbf{B} \in \mathbb{R}^{\texttt{N}\times\texttt{D}}\), \(\mathbf{C} \in \mathbb{R}^{\texttt{V}\times\texttt{N}}\), and \(\mathbf{D} \in \mathbb{R}^{\texttt{V}\times\texttt{D}}\), where

  • \(\mathtt{N}\) is the state space dimension size8
  • \(\mathtt{D}\) is the number of input channels
  • \(\mathtt{V}\) is the number of output channels

However for S4 (and Mamba), the SSM is parametrized like so:

In this case, the \(\mathbf{A} \in \mathbb{R}^{\mathtt{N} \times \mathtt{N}}, \mathbf{B} \in \mathbb{R}^{\mathtt{N}\times 1}, \mathbf{C} \in \mathbb{R}^{1\times\mathtt{N}}\) matrices can all be represented by \(\mathtt{N}\) numbers. To operate over an input sequence \(\mathbf{x}\) of batch size \(\mathtt{B}\) and length \(\mathtt{L}\) with \(\mathtt{D}\) channels, the SSM is applied independently to each channel.

Structure fine print and personal musings

  • You can confirm this parametrization by inspecting the authors’ reference Mamba implementation here (or their S4 / S4D implementations) which simply repeats the same instantiation across all channel dimensions.
  • I’m curious how the models perform if one does not restrict to this parametrization. As in, the Mamba paper assumes that \(\mathtt{N}\) is an expansion factor for a single channel, but how would things behave all channels were handled together? This is less meaningful if we’re doing an S4D-Real style of initialization, but how about with random initialization?

Mamba SSM

To recap, we discussed

  1. LTI SSMs in their more general continuous form
  2. Discrete LTI SSMs
  3. Discrete LTI structured SSMs (à la S4)

The only actual model architecture change from Mamba is the removal of linear time-invariance for \(\mathbf{B}\), \(\mathbf{C}\), and \(\Delta\). They are now functions of \(\mathbf{x}_t\), i.e. the parameters are “selective”.

\[\begin{aligned} \mathbf{h}_{t} &= \mathbf{\overline{A}} \mathbf{h}_{t-1}+ \overline{\mathbf{B}(\mathbf{x}_t)}\mathbf{x}_t \\ \mathbf{y}_t &= \overline{\mathbf{C}(\mathbf{x}_t)}\mathbf{h}_t + \mathbf{\overline{D}}\mathbf{x}_t \end{aligned}\]

A departure from S4’s linear time invariance

Linear time invariance was critical to S4 because it provided the foundation for its efficient implementation. The sequential (slow) implementation is the obvious one: just use a for-loop, applying the parameters to each input one at a time.

The parallel (fast) implementation is more complex. In brief, because the updates from the SSM matrices are time-invariant, you can compute the outputs by constructing a particular kernel and then performing a full-width convolution. Since full-width convolutions can be computed quickly with the FFT trick, this ends up being a very efficient implementation.

However, LTI caps the expressivity of the model, as every update to state is handled identically, no matter what needs to be updated, if the current input is relevant or irrelevant, etc. With linear time variance, instead of learning fixed matrix parameters, Mamba learns functions which ingest the input and state at time \(t\) and then controls output dynamics.

This idea of not just naively applying the same simple state update for each input is also why gating RNNs (e.g. LSTM, GRU) are much more effective than vanilla RNNs, and indeed there is a strong theoretical connection between gating heuristics and selectivity in Mamba.

Fast implementation

Since Mamba is not a linear time-invariance SSM, we can no longer rely on a full-width convolution as the basis of fast (i.e. parallel) implementation. We need other strategies, otherwise we are still stuck with slow training and limited practical utility.

The most critical techniques the authors employ are the Blelloch work-efficient9 parallel scan and hardware-aware memory management, which together facilitate the fast (in real world wall-clock time) training of Mamba.

Pedagogical fine print and personal musings

  • The above notation is my own, but I’ve tried to clarify some things, e.g.
    • \(\mathbf{A}\) and \(\mathbf{D}\) remain linear time-invariant, although I am curious about ablations that also parametrize those two parameters.
    • \(\mathbf{B}\) and \(\mathbf{C}\) are parametrized by \(\mathbf{x}_t\), and then discretized.
  • All pedagogically friendly implementations of Mamba I was able to find actually omit the parallel scan. I’m personally conflicted about the utility of this, because it does simplify the implementation, but how you can implement something like this efficiently is half of the reason why Mamba is valuable in the first place.

Hardware-aware resource management

In most of computer science, we lean pretty hard on a computation model that hand waves away a lot of real world performance characteristics in favor of “constant time”. The random access machine is one popular model, but far from the only one. In any case, these models are akin to idealized models in an introductory physics class: they are exceedingly useful, but they also lie.

The real world is messier, with different hierarchies of performance in compute, memory, and bandwidth, all thanks to the very real implications of squeezing transistors on a silicon wafer of limited size.

A simple GPU program

Most GPU programs follow the same recipe:

  1. Load relevant data from host DRAM to GPU HBM.
  2. Load it into GPU SRAM.
  3. Perform desired computation.10
  4. Load it back into GPU HBM.
  5. Load it back into CPU DRAM.

Loading is by far the slowest part of this process, or at least for most neural network applications. The more you can minimize memory loading the better, even if you need to spend some compute to do so.11

One of the Mamba authors, Tri Dao, is well-known for his work on FlashAttention, which introduced hardware-aware computation of self-attention, brining memory requirements from quadratic to linear and also providing dramatic wall-clock time speedups.

The core realization with FlashAttention was that the size of the intermediate computations dwarfed the actual size of the inputs and outputs. I.e., the \(QK^T\) matrix has size \(O(\mathtt{L}^2)\), even though the inputs themselves are only \(O(\mathtt{L})\).

Hardware-aware Mamba

The same principle is applicable here. The intermediate computations (the actual state machine mechanics) are again larger than the inputs and outputs, i.e. \(O(\mathtt{BLDN} > O(\mathtt{BLD} + \mathtt{DN})\). Thus a similar approach works well here, were computations are done in a blockwise fashion, maximizing the amount of computation that can occur in SRAM before needing to load to/from HBM.

If this still feels a bit hand-wavey, then please go read Horace He’s excellent blog post Making Deep Learning Go Brrrr From First Principles, which is perhaps the best technical blog post I’ve read in the past few years.

And if you’re feeling brave, then I’d encourage you to dive into the core implementation in the authors’ reference implementation, i.e. selective_scan_fwd_kernel.cuh. In particular, pay attention to anything with smem or shared.

The Blelloch parallel prefix scan

Of course, a very memory efficient implementation isn’t particularly useful if you can’t compute it parallel, thanks to how modern hardware accelerators are built. Hence the desire for parallelism.

To be clear, parallel computation of linear RNNs is not something the Mamba authors invented, nor is it even particularly recent12. But it’s critical to an efficient implementation for Mamba, hence discussing it here. At a high level, there are two key pieces to gaining intuition here for Mamba:

  1. Understanding how the Blelloch parallel prefix scan algorithm works for simple binary associative operators, e.g. summation.
  2. Understanding how to represent the core bits of Mamba, or any linear RNN, as a binary operator.

This section will focus on the former.

Warm up: the parallel reduce

Before discussing the parallel scan, let’s first think about something simpler, the parallel reduce.

Suppose I have a list with \(k\) elements and I want to perform a reduce operation over some binary associative operator \(\oplus\) where:

  • An operator is a function that takes in multiple elements of one type and returns a single element of the same type. A binary operator takes in two elements.
  • A binary operator is associative if \(x \oplus (y \oplus z)= (x \oplus y) \oplus z\) for all \(x,y,z\).
  • A reduce, also sometimes known as a fold, is a function that recursively applies a binary operation to aggregate a sequence of values into a single cumulative result. E.g., \((((x_1 + x_2) + x_3) + ... + x_{k-1}) + x_k\) is the reduction of the plus operator over a list of inputs \(x_1\) to \(x_k\).

One possible naive implementation is just looping through the inputs in order. For example, with inputs [3, 1, 7, 0, 4, 1, 6, 3]:

def naive_reduce(op, identity, inputs):
    result = identity
    for i in inputs:
        result = op(result, i)
    return result

>>> naive_reduce(op=lambda x, y: x + y, identity=0, inputs_=[3, 1, 7, 0, 4, 1, 6, 3])
25

This has linear time complexity w.r.t. the number of inputs, which is as good as we can hope for without parallelism, but what if we had more workers? We could split our input into pairs, compute those sums, and then recursively use those results as the inputs to another reduce, forming a computation tree:

graph BT; 00[3] --> 10[4] 01[1] --> 10[4] 02[7] --> 11[7] 03[0] --> 11[7] 04[4] --> 12[5] 05[1] --> 12[5] 06[6] --> 13[9] 07[3] --> 13[9] 10[4] --> 20[11] 11[7] --> 20[11] 12[5] --> 21[14] 13[9] --> 21[14] 20[11] --> 30[25] 21[14] --> 30[25]

The Blelloch parallel scan

A scan aggregates results from a sequence of elements by applying the operator cumulatively. For example, given the same inputs, the scan output would be [3, 4, 11, 11, 15, 16, 22, 25].

Closely related is the prescan which just shifts outputs by one, starting with an identity for the given operator (e.g., zero for summation): [0, 3, 4, 11, 11, 15, 16, 22]. The Blelloch algorithm technically computes a prescan, although that is sufficient for our use case since it’s easy to go to/from a prescan.

First, the up sweep

The first step to a fast scan computation is to compute a parallel reduction, as we described in the previous section. However, this time, we preserve intermediate computations.

Next, the down sweep

In the down sweep, we will maintain the invariant that: every node contains the sum of all prior leaf nodes, as determined visit order in a pre-order traversal, e.g.:

def pre_order_traversal(node: Node) -> None:
  if not node:
    return
  visit(node)
    pre_order_traversal(node.left)
    pre_order_traversal(node.right)

If every node can contain the sum of all prior leaf nodes, then the values of the leaves themselves will be the results of a prescan!

Stepping through the down sweep with a concrete example:

When there are no prior leaf nodes, we use use the identity value, e.g. 0 for summation.

graph TD; 10[?] --> 00[?] 10[?] --> 01[?] 11[?] --> 02[?] 11[?] --> 03[?] 12[?] --> 04[?] 12[?] --> 05[?] 13[?] --> 06[?] 13[?] --> 07[?] 20 --> 10[?] 20 --> 11[?] 21[?] --> 12[?] 21[?] --> 13[?] 30 --> 20[?] 30[0] --> 21[?]

We now need to be careful about maintaining the invariant. Filling in the down-sweep level by level, for any particular node N:

  • downsweep[N].left.value = downsweep[N].value
    • For the following diagrams, a blue node indicates the contribution from the parent.
  • downsweep[N].right.value = downsweep[N].value + upsweep[N].left.value
    • For the following diagrams, a red node indicates a contribution from the downsweep tree, and the yellow node indicates the contribution from the upsweep tree, and orange indicates the combined result.
Up sweep
graph BT; classDef orange fill:#ffad5b,stroke:#333,stroke-width:2px; classDef yellow fill:#ffff5b,stroke:#333,stroke-width:2px; classDef red fill:#ff5b5b,stroke:#333,stroke-width:2px; 00[3] --> 10[4] 01[1] --> 10[4] 02[7] --> 11[7] 03[0] --> 11[7] 04[4] --> 12[5] 05[1] --> 12[5] 06[6] --> 13[9] 07[3] --> 13[9] 10[4] --> 20[11] 11[7] --> 20[11] 12[5] --> 21[14] 13[9] --> 21[14] 20[11] --> 30[25] 21[14] --> 30[25] class 20 yellow
graph BT; classDef orange fill:#ffad5b,stroke:#333,stroke-width:2px; classDef yellow fill:#ffff5b,stroke:#333,stroke-width:2px; classDef red fill:#ff5b5b,stroke:#333,stroke-width:2px; 00[3] --> 10[4] 01[1] --> 10[4] 02[7] --> 11[7] 03[0] --> 11[7] 04[4] --> 12[5] 05[1] --> 12[5] 06[6] --> 13[9] 07[3] --> 13[9] 10[4] --> 20[11] 11[7] --> 20[11] 12[5] --> 21[14] 13[9] --> 21[14] 20[11] --> 30[25] 21[14] --> 30[25] class 10,12 yellow
graph BT; classDef orange fill:#ffad5b,stroke:#333,stroke-width:2px; classDef yellow fill:#ffff5b,stroke:#333,stroke-width:2px; classDef red fill:#ff5b5b,stroke:#333,stroke-width:2px; 00[3] --> 10[4] 01[1] --> 10[4] 02[7] --> 11[7] 03[0] --> 11[7] 04[4] --> 12[5] 05[1] --> 12[5] 06[6] --> 13[9] 07[3] --> 13[9] 10[4] --> 20[11] 11[7] --> 20[11] 12[5] --> 21[14] 13[9] --> 21[14] 20[11] --> 30[25] 21[14] --> 30[25] class 00,02,04,06 yellow
Down sweep
graph TD; classDef orange fill:#ffad5b,stroke:#333,stroke-width:2px; classDef yellow fill:#ffff5b,stroke:#333,stroke-width:2px; classDef red fill:#ff5b5b,stroke:#333,stroke-width:2px; classDef blue fill:#5bd5ff,stroke:#333,stroke-width:2px; 10[?] --> 00[?] 10[?] --> 01[?] 11[?] --> 02[?] 11[?] --> 03[?] 12[?] --> 04[?] 12[?] --> 05[?] 13[?] --> 06[?] 13[?] --> 07[?] 20 --> 10[?] 20 --> 11[?] 21[?] --> 12[?] 21[?] --> 13[?] 30 --> 20[0] 30[0] --> 21[11] class 30 red class 21 orange class 20 blue
graph TD; classDef orange fill:#ffad5b,stroke:#333,stroke-width:2px; classDef yellow fill:#ffff5b,stroke:#333,stroke-width:2px; classDef red fill:#ff5b5b,stroke:#333,stroke-width:2px; classDef blue fill:#5bd5ff,stroke:#333,stroke-width:2px; 10[?] --> 00[?] 10[?] --> 01[?] 11[?] --> 02[?] 11[?] --> 03[?] 12[?] --> 04[?] 12[?] --> 05[?] 13[?] --> 06[?] 13[?] --> 07[?] 20 --> 10[0] 20 --> 11[4] 21[?] --> 12[11] 21[?] --> 13[16] 30 --> 20[0] 30[0] --> 21[11] class 20,21 red class 11,13 orange class 10,12 blue
graph TD; classDef orange fill:#ffad5b,stroke:#333,stroke-width:2px; classDef yellow fill:#ffff5b,stroke:#333,stroke-width:2px; classDef red fill:#ff5b5b,stroke:#333,stroke-width:2px; classDef blue fill:#5bd5ff,stroke:#333,stroke-width:2px; 10[?] --> 00[0] 10[?] --> 01[3] 11[?] --> 02[4] 11[?] --> 03[11] 12[?] --> 04[11] 12[?] --> 05[15] 13[?] --> 06[16] 13[?] --> 07[22] 20 --> 10[0] 20 --> 11[4] 21[?] --> 12[11] 21[?] --> 13[16] 30 --> 20[0] 30[0] --> 21[11] class 10,11,12,13 red class 01,03,05,07 orange class 00,02,04,06 blue

Voila! For a sketch of the proof of why this works, consider a node in a preorder traversal.

  • It is either a left child, or a right child (or the root).
  • If it is a left child, then it and its root have the same amount of prior leaf nodes, so it can get its value just by copying its parent’s value.
  • If it is a right child, then there are leaves from two possible regions:
    • all the leaves that are strictly “after” the parent node but before the right node (the contribution of the sibling/left node)
    • all leaves that are “before” the parent node entirely

A binary associative operator for Mamba

The above examples used “sum”, but the Blelloch parallel scan works for all binary associative operators. For Mamba, we can define an associative operator, that when given the proper inputs, will compute the prescan of state vectors in parallel!

Prerequisites

In fact, what we’ll discuss here is contextualized by Mamba, but is actually valid for all first-order recurrences of the following form13:

\[h_t= \begin{cases} b_0 & t = 0 \\ (a_t\otimes h_{t-1}) \oplus b_t & t>0 \\ \end{cases}\]

Where \(\oplus\) and \(\otimes\) meet the following criteria:

  • \(\oplus\) must be associative, i.e. \((x \oplus y) \oplus z = x \oplus (y \oplus z)\)
    • Notice that vector-vector addition satisfies this!
  • \(\otimes\) must be semiassociative, i.e. there exists a binary associative operator \(\odot\) such that \(x \otimes (y \otimes z) = (x \odot y) \otimes z\)
    • Notice that \(\odot\) as matrix-matrix multiplication and \(\otimes\) as matrix-vector multiplication satisfies this!
  • \(\otimes\) distributes over \(\oplus\): \(x \otimes (y \oplus z) = (x \otimes y) \oplus (x \otimes z)\)
    • Notice that above matrix/vector addition/multiplication operators satisfy this!

Defining the operator

We massage our inputs into the sequence of \(c_t \equiv[a_t, b_t] \text{ for } t = 1, 2, ... , \mathtt{L}\), where

\[\begin{aligned} a_t &= \mathbf{\overline{A}} \\ b_t &= \overline{\mathbf{B(x}_t)}\mathbf{x}_t\\ \end{aligned}\]

And define a new operator \(\bull\) as follows:

\[\begin{aligned} c_i \bullet c_j &\equiv [c_{j,a} c_{i,a}, \, c_{j,a} c_{i,b} + c_{j,b}] \\ &\equiv [a_j a_i, \, a_j b_i + b_j] \end{aligned}\]

The “identity” for this operator is \(c_0 = \left[\mathbf{\overline{A}}, \, \mathbf{h}_0 \right]\), where \(\mathbf{h}_0\) is the initial state vector, before having seen any inputs. This is analogous to \(0\) for the \(\text{add}\) operator.

We then apply the operator in parallel with the Blelloch (pre)scan algorithm, and the outputs at the second index will be the desired \(\mathbf{h}_t\) results, computed in an efficient manner!

Some proofs

From the set up of the operator and how we set up \(a_t\) and \(b_t\), the second component is basically guaranteed to compute the proper state vectors. Recall the first line of the Mamba SSM:

\[\mathbf{h}_{t} = \mathbf{\overline{A}} \mathbf{h}_{t-1}+ \overline{\mathbf{B}(\mathbf{x}_t)}\mathbf{x}_t\]

A sketch of the high level intuition here is:

  1. We can show that the \(\bull\) operator with our specified initializations computes \(\mathbf{h}_t\) with a sequential scan.
  2. We can show that the \(\bull\) operator is associative, so that we may use the Blelloch parallel scan algorithm in particular.

Proof of part 1

  1. We initialize \(b_0 = \mathbf{h}_0\).
  2. For time \(t \ge 1\), if \(b_{t-1}\) is equal to \(\mathbf{h}_{t-1}\), then \(a_t b_{t-1} + b_t = \mathbf{h}_t\). This is because we have set \(a_t = \mathbf{\overline{A}}\) and \(b_t = \overline{\mathbf{B}(\mathbf{x}_t)}\mathbf{x}_t\) for all \(t\).
  3. Via induction, since \(b_0 = \mathbf{h}_0\) is initialized correctly and subsequent \(b_t = \mathbf{h}_t\) are computed correctly if provided \(b_{t-1} = \mathbf{h}_{t-1}\), then scanning with this operator and these initializations correctly computes the desired state vectors.

Proof of part 2

This, unfortunately, is simply a wall of algebra:

\[\begin{aligned} &\text{Apply the definition of } \bull: \\ (c_i \bull c_j) \bull c_k &= [c_{j,a} \odot c_{i,a}, \; (c_{j,a} \otimes c_{i,b}) \oplus c_{j,b}] \bull c_k \\ &\text{Apply the definition of} \bull \text{again:} \\ &= [ c_{k,a} \odot (c_{j,a} \odot c_{i,a} ) , \; (c_{k,a} \otimes ((c_{j,a} \otimes c_{i,b}) \oplus c_{j,b})) \oplus c_{k,b}] \\ &\text{Associativity of } \odot : \\ &= [(c_{i,a} \odot c_{j,a}) \odot c_{i,a}, \; (c_{k,a} \otimes ((c_{j,a} \otimes c_{i,b}) \oplus c_{j,b})) \oplus c_{k,b}] \\ &\otimes \text{distributes } c_{k,a} \text { over} \oplus \text{:} \\ &= [(c_{k,a} \odot c_{j,a}) \odot c_{i,a}, \; (c_{k,a} \otimes (c_{j,a} \otimes c_{i,b}) \oplus (c_{k,a} \otimes c_{j,b})) \oplus c_{k,b}] \\ &\text{Associativity of} \oplus \text{:} \\ &= [(c_{k,a} \odot c_{j,a}) \odot c_{i,a}, \; (c_{k,a} \otimes (c_{j,a} \otimes c_{i,b})) \oplus ((c_{k,a} \otimes c_{j,b}) \oplus c_{k,b})] \\ &\text{Semiassociativity of } \otimes : \\ &= [(c_{k,a} \odot c_{j,a}) \odot c_{i,a}, \; ((c_{k,a} \odot c_{j,a}) \otimes c_{i,b}) \oplus ((c_{k,a} \otimes c_{j,b}) \oplus c_{k,b})] \\ &\text{Apply operator definition:} \\ &= c_i \bull [c_{k,a} \odot c_{j,a}, \; ( c_{k,a} \otimes c_{j,b} ) \oplus c_{k,b}] \\ &\text{Apply operator definition again:} \\ (c_i \bull c_j) \bull c_k &= c_i \bull (c_j \bull c_k) \end{aligned}\]

Sanity checking

The above seems to make sense, but perhaps you prefer python to \(\LaTeX\) (I wouldn’t blame you).

As a first sanity check, if you’ve gotten this far, now is perhaps a good time to start perusing some reference implementations. E.g., in the author’s CUDA code, you can see the above operator implemented quite literally at selective_scan_common.h:113.

But also, let’s write our own “unit tests” as an additional sanity check. FYI, this will leverage the jax.lax.associative_scan implementation, which is a batteries-included implementation14 of Blelloch’s algorithm.

Miscellaneous setup

# Various imports
from einops import einsum
import jax
# Jax.lax already has a convenient parallel scan implementation.
import jax.lax as lax
import jax.numpy as jnp

key = jax.random.PRNGKey(seed=42)

B = 1  # batch size
L = 8192 # context length
N = 64  # hidden state size
D = 2  # num in channels
V = 1  # num out channels

# Gets the various fake x_t inputs.
def generate_random_xs(key, num_inputs=L, num_channels=D):
    key, subkey = jax.random.split(key)
    xs = jax.random.lognormal(subkey, shape=(L, D))
    return key, xs

# Gets various fake A matrices. This isshape= actually constant in the paper,
# but it doesn't have to be.
def generate_random_As(key, num_inputs=L, state_size=N):
    key, subkey = jax.random.split(key)
    As = jax.random.lognormal(subkey, shape=(L, N, N))
    return key, As

# Gets various fake B(x_t) matrices.
def generate_random_Bxs(key, num_inputs=L, state_size=N, num_channels=D):
    key, subkey = jax.random.split(key)
    Bxs = jax.random.lognormal(subkey, shape=(L, N, D))
    return key, Bxs

# Gets the b_t term.
def get_bs(xs, Bxs):
    return einsum(Bxs, xs, "l n d, l d -> l n")

# Jax plays nicest with jnp.arrays, so we'll stuff the values inside a
# single array and just unpack things here. I suppose I could use PyTrees
# but please forgive a bit of laziness/hackiness on my part.
def extract(c, state_size):
    assert c.ndim == 1
    assert c.shape[0] == state_size * state_size + state_size
    return (
        c[:state_size * state_size].reshape((state_size, state_size)),
        c[-state_size:].reshape((state_size,))
    )

The operator implementation and test logic

def operator(c_prev, c_curr, num_inputs=L, state_size=N, num_channels=D):
    prev_a, prev_b = extract(c_prev, state_size)
    curr_a, curr_b = extract(c_curr, state_size)
    return jnp.concatenate([
        jnp.ravel(curr_a @ prev_a), 
        jnp.ravel(curr_a @ prev_b + curr_b)
    ])
vectorized_operator = jax.vmap(operator, in_axes=(0, 0), out_axes=0)

# Actually generate some fake test data.
key, xs = generate_random_xs(key)
key, Bxs = generate_random_Bxs(key)
key, As = generate_random_As(key)

bs = get_bs(xs, Bxs)
cs = jnp.concatenate([As.reshape(-1, N * N), bs], axis=1)

# %%timeit results on a freebie Google Colab VM: 
# 283 ms ± 16.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
lax_scanned = lax.associative_scan(vectorized_operator, cs)[:, -N:]

def naive_scan_hs(h_0, As, Bxs, xs):
    output = [h_0]
    for a, bx, x in zip(As, Bxs, xs):
        b = einsum(bx, x, "n d, d -> n")
        output.append(a @ output[-1] + b)
    return output[1:]

# %%timeit results on a freebie Google Colab VM:
# 3.34 s ± 313 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
naive_hs = jnp.vstack(
    naive_scan_hs(jnp.zeros((N,)), As, Bxs, xs)
)

# The following returns Array(True, dtype=bool)! Which means that
# we're getting identical results, (allowing for some floating point
# imprecision), regardless of if we're using the naive iterative
# implementation, or the fast parallel implementation.
jnp.allclose(naive_hs, lax_scanned)

Hooray, we’re getting identical results as well as a much faster wall time. These trends only get more extreme as you leverage more capable hardware (the above results were on a freebie Google Colaboratory VM, i.e. not hardware accelerated and using only two CPU cores).

Closing thoughts

Whew, that was a lot. Let’s do a quick recap of what we’ve covered:

Summary of above topics

  • Mamba is a state space model, and is at its core, recurrent/sequential.
  • Because Mamba relies on a fixed sized state representation, it is efficient at inference time, unlike transformers.
  • Mamba, like FlashAttention, maximizes efficiency in a real world wall-clock sense by doing as much work as possible in SRAM.
  • Mamba can be computed in parallel via the Blelloch parallel scan algorithm, making it efficient at training time, unlike RNNs.
  • Mamba’s parameters are selective, making it for better at modeling long term dependencies than vanilla RNNs.

In particular, we paid extra attention to

  1. Mamba’s specific SSM formulation.
  2. How linear RNNs, Mamba included, can be computed efficiently.

Remaining topics

There is a bunch of other important that I’ve glossed over here. Conceptually, I believe the selective SSM formulation and the scan implementation are the most important contributions, but a thorough discussion would also cover

  • The hardware aware computation15
  • Mamba’s fused “one block” architecture.
  • Other fairly standard optimization tricks, e.g. recomputation, kernel fusion, etc.
  • Theoretical ties between heuristic gating and selectivity

And of course, I provided a bit of code to sanity check the discussion around parallel computation of the state vectors, but this is not a full implementation 🙂. Perhaps I’ll get around to writing a more thorough treatise in the future, but for now, I hope you’ve found this interesting!

Musings on breadth

Mamba is one of the most impressive papers I’ve read in years. It’s particularly impressive because the core concepts (selectivity and efficient implementation) are simple (where simple ≠ easy) yet effective.

In fact, throughout the paper there are not many particular insights that are demanding of galaxy brain intellect. For example:

  • Non-LTI SSMs are really just the more general version of SSMs you might learn about in a control theory class.
  • There have been other examples of fused architectures.
  • The Blelloch scan algorithm is an old paper, originally published in 1990, and is something you’re very likely to encounter if you took some parallel programming courses. There’s even a Udacity course that touches on this!
  • The “selectivity” functions are very simple, comprising just matmuls, a softplus, and a broadcast.
  • I am inexperienced at writing kernel code, but I assume like any other engineering skill this is very doable with some practice.
  • There is a bit of insight that comes with knowing how hardware accelerators are built. Becoming an expert at this is, like all things, very difficult. But you can get some pretty useful knowledge by doing some high bang-for-buck reading, e.g. Making Deep Learning Go Brrrr From First Principles.
  • Legendre polynomials, à la HiPPO initialization, could very feasibly appear in one’s undergraduate coursework if they enjoy math (e.g., Math 442 at my alma mater).

But on the other hand, combining a useful knowledge of all of the above and distilling that into a cohesive and novel body of research? Super cool. I suppose that’s why both authors are younger than me but also tenure-track professors at top institutions.

Assorted references

Footnotes

  1. I.e. post-training, the parameters involved in the updates are always the same. 

  2. There are some tricks like KV caching which help here, but they unfortunately do not change asymptotic behavior. 

  3. Not to say that scaling is easy either, because HPC is a fiendishly difficult thing to get right, it’s just a very simple concept to understand: use moar computer 

  4. At the risk of offending proper mathematicians, you can kinda think of a first order differential equation as a infinitesimally precise recurrence. 

  5. You should take any of my opinions with a large grain of salt. I, for one, have not been awarded any ICLR Outstanding Paper Awards 🙂 

  6. “Simple” 🫠 

  7. IMHO, this is where most people get stuck when trying to understand S4. I mean, how often does the typical computer scientist encounter Legendre polynomials? 

  8. It is technically possible but quite silly to use a value smaller than \(\mathtt{D}\). Remember that the state itself is only \(\mathtt{N}\) dimensional, i.e. \(\mathbf{h}_t \in \mathbb{R}^\mathtt{N}\). It would be very challenging for a model to compress multiple \(\mathbb{R}^{\mathtt{D}}\) inputs into a single \(\mathbb{R}^{\mathtt{N}}\) if \(\mathtt{N} < \mathtt{D}\). 

  9. There are other parallel scan implementations, e.g. the Hillis-Steele scan. In a nutshell, the Blelloch scan is more “work-efficient” yet less “step-efficient”. At a high level, when the amount of work to do exceeds the amount of available parallelism, the Blelloch implementation is more efficient. This matches up the typical situation in most ML problems, where even a lot of parallel compute is dwarfed by the scale of data/input demands. 

  10. This is where the GPU kernel comes in, as in “kernel fusion”, a scary name for a simple thing. Also, is there a more overloaded word in STEM other than kernel? Kernel trick, GPU kernel, Linux kernel, convolution kernel, etc. Enough to warrant a fairly long Kernel (disambiguation) Wikipedia article. Maybe entropy or set are more overloaded… 

  11. I mean, to an extent; the KV-cache is still a thing. 

  12. Guy Blelloch published his work in 1988 and 1990, building upon work by Hillis and Steele from 1986. In doing so, he explicitly highlighted the capability of the parallel scan to efficiently compute certain first-order (and higher-order) recurrences.

    To my knowledge, Eric Martin and Chris Cundy were the first to publish work connecting the Blelloch parallel scan algorithm to linear RNNs in early 2018. The publication was successful by usual metrics, appearing in ICLR and garnering 43 citations as of early 2024, but it was perhaps overshadowed at the time due to the decreasing popularity of recurrent architectures due to the then-recent introduction of the transformer architecture. 

  13. Note that this is slightly different than what you will find in the OG Blelloch paper, since the first-order recurrence referred to in that paper swaps the position of \(h_t\) and \(a_t\). All following proofs have been modified accordingly. Rest assured, they are the same in spirit, with only some very minor algebra changes. 

  14. I’m a bit curious about whether the authors considered using JAX, its lax module and corresponding associative_scan implementation with Pallas whenever necessary. I suppose the implementation of the Blelloch scan is not actually that many lines of code, so perhaps it just comes down to preference. 

  15. I.e., do as much as possible in SRAM since writing to/from SRAM/HBM is slow.