<?xml version="1.0" encoding="utf-8"?><feed xmlns="http://www.w3.org/2005/Atom" ><generator uri="https://jekyllrb.com/" version="3.9.5">Jekyll</generator><link href="https://jmschndev.github.io/feed.xml" rel="self" type="application/atom+xml" /><link href="https://jmschndev.github.io/" rel="alternate" type="text/html" /><updated>2024-03-10T02:13:04+00:00</updated><id>https://jmschndev.github.io/feed.xml</id><title type="html">Sparse Notes</title><author><name>James Chen</name></author><entry><title type="html">Mamba No. 5 (A Little Bit Of…)</title><link href="https://jmschndev.github.io/jekyll/update/2024/02/12/mamba.html" rel="alternate" type="text/html" title="Mamba No. 5 (A Little Bit Of…)" /><published>2024-02-12T19:50:08+00:00</published><updated>2024-02-12T19:50:08+00:00</updated><id>https://jmschndev.github.io/jekyll/update/2024/02/12/mamba</id><content type="html" xml:base="https://jmschndev.github.io/jekyll/update/2024/02/12/mamba.html"><![CDATA[<p>In this post, I attempt to provide a walkthrough of the essence of the Mamba state space model architecture, occasionally sacrificing some rigor for intuition and overall pedagogical friendliness.</p>

<p>I don’t assume readers have any familiarity with state space models, but I do assume some familiarity with machine learning and mathematical notation.</p>

<p>If at any point you spot any errors, typos, or confusing wording, please let me know!</p>

<ul>
  <li><a href="#tldr">TL;DR</a></li>
  <li><a href="#setting-the-stage">Setting the stage</a>
    <ul>
      <li><a href="#exhibit-a-the-rnn">Exhibit A: the RNN</a>
        <ul>
          <li><a href="#pros">Pros</a></li>
          <li><a href="#cons">Cons</a></li>
        </ul>
      </li>
      <li><a href="#exhibit-b-the-transformer">Exhibit B: the transformer</a>
        <ul>
          <li><a href="#pros-1">Pros</a></li>
          <li><a href="#cons-1">Cons</a></li>
        </ul>
      </li>
      <li><a href="#why-mamba-why-now">Why Mamba? Why now?</a></li>
    </ul>
  </li>
  <li><a href="#linear-time-invariant-state-space-models">Linear time-invariant state space models</a>
    <ul>
      <li><a href="#continuous-form">Continuous form</a>
        <ul>
          <li><a href="#components">Components</a></li>
        </ul>
      </li>
      <li><a href="#discrete-form">Discrete form</a>
        <ul>
          <li><a href="#parameter-discretization">Parameter discretization</a></li>
        </ul>
      </li>
      <li><a href="#structured-ssms-à-la-s4s4d"><em>Structured</em> SSMs, à la S4/S4D</a>
        <ul>
          <li><a href="#structure-for-a-ie-initialization">Structure for <strong>A</strong>, i.e. initialization</a></li>
          <li><a href="#inputs-and-shapes">Inputs and shapes</a></li>
        </ul>
      </li>
    </ul>
  </li>
  <li><a href="#mamba-ssm">Mamba SSM</a>
    <ul>
      <li><a href="#a-departure-from-s4s-linear-time-invariance">A departure from S4’s linear time invariance</a></li>
      <li><a href="#fast-implementation">Fast implementation</a></li>
    </ul>
  </li>
  <li><a href="#hardware-aware-resource-management">Hardware-aware resource management</a>
    <ul>
      <li><a href="#a-simple-gpu-program">A simple GPU program</a></li>
      <li><a href="#hardware-aware-mamba">Hardware-aware Mamba</a></li>
    </ul>
  </li>
  <li><a href="#the-blelloch-parallel-prefix-scan">The Blelloch parallel prefix scan</a>
    <ul>
      <li><a href="#warm-up-the-parallel-reduce">Warm up: the parallel reduce</a></li>
      <li><a href="#the-blelloch-parallel-scan">The Blelloch parallel scan</a>
        <ul>
          <li><a href="#first-the-up-sweep">First, the up sweep</a></li>
          <li><a href="#next-the-down-sweep">Next, the down sweep</a></li>
        </ul>
      </li>
    </ul>
  </li>
  <li><a href="#a-binary-associative-operator-for-mamba">A binary associative operator for Mamba</a>
    <ul>
      <li><a href="#prerequisites">Prerequisites</a></li>
      <li><a href="#defining-the-operator">Defining the operator</a></li>
      <li><a href="#some-proofs">Some proofs</a>
        <ul>
          <li><a href="#proof-of-part-1">Proof of part 1</a></li>
          <li><a href="#proof-of-part-2">Proof of part 2</a></li>
        </ul>
      </li>
      <li><a href="#sanity-checking">Sanity checking</a>
        <ul>
          <li><a href="#miscellaneous-setup">Miscellaneous setup</a></li>
          <li><a href="#the-operator-implementation-and-test-logic">The operator implementation and test logic</a></li>
        </ul>
      </li>
    </ul>
  </li>
  <li><a href="#closing-thoughts">Closing thoughts</a>
    <ul>
      <li><a href="#summary-of-above-topics">Summary of above topics</a></li>
      <li><a href="#remaining-topics">Remaining topics</a></li>
      <li><a href="#musings-on-breadth">Musings on breadth</a></li>
    </ul>
  </li>
  <li><a href="#assorted-references">Assorted references</a></li>
  <li><a href="#footnotes">Footnotes</a></li>
</ul>

<h2 id="tldr">TL;DR</h2>

<p>Mamba is a <em>state space model</em> (SSM) architecture that improves upon the S4 architecture. Sometimes known as <em>S6</em>, it makes two important modifications to S4:</p>

<ul>
  <li><em>Selective</em> SSM parameters</li>
  <li>Efficient implementation via parallel <em>scan</em></li>
</ul>

<p>Mamba parallelizes well during training, scales well with context length, performs inference efficiently, and most importantly, displays <a href="https://twitter.com/_albertgu/status/1731727672286294400">strong empirical results</a>.</p>

<h2 id="setting-the-stage">Setting the stage</h2>

<p>Sequence models can be placed on a spectrum based on their approach to information representation, from highly compressed (e.g. RNNs) to highly explicit (e.g. transformers).</p>

<h3 id="exhibit-a-the-rnn">Exhibit A: the RNN</h3>

<p>Consider a vanilla RNN:</p>

\[\begin{aligned}h_t &amp;= \tanh(W_{hh}h_{t-1} + W_{xh}x_t)\\y_t &amp;= W_{hy}h_t\end{aligned}\]

<p>The fixed size state \(h_{t-1}\) represents all prior context in a sequence at time \(t\). This underpins the core tradeoffs associated with RNNs:</p>

<h4 id="pros">Pros</h4>

<ul>
  <li><strong>Efficient autoregressive inference:</strong> Since \(h_{t}\) encapsulates prior inputs, the model only needs to consider a small and constant set of new information for each subsequent input.</li>
  <li><strong>No limits to context length:</strong> There is nothing in the formulation that explicitly constrains the model to a maximal sequence length.</li>
</ul>

<h4 id="cons">Cons</h4>

<ul>
  <li><strong>Ineffective modeling of complex dependencies:</strong> All prior context must be compressed, via static<sup id="fnref:static" role="doc-noteref"><a href="#fn:static" class="footnote" rel="footnote">1</a></sup> updates, into a fixed amount of bits.</li>
  <li><strong>Slow training:</strong> Training requires <a href="https://en.wikipedia.org/wiki/Backpropagation_through_time">sequential backpropagation through time</a>, making poor utilization of hardware accelerators, e.g. GPUs or TPUs. Accelerators have enormous throughput for parallel computation, but are otherwise surprisingly slow at sequential computation.</li>
</ul>

<h3 id="exhibit-b-the-transformer">Exhibit B: the transformer</h3>

<p>On the other end of that spectrum, consider a decoder-only transformer model, à la GPT-3. In particular, let’s focus on its scaled dot-product self-attention layer:</p>

\[\text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V\]

<h4 id="pros-1">Pros</h4>

<ul>
  <li><strong>Unreasonably effective at modeling complex dependencies:</strong> Every token gets to <em>explicitly</em> attend to all other prior tokens, instead of relying on a fixed-sized state as a “summary”.</li>
  <li><strong>Highly parallel training:</strong> There are no dependencies along the time dimension, and the core operations are matrix multiplications, which hardware accelerators have been excellent at parallelizing for decades.</li>
</ul>

<h4 id="cons-1">Cons</h4>

<ul>
  <li><strong>Quadratic scaling with context length:</strong> Since every input attends to all prior inputs, the total amount of computation required accelerates as the number of tokens increases.</li>
  <li><strong>Autoregressive inference is expensive<sup id="fnref:expensive" role="doc-noteref"><a href="#fn:expensive" class="footnote" rel="footnote">2</a></sup>:</strong> Unlike RNNs, there is no fixed-sized compressed representation of the prior tokens; each new token must explicitly attend to all prior tokens.</li>
</ul>

<h3 id="why-mamba-why-now">Why Mamba? Why now?</h3>

<p>The contrasting tradeoffs of transformers and RNNs highlight the crux of sequence modeling research: how can we improve model quality within the constraints of available compute?</p>

<p>Recently, the industry has made rapid forward progress not from algorithmic breakthroughs but instead dramatic increases in compute, due to both increased funding and continual improvements in hardware development.</p>

<p>Which is to say, scaling isn’t particularly <em>clever</em>, but oh boy is it <em>effective</em><sup id="fnref:effective" role="doc-noteref"><a href="#fn:effective" class="footnote" rel="footnote">3</a></sup>.</p>

<p>Perhaps Rich Sutton said it best in <a href="http://www.incompleteideas.net/IncIdeas/BitterLesson.html">The Bitter Lesson</a>:</p>

<blockquote>
  <p>One thing that should be learned from the bitter lesson is the great power of general purpose methods, of methods that continue to scale with increased computation even as the available computation becomes very great.</p>
</blockquote>

<p>Speculatively, RNNs and transformers have a limited lifespan because they make poor use of increasingly abundant compute. It’s critical that we design models that better leverage compute while also maintaining or preserving fundamental model quality. Hence, the interest in Mamba.</p>

<h2 id="linear-time-invariant-state-space-models">Linear time-invariant state space models</h2>

<p>Mamba is based on S4, which is a <strong><em>linear time invariant (LTI) state space model (SSM)</em></strong>, a common and useful subset of state space models more generally. For now, let’s focus on LTI SSMs, starting with their continuous form:</p>

<h3 id="continuous-form">Continuous form</h3>

\[\begin{aligned}
\mathbf{h}'(t) &amp;= \mathbf{A} \mathbf{h}(t) + \mathbf{B}\mathbf{x}(t) \\
\mathbf{y}(t) &amp;= \mathbf{C}\mathbf{h}(t) + \mathbf{D}\mathbf{x}(t)
\end{aligned}\]

<p>Ah yes, a chunk of <a href="https://i.imgur.com/yAU72Pl.png">\(\LaTeX\)</a>, how intuitive. Let’s break it down.</p>

<h4 id="components">Components</h4>

<ul>
  <li>\(t\) represents time, and is a scalar real number, i.e. \(t \in \mathbb{R}\).
    <ul>
      <li>Although integers are a subset real of numbers (\(\mathbb{Z} \subset \mathbb{R}\)), there is some specific handling and notation for the discrete case. We’ll discuss this further in the next section.</li>
    </ul>
  </li>
  <li>\(\mathbf{x}(t) \in \mathbb{R}^{\mathtt{D}}\) is the input to our model at time \(t\), which has dimensionality \(\mathtt{D}\) (i.e., has \(\mathtt{D}\) channels).
    <ul>
      <li>E.g., if you are doing modeling over raw audio files, then \(\mathtt{D} = 1\), \(t \in \mathbb{R}^+\), and \(\mathbf{x}(t)\) is the amplitude from the microphone’s recording at time \(t\).</li>
      <li>E.g., if you are doing modeling over text token embeddings, then \(\mathtt{D} = \text{embedding dimensionality}\), \(t \in \mathbb{Z}^+\), and \(\mathbf{x}(t)\) is the \(\mathtt{D}\)-length embedding vector for the token at position index \(t\).</li>
    </ul>
  </li>
  <li>\(\mathbf{y}(t) \in \mathbb{R}^{\mathtt{V}}\) is the corresponding output of our model, at time \(t\).
    <ul>
      <li>E.g. for binary music/not-music classification over raw audio files: \(\mathtt{V} = 1\), \(t \in \mathbb{R}^+\), and \(\mathbf{y}(t) \in [0, 1]\) is the predicted probability that all prior audio context was music at time \(t\).</li>
      <li>E.g. for next-token language modeling: \(\mathtt{V}=\text{tokenizer vocab size}\), \(t \in \mathbb{Z}^+\), and \(\mathbf{y}(t) \in [0, 1]^\mathtt{V}\) is the predicted probability distribution over all tokens in the vocabulary.</li>
    </ul>
  </li>
  <li>Similar to a vanilla RNN, the state vector \(\mathbf{h}(t) \in \mathbb{R}^\mathtt{N}\) encapsulates all prior inputs at time \(t\).</li>
  <li>The matrices \(\mathbf{A} \in \mathbb{R}^{\texttt{N}\times\texttt{N}}\), \(\mathbf{B} \in \mathbb{R}^{\texttt{N}\times\texttt{D}}\), \(\mathbf{C} \in \mathbb{R}^{\texttt{V}\times\texttt{N}}\), and \(\mathbf{D} \in \mathbb{R}^{\texttt{V}\times\texttt{D}}\), known as the state matrix, input matrix, output matrix, and feedthrough matrix respectively, comprise the actual parameters of the SSM. These parameters, along with inputs and state, determine the outputs. These parameters also determine how the state itself evolves over the sequence of inputs.</li>
</ul>

<h3 id="discrete-form">Discrete form</h3>

<p>In the continuous case, the dynamics of how \(\mathbf{h}\) evolves over time are determined via the differential equation \(\mathbf{h}'(t) = \mathbf{A} \mathbf{h}(t) + \mathbf{B}\mathbf{x}(t)\). That is, the current value of \(\mathbf{h}\) itself determines how \(\mathbf{h}\) is changing at that moment in time.</p>

<p>The discrete case is similar, but because we <a href="https://images.squarespace-cdn.com/content/v1/57a9d8dcd482e9bbf179f445/1477647080565-2BE7N7RA4YLGNUPWWFEK/The+limit+does+not+exist.jpg?format=1500w">cannot differentiate discrete functions</a>, we notate this self-modifying recursive behavior with, well, a recurrence<sup id="fnref:recurrence" role="doc-noteref"><a href="#fn:recurrence" class="footnote" rel="footnote">4</a></sup>. We also use subscript notation instead of function notation to help emphasize this distinction:</p>

\[\begin{aligned}
\mathbf{h}_{t} &amp;= \mathbf{\overline{A}} \mathbf{h}_{t-1}+ \mathbf{\overline{B}}\mathbf{x}_t \\
\mathbf{y}_t &amp;= \mathbf{\overline{C}}\mathbf{h}_t + \mathbf{\overline{D}}\mathbf{x}_t
\end{aligned}\]

<p>Furthermore, you may now also notice some horizontal bars over our SSM parameters, which is indicative of <strong><em>discretized</em></strong> parameters:</p>

<h4 id="parameter-discretization">Parameter discretization</h4>

<p>Because SSMs’ most general formulation is continuous, when working with discrete data there is (usually) a discretization step, where the “discretized” parameters are annotated with an overhead line, and are dependent on a <strong>learned “</strong>step size” parameter \(\Delta\) and <strong>fixed</strong> choices for discretization functions \(f_\mathbf{A}\), \(f_\mathbf{B}\), \(f_\mathbf{C}\), and \(f_\mathbf{D}\).</p>

\[\begin{aligned}
\mathbf{\overline{A}} &amp;= f_{\mathbf{A}}(\Delta, \mathbf{A}) \\
\mathbf{\overline{B}} &amp;= f_{\mathbf{B}}(\Delta, \mathbf{A}, \mathbf{B})  \\
\mathbf{\overline{C}} &amp;= f_{\mathbf{C}}(\Delta, \mathbf{C}) \\
\mathbf{\overline{D}} &amp;= f_{\mathbf{D}}(\Delta, \mathbf{D}) \\
\end{aligned}\]

<p>There are several reasonable strategies for discretization, but an intuitive exposition here would be longer than I’d like for this post, so for now you’ll have to just trust me (or go read Albert Gu’s <a href="https://stacks.stanford.edu/file/druid:mb976vf9362/gu_dissertation-augmented.pdf">330 page thesis</a>).</p>

<p><strong>Discretization fine print and personal musings</strong></p>

<ul>
  <li>The authors prefer to drop \(\mathbf{\overline{D}}\) entirely from their exposition of SSMs, because it “<a href="https://stacks.stanford.edu/file/druid:mb976vf9362/gu_dissertation-augmented.pdf">can be viewed as a skip connection that does not interact with the state…, the most important part of the SSM.</a>” To be clear, it is still explicitly parametrized in their actual reference <a href="https://github.com/state-spaces/s4/blob/a246043077dbff8563a6b172426443dced9a9d96/models/s4/s4.py#L1676">S4</a> and <a href="https://github.com/state-spaces/mamba/blob/86a3a902ca4189689aabf1c09174235024c7aede/mamba_ssm/modules/mamba_simple.py#L114">Mamba</a> implementations, so IMHO<sup id="fnref:imho" role="doc-noteref"><a href="#fn:imho" class="footnote" rel="footnote">5</a></sup> it’s a bit more clear to leave it in the exposition even at the cost of verbosity.</li>
  <li>\(\mathbf{C}\) is not actually discretized per se, because \(\mathbf{A}\) and \(\mathbf{B}\) are already discretized, so in the Mamba paper there is no \(\mathbf{\overline{C}}\). Equivalently, in the S4 paper, \(\mathbf{C}\) is in fact discretized, where its discretization function is the identity. I personally prefer the notation of the latter, but it’s not really a big deal.</li>
  <li>There are at least a few valid methods of discretization, the Euler method, zero-order-hold (ZOH) method, or the bilinear method. The Euler method is the weakest, but choosing between the latter two is nuanced. In fact, the S4 paper goes with the bilinear method, but Mamba highlights ZOH instead.
    <ul>
      <li>I am no mathematician, but I find the setup of predefined discretization functions to be a touch surprising, because it goes against my personal deep-learning intuition of “let the model learn everything”. I’m curious how well learned discretization <strong><em>functions</em></strong> would perform, in addition to simply a learned step size parameter.</li>
    </ul>
  </li>
  <li>Whether or not discretization is necessary at all is also an interesting question. There are various nice properties and interpretations, such as viewing \(\mathbf{x}_t\) as being sampled from continuous data, i.e. \(\mathbf{x}_t = \mathbf{x}(t\Delta)\). Yet, authors also mention how other SSM research e.g. <a href="https://arxiv.org/abs/2303.09489">Zhang et al., 2018</a>, <em>Effectively Modeling Time Series with Simple<sup id="fnref:simple" role="doc-noteref"><a href="#fn:simple" class="footnote" rel="footnote">6</a></sup> Discrete State Spaces</em>,  did not explicitly discretize and still achieved good results.</li>
</ul>

<h3 id="structured-ssms-à-la-s4s4d"><em>Structured</em> SSMs, à la S4/S4D</h3>

<h4 id="structure-for-a-ie-initialization">Structure for <strong>A</strong>, i.e. initialization</h4>

<p>S4 (or rather, <a href="https://arxiv.org/pdf/2206.11893.pdf">S4D</a>) goes a step further and parametrizes \(\mathbf{A}\) as a diagonal matrix, initialized via <a href="https://arxiv.org/pdf/2206.12037.pdf">very clever approximations of the HiPPO matrix</a><sup id="fnref:hippo" role="doc-noteref"><a href="#fn:hippo" class="footnote" rel="footnote">7</a></sup> used in the original S4 paper. Some high level pointers:</p>

<ul>
  <li>HiPPO initialization was required to get good performance from the S4 architecture. A random initialization with S4 produced a model with middling performance.</li>
  <li>Matrix-vector multiplications are, of course, much faster/cheaper if your matrix is diagonal.</li>
  <li>S4D style initialization continues to work well with Mamba, but now random initialization also performs pretty well!</li>
</ul>

<h4 id="inputs-and-shapes">Inputs and shapes</h4>

<p>Earlier, we described LTI SSMs as being able to handle inputs with multiple channels. In particular, we parametrized the SSM matrices as \(\mathbf{A} \in \mathbb{R}^{\texttt{N}\times\texttt{N}}\), \(\mathbf{B} \in \mathbb{R}^{\texttt{N}\times\texttt{D}}\), \(\mathbf{C} \in \mathbb{R}^{\texttt{V}\times\texttt{N}}\), and \(\mathbf{D} \in \mathbb{R}^{\texttt{V}\times\texttt{D}}\), where</p>

<ul>
  <li>\(\mathtt{N}\) is the state space dimension size<sup id="fnref:size" role="doc-noteref"><a href="#fn:size" class="footnote" rel="footnote">8</a></sup></li>
  <li>\(\mathtt{D}\) is the number of input channels</li>
  <li>\(\mathtt{V}\) is the number of output channels</li>
</ul>

<p>However for S4 (and Mamba), the SSM is parametrized like so:</p>

<blockquote>
  <p>In this case, the \(\mathbf{A} \in \mathbb{R}^{\mathtt{N} \times \mathtt{N}}, \mathbf{B} \in \mathbb{R}^{\mathtt{N}\times 1}, \mathbf{C} \in \mathbb{R}^{1\times\mathtt{N}}\) matrices can all be represented by \(\mathtt{N}\) numbers. To operate over an input sequence \(\mathbf{x}\) of batch size \(\mathtt{B}\) and length \(\mathtt{L}\) with \(\mathtt{D}\) channels, the SSM is applied independently to each channel.</p>
</blockquote>

<p><strong>Structure fine print and personal musings</strong></p>

<ul>
  <li>You can confirm this parametrization by inspecting the authors’ reference Mamba implementation <a href="https://github.com/state-spaces/mamba/blob/2a3704fd47ba817b415627b06fd796b971fdc137/mamba_ssm/modules/mamba_simple.py#L104">here</a> (or their <a href="https://github.com/state-spaces/s4/blob/main/models/s4/s4.py#L596">S4</a> / <a href="https://github.com/state-spaces/s4/blob/a246043077dbff8563a6b172426443dced9a9d96/models/s4/s4d.py#L27">S4D</a> implementations) which simply repeats the same instantiation across all channel dimensions.</li>
  <li>I’m curious how the models perform if one does not restrict to this parametrization. As in, the Mamba paper assumes that \(\mathtt{N}\) is an expansion factor for a single channel, but how would things behave all channels were handled together? This is less meaningful if we’re doing an S4D-Real style of initialization, but how about with random initialization?</li>
</ul>

<h2 id="mamba-ssm">Mamba SSM</h2>

<p>To recap, we discussed</p>

<ol>
  <li><strong><em>LTI</em></strong> SSMs in their more general continuous form</li>
  <li><strong><em>Discrete</em></strong> LTI SSMs</li>
  <li>Discrete LTI <strong><em>structured</em></strong> SSMs (à la S4)</li>
</ol>

<p>The only actual model architecture change from Mamba is the removal of linear time-invariance for \(\mathbf{B}\), \(\mathbf{C}\), and \(\Delta\). They are now <em>functions</em> of \(\mathbf{x}_t\), i.e. the parameters are “selective”.</p>

\[\begin{aligned}
\mathbf{h}_{t} &amp;= \mathbf{\overline{A}} \mathbf{h}_{t-1}+ \overline{\mathbf{B}(\mathbf{x}_t)}\mathbf{x}_t \\
\mathbf{y}_t &amp;= \overline{\mathbf{C}(\mathbf{x}_t)}\mathbf{h}_t + \mathbf{\overline{D}}\mathbf{x}_t
\end{aligned}\]

<h3 id="a-departure-from-s4s-linear-time-invariance">A departure from S4’s linear time invariance</h3>

<p>Linear time invariance was critical to S4 because it provided the foundation for its efficient implementation. The sequential (slow) implementation is the obvious one: just use a for-loop, applying the parameters to each input one at a time.</p>

<p>The parallel (fast) implementation is more complex. In brief, because the updates from the SSM matrices are time-invariant, you can compute the outputs by constructing a particular kernel and then performing a full-width convolution. Since full-width convolutions can be computed quickly with the <a href="https://www.analog.com/media/en/technical-documentation/dsp-book/dsp_book_ch18.pdf">FFT trick</a>, this ends up being a very efficient implementation.</p>

<p>However, LTI caps the expressivity of the model, as every update to state is handled identically, no matter what needs to be updated, if the current input is relevant or irrelevant, etc. With linear time <em>variance</em>, instead of learning fixed matrix parameters, Mamba learns <em>functions</em> which ingest the input and state at time \(t\) and <em>then</em> controls output dynamics.</p>

<p>This idea of not just naively applying the same simple state update for each input is also why gating RNNs (e.g. LSTM, GRU) are much more effective than vanilla RNNs, and indeed there is a strong theoretical connection between gating heuristics and selectivity in Mamba.</p>

<h3 id="fast-implementation">Fast implementation</h3>

<p>Since Mamba is not a linear time-invariance SSM, we can no longer rely on a full-width convolution as the basis of fast (i.e. parallel) implementation. We need other strategies, otherwise we are still stuck with slow training and limited practical utility.</p>

<p>The most critical techniques the authors employ are the <em>Blelloch work-efficient<sup id="fnref:work-efficient" role="doc-noteref"><a href="#fn:work-efficient" class="footnote" rel="footnote">9</a></sup> parallel scan</em> and hardware-aware memory management, which together facilitate the fast (in real world wall-clock time) training of Mamba.</p>

<p><strong>Pedagogical fine print and personal musings</strong></p>

<ul>
  <li>The above notation is my own, but I’ve tried to clarify some things, e.g.
    <ul>
      <li>\(\mathbf{A}\) and \(\mathbf{D}\) remain linear time-invariant, although I am curious about ablations that also parametrize those two parameters.</li>
      <li>\(\mathbf{B}\) and \(\mathbf{C}\) are parametrized by \(\mathbf{x}_t\), and <strong><em>then</em></strong> discretized.</li>
    </ul>
  </li>
  <li>All pedagogically friendly implementations of Mamba I was able to find actually omit the parallel scan. I’m personally conflicted about the utility of this, because it does simplify the implementation, but <strong><em>how</em></strong> you can implement something like this efficiently is half of the reason why Mamba is valuable in the first place.</li>
</ul>

<h2 id="hardware-aware-resource-management">Hardware-aware resource management</h2>

<p>In most of computer science, we lean pretty hard on a computation model that hand waves away a lot of real world performance characteristics in favor of “constant time”. The random access machine is one popular model, but far from the only one. In any case, these models are akin to idealized models in an introductory physics class: they are exceedingly useful, but they also lie.</p>

<p>The real world is messier, with different hierarchies of performance in compute, memory, and bandwidth, all thanks to the very real implications of squeezing transistors on a silicon wafer of limited size.</p>

<h3 id="a-simple-gpu-program">A simple GPU program</h3>

<p>Most GPU programs follow the same recipe:</p>

<ol>
  <li>Load relevant data from host DRAM to GPU HBM.</li>
  <li>Load it into GPU SRAM.</li>
  <li>Perform desired computation.<sup id="fnref:kernel" role="doc-noteref"><a href="#fn:kernel" class="footnote" rel="footnote">10</a></sup></li>
  <li>Load it back into GPU HBM.</li>
  <li>Load it back into CPU DRAM.</li>
</ol>

<p>Loading is <em>by far</em> the slowest part of this process, or at least for most neural network applications. The more you can minimize memory loading the better, even if you need to spend some compute to do so.<sup id="fnref:kvcache" role="doc-noteref"><a href="#fn:kvcache" class="footnote" rel="footnote">11</a></sup></p>

<p>One of the Mamba authors, Tri Dao, is well-known for his work on FlashAttention, which introduced hardware-aware computation of self-attention, brining memory requirements from quadratic to linear and also providing dramatic wall-clock time speedups.</p>

<p>The core realization with FlashAttention was that the size of the intermediate computations dwarfed the actual size of the inputs and outputs. I.e., the \(QK^T\) matrix has size \(O(\mathtt{L}^2)\), even though the inputs themselves are only \(O(\mathtt{L})\).</p>

<h3 id="hardware-aware-mamba">Hardware-aware Mamba</h3>

<p>The same principle is applicable here. The intermediate computations (the actual state machine mechanics) are again larger than the inputs and outputs, i.e. \(O(\mathtt{BLDN} &gt; O(\mathtt{BLD} + \mathtt{DN})\). Thus a similar approach works well here, were computations are done in a blockwise fashion, maximizing the amount of computation that can occur in SRAM before needing to load to/from HBM.</p>

<p>If this still feels a bit hand-wavey, then please go read Horace He’s excellent blog post <a href="https://horace.io/brrr_intro.html">Making Deep Learning Go Brrrr From First Principles</a>, which is perhaps the best technical blog post I’ve read in the past few years.</p>

<p>And if you’re feeling brave, then I’d encourage you to dive into the core implementation in the authors’ reference implementation, i.e. <a href="https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/csrc/selective_scan/selective_scan_fwd_kernel.cuh#L82"><code class="language-plaintext highlighter-rouge">selective_scan_fwd_kernel.cuh</code></a>. In particular, pay attention to anything with <code class="language-plaintext highlighter-rouge">smem</code> or <code class="language-plaintext highlighter-rouge">shared</code>.</p>

<h2 id="the-blelloch-parallel-prefix-scan">The Blelloch parallel prefix scan</h2>

<p>Of course, a very memory efficient implementation isn’t particularly useful if you can’t compute it parallel, thanks to how modern hardware accelerators are built. Hence the desire for parallelism.</p>

<p>To be clear, parallel computation of linear RNNs is not something the Mamba authors invented, nor is it even particularly recent<sup id="fnref:recent" role="doc-noteref"><a href="#fn:recent" class="footnote" rel="footnote">12</a></sup>. But it’s critical to an efficient implementation for Mamba, hence discussing it here. At a high level, there are two key pieces to gaining intuition here for Mamba:</p>

<ol>
  <li>Understanding how the Blelloch parallel prefix scan algorithm works for simple binary associative operators, e.g. summation.</li>
  <li>Understanding how to represent the core bits of Mamba, or any linear RNN, as a binary operator.</li>
</ol>

<p>This section will focus on the former.</p>

<h3 id="warm-up-the-parallel-reduce">Warm up: the parallel reduce</h3>

<p>Before discussing the parallel scan, let’s first think about something simpler, the parallel reduce.</p>

<p>Suppose I have a list with \(k\) elements and I want to perform a <strong><em>reduce</em></strong> operation over some <strong><em>binary associative operator</em></strong> \(\oplus\) where:</p>

<ul>
  <li>An <strong><em>operator</em></strong> is a function that takes in multiple elements of one type and returns a single element of the same type. A <strong><em>binary</em></strong> operator takes in two elements.</li>
  <li>A binary operator is <strong><em>associative</em></strong> if \(x \oplus (y \oplus z)= (x \oplus y) \oplus z\) for all \(x,y,z\).</li>
  <li>A <strong><em>reduce</em></strong>, also sometimes known as a <strong><em>fold</em></strong>, is a function that recursively applies a binary operation to aggregate a sequence of values into a single cumulative result. E.g., \((((x_1 + x_2) + x_3) + ... + x_{k-1}) + x_k\) is the reduction of the plus operator over a list of inputs \(x_1\) to \(x_k\).</li>
</ul>

<p>One possible naive implementation is just looping through the inputs in order. For example, with inputs <code class="language-plaintext highlighter-rouge">[3, 1, 7, 0, 4, 1, 6, 3]</code>:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">naive_reduce</span><span class="p">(</span><span class="n">op</span><span class="p">,</span> <span class="n">identity</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span>
    <span class="n">result</span> <span class="o">=</span> <span class="n">identity</span>
    <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">inputs</span><span class="p">:</span>
        <span class="n">result</span> <span class="o">=</span> <span class="n">op</span><span class="p">(</span><span class="n">result</span><span class="p">,</span> <span class="n">i</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">result</span>

<span class="o">&gt;&gt;&gt;</span> <span class="n">naive_reduce</span><span class="p">(</span><span class="n">op</span><span class="o">=</span><span class="k">lambda</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">:</span> <span class="n">x</span> <span class="o">+</span> <span class="n">y</span><span class="p">,</span> <span class="n">identity</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">inputs_</span><span class="o">=</span><span class="p">[</span><span class="mi">3</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">7</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">6</span><span class="p">,</span> <span class="mi">3</span><span class="p">])</span>
<span class="mi">25</span>
</code></pre></div></div>

<p>This has linear time complexity w.r.t. the number of inputs, which is as good as we can hope for without parallelism, but what if we had more workers? We could split our input into pairs, compute those sums, and then recursively use those results as the inputs to another reduce, forming a computation tree:</p>

<div class="mermaid">
graph BT;
    00[3] --&gt; 10[4]
    01[1] --&gt; 10[4]
    02[7] --&gt; 11[7]
    03[0] --&gt; 11[7]
    04[4] --&gt; 12[5]
    05[1] --&gt; 12[5]
    06[6] --&gt; 13[9]
    07[3] --&gt; 13[9]

  10[4] --&gt; 20[11]
  11[7] --&gt; 20[11]
  12[5] --&gt; 21[14]
  13[9] --&gt; 21[14]

  20[11] --&gt; 30[25]
  21[14] --&gt; 30[25]
</div>

<h3 id="the-blelloch-parallel-scan">The Blelloch parallel scan</h3>

<p>A <strong><em>scan</em></strong> aggregates results from a sequence of elements by applying the operator cumulatively. For example, given the same inputs, the scan output would be <code class="language-plaintext highlighter-rouge">[3, 4, 11, 11, 15, 16, 22, 25]</code>.</p>

<p>Closely related is the <strong><em>prescan</em></strong> which just shifts outputs by one, starting with an identity for the given operator (e.g., zero for summation): <code class="language-plaintext highlighter-rouge">[0, 3, 4, 11, 11, 15, 16, 22]</code>. The Blelloch algorithm technically computes a prescan, although that is sufficient for our use case since it’s easy to go to/from a prescan.</p>

<h4 id="first-the-up-sweep">First, the up sweep</h4>

<p>The first step to a fast scan computation is to compute a parallel reduction, as we described in the previous section. However, this time, we preserve intermediate computations.</p>

<h4 id="next-the-down-sweep">Next, the down sweep</h4>

<p>In the down sweep, we will maintain the invariant that: <em>every node contains the sum of all prior leaf nodes</em>, as determined visit order in a pre-order traversal, e.g.:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">pre_order_traversal</span><span class="p">(</span><span class="n">node</span><span class="p">:</span> <span class="n">Node</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="bp">None</span><span class="p">:</span>
  <span class="k">if</span> <span class="ow">not</span> <span class="n">node</span><span class="p">:</span>
    <span class="k">return</span>
  <span class="n">visit</span><span class="p">(</span><span class="n">node</span><span class="p">)</span>
    <span class="n">pre_order_traversal</span><span class="p">(</span><span class="n">node</span><span class="p">.</span><span class="n">left</span><span class="p">)</span>
    <span class="n">pre_order_traversal</span><span class="p">(</span><span class="n">node</span><span class="p">.</span><span class="n">right</span><span class="p">)</span>
</code></pre></div></div>

<p>If every node can contain the sum of all prior leaf nodes, then the values of the leaves themselves will be the results of a prescan!</p>

<p><strong>Stepping through the down sweep with a concrete example:</strong></p>

<p>When there are no prior leaf nodes, we use use the identity value, e.g. 0 for summation.</p>

<div class="mermaid">
graph TD;

     10[?] --&gt; 00[?]
     10[?] --&gt; 01[?]
     11[?] --&gt; 02[?]
     11[?] --&gt; 03[?]
     12[?] --&gt; 04[?]
     12[?] --&gt; 05[?]
     13[?] --&gt; 06[?]
     13[?] --&gt; 07[?]

     20 --&gt; 10[?]
     20 --&gt; 11[?]
     21[?] --&gt; 12[?]
     21[?] --&gt; 13[?]

     30 --&gt; 20[?]
     30[0] --&gt; 21[?]
</div>

<p>We now need to be careful about maintaining the invariant. Filling in the down-sweep level by level, for any particular node <code class="language-plaintext highlighter-rouge">N</code>:</p>

<ul>
  <li><code>downsweep[N].left.value = downsweep[N].value</code>
    <ul>
      <li>For the following diagrams, a <span style="background-color: #5bd5ff;"><strong>blue</strong></span> node indicates the contribution from the parent.</li>
    </ul>
  </li>
  <li><code>downsweep[N].right.value = downsweep[N].value + upsweep[N].left.value</code>
    <ul>
      <li>For the following diagrams, a <span style="background-color: #ff5b5b;"><strong>red</strong></span> node indicates a contribution from the downsweep tree, and the <span style="background-color: #ffff5b;"><strong>yellow</strong></span> node indicates the contribution from the upsweep tree, and <span style="background-color: #ffad5b;"><strong>orange</strong></span> indicates the combined result.</li>
    </ul>
  </li>
</ul>

<div style="display: flex; justify-content: space-between;">
  <!-- Left Column for Odd Diagrams -->
  <div style="width: 48%; background-color: #ededed;">
    <center><b>Up sweep</b></center>
    <div class="mermaid">
      graph BT;
        classDef orange fill:#ffad5b,stroke:#333,stroke-width:2px;
        classDef yellow fill:#ffff5b,stroke:#333,stroke-width:2px;
        classDef red fill:#ff5b5b,stroke:#333,stroke-width:2px;

          00[3] --&gt; 10[4]
          01[1] --&gt; 10[4]
          02[7] --&gt; 11[7]
          03[0] --&gt; 11[7]
          04[4] --&gt; 12[5]
          05[1] --&gt; 12[5]
          06[6] --&gt; 13[9]
          07[3] --&gt; 13[9]

        10[4] --&gt; 20[11]
        11[7] --&gt; 20[11]
        12[5] --&gt; 21[14]
        13[9] --&gt; 21[14]

        20[11] --&gt; 30[25]
        21[14] --&gt; 30[25]

        class 20 yellow
    </div>

    <div class="mermaid">
      graph BT;
        classDef orange fill:#ffad5b,stroke:#333,stroke-width:2px;
        classDef yellow fill:#ffff5b,stroke:#333,stroke-width:2px;
        classDef red fill:#ff5b5b,stroke:#333,stroke-width:2px;

          00[3] --&gt; 10[4]
          01[1] --&gt; 10[4]
          02[7] --&gt; 11[7]
          03[0] --&gt; 11[7]
          04[4] --&gt; 12[5]
          05[1] --&gt; 12[5]
          06[6] --&gt; 13[9]
          07[3] --&gt; 13[9]

        10[4] --&gt; 20[11]
        11[7] --&gt; 20[11]
        12[5] --&gt; 21[14]
        13[9] --&gt; 21[14]

        20[11] --&gt; 30[25]
        21[14] --&gt; 30[25]

        class 10,12 yellow
    </div>

    <div class="mermaid">
      graph BT;
        classDef orange fill:#ffad5b,stroke:#333,stroke-width:2px;
        classDef yellow fill:#ffff5b,stroke:#333,stroke-width:2px;
        classDef red fill:#ff5b5b,stroke:#333,stroke-width:2px;

          00[3] --&gt; 10[4]
          01[1] --&gt; 10[4]
          02[7] --&gt; 11[7]
          03[0] --&gt; 11[7]
          04[4] --&gt; 12[5]
          05[1] --&gt; 12[5]
          06[6] --&gt; 13[9]
          07[3] --&gt; 13[9]

        10[4] --&gt; 20[11]
        11[7] --&gt; 20[11]
        12[5] --&gt; 21[14]
        13[9] --&gt; 21[14]

        20[11] --&gt; 30[25]
        21[14] --&gt; 30[25]

        class 00,02,04,06 yellow
    </div>
  </div>

  <!-- Right Column for Even Diagrams -->
  <div style="width: 48%; background-color: #ededed;">
    <center><b>Down sweep</b></center>
    <div class="mermaid">
      graph TD;
        classDef orange fill:#ffad5b,stroke:#333,stroke-width:2px;
        classDef yellow fill:#ffff5b,stroke:#333,stroke-width:2px;
        classDef red fill:#ff5b5b,stroke:#333,stroke-width:2px;
        classDef blue fill:#5bd5ff,stroke:#333,stroke-width:2px;

           10[?] --&gt; 00[?]
           10[?] --&gt; 01[?]
           11[?] --&gt; 02[?]
           11[?] --&gt; 03[?]
           12[?] --&gt; 04[?]
           12[?] --&gt; 05[?]
           13[?] --&gt; 06[?]
           13[?] --&gt; 07[?]

           20 --&gt; 10[?]
           20 --&gt; 11[?]
           21[?] --&gt; 12[?]
           21[?] --&gt; 13[?]

           30 --&gt; 20[0]
           30[0] --&gt; 21[11]

          class 30 red
          class 21 orange
          class 20 blue
    </div>

    <div class="mermaid">
      graph TD;
        classDef orange fill:#ffad5b,stroke:#333,stroke-width:2px;
        classDef yellow fill:#ffff5b,stroke:#333,stroke-width:2px;
        classDef red fill:#ff5b5b,stroke:#333,stroke-width:2px;
        classDef blue fill:#5bd5ff,stroke:#333,stroke-width:2px;

           10[?] --&gt; 00[?]
           10[?] --&gt; 01[?]
           11[?] --&gt; 02[?]
           11[?] --&gt; 03[?]
           12[?] --&gt; 04[?]
           12[?] --&gt; 05[?]
           13[?] --&gt; 06[?]
           13[?] --&gt; 07[?]

           20 --&gt; 10[0]
           20 --&gt; 11[4]
           21[?] --&gt; 12[11]
           21[?] --&gt; 13[16]

           30 --&gt; 20[0]
           30[0] --&gt; 21[11]

          class 20,21 red
          class 11,13 orange
          class 10,12 blue
    </div>

    <div class="mermaid">
      graph TD;
        classDef orange fill:#ffad5b,stroke:#333,stroke-width:2px;
        classDef yellow fill:#ffff5b,stroke:#333,stroke-width:2px;
        classDef red fill:#ff5b5b,stroke:#333,stroke-width:2px;
        classDef blue fill:#5bd5ff,stroke:#333,stroke-width:2px;

           10[?] --&gt; 00[0]
           10[?] --&gt; 01[3]
           11[?] --&gt; 02[4]
           11[?] --&gt; 03[11]
           12[?] --&gt; 04[11]
           12[?] --&gt; 05[15]
           13[?] --&gt; 06[16]
           13[?] --&gt; 07[22]

           20 --&gt; 10[0]
           20 --&gt; 11[4]
           21[?] --&gt; 12[11]
           21[?] --&gt; 13[16]

           30 --&gt; 20[0]
           30[0] --&gt; 21[11]

        class 10,11,12,13 red
        class 01,03,05,07 orange
        class 00,02,04,06 blue
    </div>
  </div>
</div>

<p>Voila! For a sketch of the proof of why this works, consider a node in a preorder traversal.</p>

<ul>
  <li>It is either a left child, or a right child (or the root).</li>
  <li>If it is a left child, then it and its root have the same amount of prior leaf nodes, so it can get its value just by copying its parent’s value.</li>
  <li>If it is a right child, then there are leaves from two possible regions:
    <ul>
      <li>all the leaves that are strictly “after” the parent node but before the right node (the contribution of the sibling/left node)</li>
      <li>all leaves that are “before” the parent node entirely</li>
    </ul>
  </li>
</ul>

<h2 id="a-binary-associative-operator-for-mamba">A binary associative operator for Mamba</h2>

<p>The above examples used “sum”, but the Blelloch parallel scan works for all binary associative operators. For Mamba, we can define an associative operator, that when given the proper inputs, will compute the prescan of state vectors in parallel!</p>

<h3 id="prerequisites">Prerequisites</h3>

<p>In fact, what we’ll discuss here is contextualized by Mamba, but is actually valid for all first-order recurrences of the following form<sup id="fnref:form" role="doc-noteref"><a href="#fn:form" class="footnote" rel="footnote">13</a></sup>:</p>

\[h_t= \begin{cases}
      b_0 &amp; t = 0 \\
      (a_t\otimes h_{t-1}) \oplus b_t &amp; t&gt;0 \\
   \end{cases}\]

<p>Where \(\oplus\) and \(\otimes\) meet the following criteria:</p>

<ul>
  <li>\(\oplus\) must be associative, i.e. \((x \oplus y) \oplus z = x \oplus (y \oplus z)\)
    <ul>
      <li>Notice that vector-vector addition satisfies this!</li>
    </ul>
  </li>
  <li>\(\otimes\) must be semiassociative, i.e. there exists a binary associative operator \(\odot\) such that \(x \otimes (y \otimes z) = (x \odot y) \otimes z\)
    <ul>
      <li>Notice that \(\odot\) as matrix-matrix multiplication and \(\otimes\) as matrix-vector multiplication satisfies this!</li>
    </ul>
  </li>
  <li>\(\otimes\) distributes over \(\oplus\): \(x \otimes (y \oplus z) = (x \otimes y) \oplus (x \otimes z)\)
    <ul>
      <li>Notice that above matrix/vector addition/multiplication operators satisfy this!</li>
    </ul>
  </li>
</ul>

<h3 id="defining-the-operator">Defining the operator</h3>

<p>We massage our inputs into the sequence of \(c_t \equiv[a_t, b_t] \text{ for } t = 1, 2, ... , \mathtt{L}\), where</p>

\[\begin{aligned}
a_t &amp;= \mathbf{\overline{A}} \\
b_t &amp;= \overline{\mathbf{B(x}_t)}\mathbf{x}_t\\
\end{aligned}\]

<p>And define a new operator \(\bull\) as follows:</p>

\[\begin{aligned}
c_i \bullet c_j &amp;\equiv [c_{j,a} c_{i,a}, \,  c_{j,a} c_{i,b}  + c_{j,b}] \\
&amp;\equiv [a_j a_i, \,  a_j b_i  + b_j]

\end{aligned}\]

<p>The “identity” for this operator is \(c_0 = \left[\mathbf{\overline{A}}, \, \mathbf{h}_0 \right]\), where \(\mathbf{h}_0\) is the initial state vector, before having seen any inputs. This is analogous to \(0\) for the \(\text{add}\) operator.</p>

<p>We then apply the operator in parallel with the Blelloch (pre)scan algorithm, and the outputs at the second index will be the desired \(\mathbf{h}_t\) results, computed in an efficient manner!</p>

<h3 id="some-proofs">Some proofs</h3>

<p>From the set up of the operator and how we set up \(a_t\) and \(b_t\), the second component is basically guaranteed to compute the proper state vectors. Recall the first line of the Mamba SSM:</p>

\[\mathbf{h}_{t} = \mathbf{\overline{A}} \mathbf{h}_{t-1}+ \overline{\mathbf{B}(\mathbf{x}_t)}\mathbf{x}_t\]

<p>A sketch of the high level intuition here is:</p>

<ol>
  <li>We can show that the \(\bull\) operator with our specified initializations computes \(\mathbf{h}_t\)  with a sequential scan.</li>
  <li>We can show that the \(\bull\) operator is <strong><em>associative</em></strong>, so that we may use the <strong><em>Blelloch parallel scan</em></strong> algorithm in particular.</li>
</ol>

<h4 id="proof-of-part-1">Proof of part 1</h4>

<ol>
  <li>We initialize \(b_0 = \mathbf{h}_0\).</li>
  <li>For time \(t \ge 1\), if \(b_{t-1}\) is equal to \(\mathbf{h}_{t-1}\), then  \(a_t b_{t-1}  + b_t = \mathbf{h}_t\). This is because we have set \(a_t = \mathbf{\overline{A}}\) and \(b_t = \overline{\mathbf{B}(\mathbf{x}_t)}\mathbf{x}_t\)  for all \(t\).</li>
  <li>Via induction, since \(b_0 = \mathbf{h}_0\) is initialized correctly and subsequent \(b_t = \mathbf{h}_t\) are computed correctly if provided \(b_{t-1} = \mathbf{h}_{t-1}\), then scanning with this operator and these initializations correctly computes the desired state vectors.</li>
</ol>

<h4 id="proof-of-part-2">Proof of part 2</h4>

<p>This, unfortunately, is simply a wall of algebra:</p>

\[\begin{aligned}
&amp;\text{Apply the definition of } \bull: \\
(c_i \bull c_j) \bull c_k &amp;= [c_{j,a} \odot  c_{i,a},  \; (c_{j,a} \otimes c_{i,b}) \oplus c_{j,b}] \bull c_k \\

&amp;\text{Apply the definition of} \bull \text{again:} \\
&amp;= [ c_{k,a} \odot (c_{j,a} \odot c_{i,a} ) , \;   (c_{k,a} \otimes ((c_{j,a} \otimes c_{i,b}) \oplus c_{j,b})) \oplus c_{k,b}] \\

&amp;\text{Associativity of } \odot : \\
&amp;= [(c_{i,a} \odot c_{j,a}) \odot c_{i,a}, \;   (c_{k,a} \otimes ((c_{j,a} \otimes c_{i,b}) \oplus c_{j,b})) \oplus c_{k,b}] \\

&amp;\otimes \text{distributes } c_{k,a} \text { over} \oplus \text{:} \\
&amp;= [(c_{k,a} \odot c_{j,a}) \odot c_{i,a}, \;   (c_{k,a} \otimes (c_{j,a} \otimes c_{i,b}) \oplus (c_{k,a} \otimes c_{j,b})) \oplus c_{k,b}] \\

&amp;\text{Associativity of} \oplus \text{:} \\
&amp;= [(c_{k,a} \odot c_{j,a}) \odot c_{i,a}, \;   (c_{k,a} \otimes (c_{j,a} \otimes c_{i,b})) \oplus ((c_{k,a} \otimes c_{j,b}) \oplus c_{k,b})] \\

&amp;\text{Semiassociativity of } \otimes : \\
&amp;= [(c_{k,a} \odot c_{j,a}) \odot c_{i,a}, \;   ((c_{k,a} \odot c_{j,a}) \otimes c_{i,b}) \oplus ((c_{k,a} \otimes c_{j,b}) \oplus c_{k,b})] \\

&amp;\text{Apply operator definition:} \\
&amp;= c_i \bull [c_{k,a} \odot c_{j,a}, \;  ( c_{k,a} \otimes c_{j,b} ) \oplus c_{k,b}] \\

&amp;\text{Apply operator definition again:} \\
(c_i \bull c_j) \bull c_k &amp;= c_i \bull (c_j \bull c_k)
\end{aligned}\]

<h3 id="sanity-checking">Sanity checking</h3>

<p>The above seems to make sense, but perhaps you prefer <code class="language-plaintext highlighter-rouge">python</code> to \(\LaTeX\) (I wouldn’t blame you).</p>

<p>As a first sanity check, if you’ve gotten this far, now is perhaps a good time to start perusing some reference implementations. E.g., in the author’s CUDA code, you can see the above operator implemented quite literally at <a href="https://github.com/state-spaces/mamba/blob/009bec5ee37f586844a3fc89c040a9c1a9d8badf/csrc/selective_scan/selective_scan_common.h#L113"><code class="language-plaintext highlighter-rouge">selective_scan_common.h:113</code></a>.</p>

<p>But also, let’s write our own “unit tests” as an additional sanity check. FYI, this will leverage the <code class="language-plaintext highlighter-rouge">jax.lax.associative_scan</code> implementation, which is a batteries-included implementation<sup id="fnref:implementation" role="doc-noteref"><a href="#fn:implementation" class="footnote" rel="footnote">14</a></sup> of Blelloch’s algorithm.</p>

<h4 id="miscellaneous-setup">Miscellaneous setup</h4>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1"># Various imports
</span><span class="kn">from</span> <span class="nn">einops</span> <span class="kn">import</span> <span class="n">einsum</span>
<span class="kn">import</span> <span class="nn">jax</span>
<span class="c1"># Jax.lax already has a convenient parallel scan implementation.
</span><span class="kn">import</span> <span class="nn">jax.lax</span> <span class="k">as</span> <span class="n">lax</span>
<span class="kn">import</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="n">jnp</span>

<span class="n">key</span> <span class="o">=</span> <span class="n">jax</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">PRNGKey</span><span class="p">(</span><span class="n">seed</span><span class="o">=</span><span class="mi">42</span><span class="p">)</span>

<span class="n">B</span> <span class="o">=</span> <span class="mi">1</span>  <span class="c1"># batch size
</span><span class="n">L</span> <span class="o">=</span> <span class="mi">8192</span> <span class="c1"># context length
</span><span class="n">N</span> <span class="o">=</span> <span class="mi">64</span>  <span class="c1"># hidden state size
</span><span class="n">D</span> <span class="o">=</span> <span class="mi">2</span>  <span class="c1"># num in channels
</span><span class="n">V</span> <span class="o">=</span> <span class="mi">1</span>  <span class="c1"># num out channels
</span>
<span class="c1"># Gets the various fake x_t inputs.
</span><span class="k">def</span> <span class="nf">generate_random_xs</span><span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="n">num_inputs</span><span class="o">=</span><span class="n">L</span><span class="p">,</span> <span class="n">num_channels</span><span class="o">=</span><span class="n">D</span><span class="p">):</span>
    <span class="n">key</span><span class="p">,</span> <span class="n">subkey</span> <span class="o">=</span> <span class="n">jax</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">split</span><span class="p">(</span><span class="n">key</span><span class="p">)</span>
    <span class="n">xs</span> <span class="o">=</span> <span class="n">jax</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">lognormal</span><span class="p">(</span><span class="n">subkey</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">L</span><span class="p">,</span> <span class="n">D</span><span class="p">))</span>
    <span class="k">return</span> <span class="n">key</span><span class="p">,</span> <span class="n">xs</span>

<span class="c1"># Gets various fake A matrices. This isshape= actually constant in the paper,
# but it doesn't have to be.
</span><span class="k">def</span> <span class="nf">generate_random_As</span><span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="n">num_inputs</span><span class="o">=</span><span class="n">L</span><span class="p">,</span> <span class="n">state_size</span><span class="o">=</span><span class="n">N</span><span class="p">):</span>
    <span class="n">key</span><span class="p">,</span> <span class="n">subkey</span> <span class="o">=</span> <span class="n">jax</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">split</span><span class="p">(</span><span class="n">key</span><span class="p">)</span>
    <span class="n">As</span> <span class="o">=</span> <span class="n">jax</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">lognormal</span><span class="p">(</span><span class="n">subkey</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">L</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">N</span><span class="p">))</span>
    <span class="k">return</span> <span class="n">key</span><span class="p">,</span> <span class="n">As</span>

<span class="c1"># Gets various fake B(x_t) matrices.
</span><span class="k">def</span> <span class="nf">generate_random_Bxs</span><span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="n">num_inputs</span><span class="o">=</span><span class="n">L</span><span class="p">,</span> <span class="n">state_size</span><span class="o">=</span><span class="n">N</span><span class="p">,</span> <span class="n">num_channels</span><span class="o">=</span><span class="n">D</span><span class="p">):</span>
    <span class="n">key</span><span class="p">,</span> <span class="n">subkey</span> <span class="o">=</span> <span class="n">jax</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">split</span><span class="p">(</span><span class="n">key</span><span class="p">)</span>
    <span class="n">Bxs</span> <span class="o">=</span> <span class="n">jax</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">lognormal</span><span class="p">(</span><span class="n">subkey</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">L</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">D</span><span class="p">))</span>
    <span class="k">return</span> <span class="n">key</span><span class="p">,</span> <span class="n">Bxs</span>

<span class="c1"># Gets the b_t term.
</span><span class="k">def</span> <span class="nf">get_bs</span><span class="p">(</span><span class="n">xs</span><span class="p">,</span> <span class="n">Bxs</span><span class="p">):</span>
    <span class="k">return</span> <span class="n">einsum</span><span class="p">(</span><span class="n">Bxs</span><span class="p">,</span> <span class="n">xs</span><span class="p">,</span> <span class="s">"l n d, l d -&gt; l n"</span><span class="p">)</span>

<span class="c1"># Jax plays nicest with jnp.arrays, so we'll stuff the values inside a
# single array and just unpack things here. I suppose I could use PyTrees
# but please forgive a bit of laziness/hackiness on my part.
</span><span class="k">def</span> <span class="nf">extract</span><span class="p">(</span><span class="n">c</span><span class="p">,</span> <span class="n">state_size</span><span class="p">):</span>
    <span class="k">assert</span> <span class="n">c</span><span class="p">.</span><span class="n">ndim</span> <span class="o">==</span> <span class="mi">1</span>
    <span class="k">assert</span> <span class="n">c</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">state_size</span> <span class="o">*</span> <span class="n">state_size</span> <span class="o">+</span> <span class="n">state_size</span>
    <span class="k">return</span> <span class="p">(</span>
        <span class="n">c</span><span class="p">[:</span><span class="n">state_size</span> <span class="o">*</span> <span class="n">state_size</span><span class="p">].</span><span class="n">reshape</span><span class="p">((</span><span class="n">state_size</span><span class="p">,</span> <span class="n">state_size</span><span class="p">)),</span>
        <span class="n">c</span><span class="p">[</span><span class="o">-</span><span class="n">state_size</span><span class="p">:].</span><span class="n">reshape</span><span class="p">((</span><span class="n">state_size</span><span class="p">,))</span>
    <span class="p">)</span>
</code></pre></div></div>

<h4 id="the-operator-implementation-and-test-logic">The operator implementation and test logic</h4>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">operator</span><span class="p">(</span><span class="n">c_prev</span><span class="p">,</span> <span class="n">c_curr</span><span class="p">,</span> <span class="n">num_inputs</span><span class="o">=</span><span class="n">L</span><span class="p">,</span> <span class="n">state_size</span><span class="o">=</span><span class="n">N</span><span class="p">,</span> <span class="n">num_channels</span><span class="o">=</span><span class="n">D</span><span class="p">):</span>
    <span class="n">prev_a</span><span class="p">,</span> <span class="n">prev_b</span> <span class="o">=</span> <span class="n">extract</span><span class="p">(</span><span class="n">c_prev</span><span class="p">,</span> <span class="n">state_size</span><span class="p">)</span>
    <span class="n">curr_a</span><span class="p">,</span> <span class="n">curr_b</span> <span class="o">=</span> <span class="n">extract</span><span class="p">(</span><span class="n">c_curr</span><span class="p">,</span> <span class="n">state_size</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">jnp</span><span class="p">.</span><span class="n">concatenate</span><span class="p">([</span>
        <span class="n">jnp</span><span class="p">.</span><span class="n">ravel</span><span class="p">(</span><span class="n">curr_a</span> <span class="o">@</span> <span class="n">prev_a</span><span class="p">),</span> 
        <span class="n">jnp</span><span class="p">.</span><span class="n">ravel</span><span class="p">(</span><span class="n">curr_a</span> <span class="o">@</span> <span class="n">prev_b</span> <span class="o">+</span> <span class="n">curr_b</span><span class="p">)</span>
    <span class="p">])</span>
<span class="n">vectorized_operator</span> <span class="o">=</span> <span class="n">jax</span><span class="p">.</span><span class="n">vmap</span><span class="p">(</span><span class="n">operator</span><span class="p">,</span> <span class="n">in_axes</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span> <span class="n">out_axes</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>

<span class="c1"># Actually generate some fake test data.
</span><span class="n">key</span><span class="p">,</span> <span class="n">xs</span> <span class="o">=</span> <span class="n">generate_random_xs</span><span class="p">(</span><span class="n">key</span><span class="p">)</span>
<span class="n">key</span><span class="p">,</span> <span class="n">Bxs</span> <span class="o">=</span> <span class="n">generate_random_Bxs</span><span class="p">(</span><span class="n">key</span><span class="p">)</span>
<span class="n">key</span><span class="p">,</span> <span class="n">As</span> <span class="o">=</span> <span class="n">generate_random_As</span><span class="p">(</span><span class="n">key</span><span class="p">)</span>

<span class="n">bs</span> <span class="o">=</span> <span class="n">get_bs</span><span class="p">(</span><span class="n">xs</span><span class="p">,</span> <span class="n">Bxs</span><span class="p">)</span>
<span class="n">cs</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">concatenate</span><span class="p">([</span><span class="n">As</span><span class="p">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">N</span> <span class="o">*</span> <span class="n">N</span><span class="p">),</span> <span class="n">bs</span><span class="p">],</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>

<span class="c1"># %%timeit results on a freebie Google Colab VM: 
# 283 ms ± 16.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
</span><span class="n">lax_scanned</span> <span class="o">=</span> <span class="n">lax</span><span class="p">.</span><span class="n">associative_scan</span><span class="p">(</span><span class="n">vectorized_operator</span><span class="p">,</span> <span class="n">cs</span><span class="p">)[:,</span> <span class="o">-</span><span class="n">N</span><span class="p">:]</span>

<span class="k">def</span> <span class="nf">naive_scan_hs</span><span class="p">(</span><span class="n">h_0</span><span class="p">,</span> <span class="n">As</span><span class="p">,</span> <span class="n">Bxs</span><span class="p">,</span> <span class="n">xs</span><span class="p">):</span>
    <span class="n">output</span> <span class="o">=</span> <span class="p">[</span><span class="n">h_0</span><span class="p">]</span>
    <span class="k">for</span> <span class="n">a</span><span class="p">,</span> <span class="n">bx</span><span class="p">,</span> <span class="n">x</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">As</span><span class="p">,</span> <span class="n">Bxs</span><span class="p">,</span> <span class="n">xs</span><span class="p">):</span>
        <span class="n">b</span> <span class="o">=</span> <span class="n">einsum</span><span class="p">(</span><span class="n">bx</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="s">"n d, d -&gt; n"</span><span class="p">)</span>
        <span class="n">output</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">a</span> <span class="o">@</span> <span class="n">output</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">+</span> <span class="n">b</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">output</span><span class="p">[</span><span class="mi">1</span><span class="p">:]</span>

<span class="c1"># %%timeit results on a freebie Google Colab VM:
# 3.34 s ± 313 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
</span><span class="n">naive_hs</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">vstack</span><span class="p">(</span>
    <span class="n">naive_scan_hs</span><span class="p">(</span><span class="n">jnp</span><span class="p">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">N</span><span class="p">,)),</span> <span class="n">As</span><span class="p">,</span> <span class="n">Bxs</span><span class="p">,</span> <span class="n">xs</span><span class="p">)</span>
<span class="p">)</span>

<span class="c1"># The following returns Array(True, dtype=bool)! Which means that
# we're getting identical results, (allowing for some floating point
# imprecision), regardless of if we're using the naive iterative
# implementation, or the fast parallel implementation.
</span><span class="n">jnp</span><span class="p">.</span><span class="n">allclose</span><span class="p">(</span><span class="n">naive_hs</span><span class="p">,</span> <span class="n">lax_scanned</span><span class="p">)</span>
</code></pre></div></div>

<p>Hooray, we’re getting identical results as well as a much faster wall time. These trends only get more extreme as you leverage more capable hardware (the above results were on a freebie Google Colaboratory VM, i.e. not hardware accelerated and using only two CPU cores).</p>

<h2 id="closing-thoughts">Closing thoughts</h2>

<p>Whew, that was a lot. Let’s do a quick recap of what we’ve covered:</p>

<h3 id="summary-of-above-topics">Summary of above topics</h3>

<ul>
  <li>Mamba is a state space model, and is at its core, recurrent/sequential.</li>
  <li>Because Mamba relies on a fixed sized state representation, it is efficient at inference time, unlike transformers.</li>
  <li>Mamba, like FlashAttention, maximizes efficiency in a real world wall-clock sense by doing as much work as possible in SRAM.</li>
  <li>Mamba can be computed in parallel via the Blelloch parallel scan algorithm, making it efficient at training time, unlike RNNs.</li>
  <li>Mamba’s parameters are selective, making it for better at modeling long term dependencies than vanilla RNNs.</li>
</ul>

<p>In particular, we paid extra attention to</p>

<ol>
  <li>Mamba’s specific SSM formulation.</li>
  <li>How linear RNNs, Mamba included, can be computed efficiently.</li>
</ol>

<h3 id="remaining-topics">Remaining topics</h3>

<p>There is a bunch of other important that I’ve glossed over here. Conceptually, I believe the selective SSM formulation and the scan implementation are the most important contributions, but a thorough discussion would also cover</p>

<ul>
  <li>The hardware aware computation<sup id="fnref:hardware-aware" role="doc-noteref"><a href="#fn:hardware-aware" class="footnote" rel="footnote">15</a></sup></li>
  <li>Mamba’s fused “one block” architecture.</li>
  <li>Other fairly standard optimization tricks, e.g. recomputation, kernel fusion, etc.</li>
  <li>Theoretical ties between heuristic gating and selectivity</li>
</ul>

<p>And of course, I provided a bit of code to sanity check the discussion around parallel computation of the state vectors, but this is not a full implementation 🙂. Perhaps I’ll get around to writing a more thorough treatise in the future, but for now, I hope you’ve found this interesting!</p>

<h3 id="musings-on-breadth">Musings on breadth</h3>

<p>Mamba is one of the most impressive papers I’ve read in years. It’s particularly impressive because the core concepts (selectivity and efficient implementation) are simple (where simple ≠ easy) yet effective.</p>

<p>In fact, throughout the paper there are not many particular insights that are demanding of galaxy brain intellect. For example:</p>

<ul>
  <li>Non-LTI SSMs are really just the more general version of SSMs you might learn about in a control theory class.</li>
  <li>There have been other examples of fused architectures.</li>
  <li>The Blelloch scan algorithm is an old paper, originally published in 1990, and is something you’re very likely to encounter if you took some parallel programming courses. There’s even a <a href="https://www.youtube.com/watch?v=mmYv3Haj6uc">Udacity course</a> that touches on this!</li>
  <li>The “selectivity” functions are very simple, comprising just matmuls, a softplus, and a broadcast.</li>
  <li>I am inexperienced at writing kernel code, but I assume like any other engineering skill this is very doable with some practice.</li>
  <li>There is a bit of insight that comes with knowing how hardware accelerators are built. Becoming an expert at this is, like all things, very difficult. But you can get some pretty useful knowledge by doing some high bang-for-buck reading, e.g. <a href="https://horace.io/brrr_intro.html">Making Deep Learning Go Brrrr From First Principles</a>.</li>
  <li>Legendre polynomials, à la HiPPO initialization, could very feasibly appear in one’s undergraduate coursework if they enjoy math (e.g., <a href="https://math.illinois.edu/resources/syllabus-math-442">Math 442</a> at my alma mater).</li>
</ul>

<p>But on the other hand, combining a useful knowledge of all of the above and distilling that into a cohesive and novel body of research? Super cool. I suppose that’s why both authors are younger than me but also tenure-track professors at top institutions.</p>

<h2 id="assorted-references">Assorted references</h2>

<ul>
  <li><a href="https://stacks.stanford.edu/file/druid:mb976vf9362/gu_dissertation-augmented.pdf">https://stacks.stanford.edu/file/druid:mb976vf9362/gu_dissertation-augmented.pdf</a></li>
  <li><a href="https://www.cs.cmu.edu/~guyb/papers/Ble93.pdf">https://www.cs.cmu.edu/~guyb/papers/Ble93.pdf</a></li>
  <li><a href="https://srush.github.io/annotated-s4/">https://srush.github.io/annotated-s4/</a></li>
  <li><a href="https://arxiv.org/abs/2312.00752">https://arxiv.org/abs/2312.00752</a></li>
  <li><a href="https://arxiv.org/abs/2111.00396">https://arxiv.org/abs/2111.00396</a></li>
</ul>

<h2 id="footnotes">Footnotes</h2>

<div class="footnotes" role="doc-endnotes">
  <ol>
    <li id="fn:static" role="doc-endnote">
      <p>I.e. post-training, the parameters involved in the updates are always the same. <a href="#fnref:static" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
    <li id="fn:expensive" role="doc-endnote">
      <p>There are some tricks like KV caching which help here, but they unfortunately do not change asymptotic behavior. <a href="#fnref:expensive" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
    <li id="fn:effective" role="doc-endnote">
      <p>Not to say that scaling is easy either, because HPC is a fiendishly difficult thing to get right, it’s just a very simple concept to understand: <strong><em>use moar computer</em></strong> <a href="#fnref:effective" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
    <li id="fn:recurrence" role="doc-endnote">
      <p>At the risk of offending proper mathematicians, you can kinda think of a first order differential equation as a infinitesimally precise recurrence. <a href="#fnref:recurrence" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
    <li id="fn:imho" role="doc-endnote">
      <p>You should take any of my opinions with a large grain of salt. I, for one, have not been awarded any ICLR Outstanding Paper Awards 🙂 <a href="#fnref:imho" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
    <li id="fn:simple" role="doc-endnote">
      <p><a href="https://i.imgur.com/Jjb1WsZ.png">“Simple”</a> 🫠 <a href="#fnref:simple" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
    <li id="fn:hippo" role="doc-endnote">
      <p>IMHO, this is where most people get stuck when trying to understand S4. I mean, how often does the typical computer scientist encounter Legendre polynomials? <a href="#fnref:hippo" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
    <li id="fn:size" role="doc-endnote">
      <p>It is technically possible but quite silly to use a value smaller than \(\mathtt{D}\). Remember that the state itself is only \(\mathtt{N}\) dimensional, i.e. \(\mathbf{h}_t \in \mathbb{R}^\mathtt{N}\). It would be very challenging for a model to compress multiple \(\mathbb{R}^{\mathtt{D}}\) inputs into a single \(\mathbb{R}^{\mathtt{N}}\) if \(\mathtt{N} &lt; \mathtt{D}\). <a href="#fnref:size" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
    <li id="fn:work-efficient" role="doc-endnote">
      <p>There are other parallel scan implementations, e.g. the <a href="https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_1:_Shorter_span,_more_parallel">Hillis-Steele scan</a>. In a nutshell, the Blelloch scan is more “work-efficient” yet less “step-efficient”. At a high level, when the amount of work to do exceeds the amount of available parallelism, the Blelloch implementation is more efficient. This matches up the typical situation in most ML problems, where even a lot of parallel compute is dwarfed by the scale of data/input demands. <a href="#fnref:work-efficient" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
    <li id="fn:kernel" role="doc-endnote">
      <p>This is where the GPU <em>kernel</em> comes in, as in “kernel fusion”, a scary name for a simple thing. Also, is there a more overloaded word in STEM other than kernel? Kernel trick, GPU kernel, Linux kernel, convolution kernel, etc. Enough to warrant a fairly long <a href="https://en.wikipedia.org/wiki/Kernel">Kernel (disambiguation)</a> Wikipedia article. Maybe <em>entropy</em> or <em>set</em> are more overloaded… <a href="#fnref:kernel" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
    <li id="fn:kvcache" role="doc-endnote">
      <p>I mean, to an extent; the KV-cache is still a thing. <a href="#fnref:kvcache" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
    <li id="fn:recent" role="doc-endnote">
      <p>Guy Blelloch published his work in 1988 and 1990, building upon work by Hillis and Steele from 1986. In doing so, he explicitly highlighted the capability of the parallel scan to efficiently compute certain first-order (and higher-order) recurrences.<br /><br />To my knowledge, Eric Martin and Chris Cundy were the first to publish work connecting the Blelloch parallel scan algorithm to linear RNNs in early 2018. The publication was successful by usual metrics, appearing in ICLR and garnering 43 citations as of early 2024, but it was perhaps overshadowed at the time due to the decreasing popularity of recurrent architectures due to the then-recent introduction of the transformer architecture. <a href="#fnref:recent" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
    <li id="fn:form" role="doc-endnote">
      <p>Note that this is slightly different than what you will find in the OG Blelloch paper, since the first-order recurrence referred to in that paper swaps the position of \(h_t\) and \(a_t\). All following proofs have been modified accordingly. Rest assured, they are the same in spirit, with only some very minor algebra changes. <a href="#fnref:form" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
    <li id="fn:implementation" role="doc-endnote">
      <p>I’m a bit curious about whether the authors considered using JAX, its <code class="language-plaintext highlighter-rouge">lax</code> module and corresponding <code class="language-plaintext highlighter-rouge">associative_scan</code> implementation with <a href="https://jax.readthedocs.io/en/latest/pallas/index.html">Pallas</a> whenever necessary. I suppose the implementation of the Blelloch scan is not actually <a href="https://developer.nvidia.com/gpugems/gpugems3/part-vi-gpu-computing/chapter-39-parallel-prefix-sum-scan-cuda#:~:text=39.2%20Implementation">that many lines of code</a>, so perhaps it just comes down to preference. <a href="#fnref:implementation" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
    <li id="fn:hardware-aware" role="doc-endnote">
      <p>I.e., do as much as possible in SRAM since writing to/from SRAM/HBM is slow. <a href="#fnref:hardware-aware" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
  </ol>
</div>]]></content><author><name>James Chen</name></author><category term="jekyll" /><category term="update" /><summary type="html"><![CDATA[In this post, I attempt to provide a walkthrough of the essence of the Mamba state space model architecture, occasionally sacrificing some rigor for intuition and overall pedagogical friendliness.]]></summary></entry><entry><title type="html">Nano Perceiver</title><link href="https://jmschndev.github.io/jekyll/update/2023/08/13/nano-perceiver.html" rel="alternate" type="text/html" title="Nano Perceiver" /><published>2023-08-13T19:50:08+00:00</published><updated>2023-08-13T19:50:08+00:00</updated><id>https://jmschndev.github.io/jekyll/update/2023/08/13/nano-perceiver</id><content type="html" xml:base="https://jmschndev.github.io/jekyll/update/2023/08/13/nano-perceiver.html"><![CDATA[<h2 id="tldr">Tl;dr</h2>

<p>The Perceiver family of models from DeepMind decouple context length from memory and compute requirements. Perceiver AR extends this with support for autoregressive generation. It also has a refreshingly simple implementation, since at it’s core it is just a small variation on top of an otherwise standard decoder-only transformer.</p>

<p>I’ve provided a lightweight implementation <a href="https://github.com/jmschndev/nano-perceiver/tree/main">here</a> and provide additional context in this post.</p>

<h2 id="background">Background</h2>

<h3 id="notation">Notation</h3>

<p>Let’s set some consistent notation:</p>

<ul>
  <li>For inputs, consider
    <ul>
      <li>Index (i.e., first) dimensionality as \(M\), also known as context size</li>
      <li>Channel (i.e., second) dimensionality as \(C\)</li>
    </ul>
  </li>
  <li>For a transformer model, consider:
    <ul>
      <li>A model with \(L\) transformer “blocks”, each with
        <ul>
          <li>an attention layer</li>
          <li>a relatively shallow MLP</li>
        </ul>
      </li>
    </ul>
  </li>
</ul>

<p>For example, these inputs could be</p>

<ul>
  <li>\(M\) token embeddings, each with an embedding size of \(C\)</li>
  <li>\(M\) raw pixels from a color image, with \(C = 3\) for RGB.</li>
</ul>

<h3 id="transformers-and-scale">Transformers and scale</h3>

<p>Transformers are increasingly powerful and expressive with increasing training data and parameter count, as demonstrated by large transformers like GPT-4 and PaLM. However, transformers scale poorly with w.r.t. both compute and memory when context size increases.</p>

<h4 id="self-attention-is-quadratic">Self-attention is quadratic</h4>

<p>The most important operation in a transformer model is the <em>attention</em> operation, hence the seminal paper’s title being <a href="https://arxiv.org/abs/1706.03762">Attention Is All You Need</a>.</p>

<p>If you’re unfamiliar with attention, then there are an inordinate amount of intuitive explanations out there, but here’s my abbreviated take that highlights the scaling: At its core, self-attention tries to determine for each input token, how it relates to each other input token. This is basically a doubly nested for-loop:</p>

<div class="language-py highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">self_attention</span><span class="p">(</span><span class="n">some_inputs</span><span class="p">):</span>
    <span class="n">scores</span> <span class="o">=</span> <span class="p">[</span>
        <span class="p">[</span><span class="mi">0</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">some_inputs</span><span class="p">))]</span>
        <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">some_inputs</span><span class="p">))</span>
    <span class="p">]</span>
    <span class="k">for</span> <span class="n">r</span><span class="p">,</span> <span class="n">query_tok</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">some_inputs</span><span class="p">):</span>
        <span class="k">for</span> <span class="n">c</span><span class="p">,</span> <span class="n">key_tok</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">some_inputs</span><span class="p">):</span>
            <span class="n">score_qk</span> <span class="o">=</span> <span class="n">relation</span><span class="p">(</span><span class="n">get_query_vector</span><span class="p">(</span><span class="n">query_tok</span><span class="p">),</span> <span class="n">get_key_vector</span><span class="p">(</span><span class="n">key_tok</span><span class="p">))</span>
            <span class="n">scores</span><span class="p">[</span><span class="n">r</span><span class="p">][</span><span class="n">c</span><span class="p">]</span> <span class="o">=</span> <span class="n">score_qk</span>
    <span class="p">...</span>  <span class="c1"># Normalize, combine with value vectors, etc
</span></code></pre></div></div>

<p>This is obviously quadratic.</p>

<p>If you’re familiar with basic matrix math, you’ll observe that you can rewrite this more compactly with some matrix operations. I.e. given \(Q, K, V \in \mathbb{R}^{M \times C}\), the attention operation is:</p>

\[\text{softmax} \left( \frac { QK^T }{\sqrt C} \right) V\]

<p>Here you can see that attention scales quadratically via observing that \(QK^T \in \mathbb{R}^{M \times M}\).</p>

<h4 id="but-lets-not-forget-about-linear-scaling-either">But let’s not forget about “linear” scaling either</h4>

<p>A standard transformer has \(L\) blocks, which means both self-attention and subsequent MLP are run \(L\) times, which brings overall complexity to \(O(M^2L)\) and \(O(ML)\) for the attention and MLP operations respectively.</p>

<p>Suppose by some miracle we are able to reduce self-attention’s complexity down to \(O(M)\), perhaps via some <a href="https://arxiv.org/abs/2004.05150">clever approximations</a>. Throwing the math under the rug of Big-O notation, this would imply that a transformer model scales with \(O(ML)\). This is of course better than \(O(M^2L)\), but for large M (e.g. long context models), this is still rather poor scaling.</p>

<p>This incentivizes model designers to come up with clever and bespoke ways of reducing context length (e.g. the <a href="https://arxiv.org/abs/2010.11929">ViT</a>, which famously converts an image into patches first). These are effective, but it’s not always clear how to handle other domains, and it’s tremendously easy to get input sizes that are in the deep hundreds of thousands, if not larger. For example, 10 seconds of time-domain audio (\(10s \times 44.1 \text{kHz} = 441 \text{ thousand}\)) or 10 seconds of low resolution (e.g. 240p) grayscale video without any audio (\(\text{240 px high} \times \text{362 px wide} \times 24 \text{Hz} \times 10s \approx 20 \text{ million}\)).</p>

<p>Therefore, there’s a desire to not only handle the quadratic nature of self-attention, but also to further decouple context length from computation (past processing the inputs, of course).</p>

<h2 id="perceiver-architecture">Perceiver architecture</h2>

<p>The crux of the Perceiver architecture is to use <strong>cross</strong>-attention for the first block. The only difference between cross-attention and self-attention is that cross attention uses query vectors that are attending to key and value vectors that may come from a different underlying inputs<sup id="fnref:shape" role="doc-noteref"><a href="#fn:shape" class="footnote" rel="footnote">1</a></sup>:</p>

\[Q \in \mathbb{R} ^ {N \times C}; \; K, V \in \mathbb{R} ^{M \times C}\]

<p>Consider \(N\) to be a hyperparameter, selected such that \(N \ll M\). Now \(QK^T \in \mathbb{R}^{N \times M}\); as in, the first attention layer is now an \(O(NM)\) operation! And, since the output latents are of length \(N\), each subsequent self atttention is \(O(N^2)\), and each MLP layer is also only \(O(N)\) now.</p>

<h3 id="getting-the-queries">Getting the queries</h3>

<p>This is all well and good, but it depends on our ability to actually get reasonable queries.</p>

<p>For the initial query, both the original Perceiver and Perceiver IO used a fixed latent, analogous to the initial hidden state in an RNN. Perceiver AR uses an arguably simpler approach, where the initial query just comes from the last \(N\) inputs. This way, you can still apply causal masking for autoregressive generation!</p>

<p><img src="/assets/images/perceiver_ar.png" alt="Perceiver AR diagram" /></p>

<p>Since \(N\) is now a hyperparameter choice, if one uses \(N=M\), then this is actually equivalent to a standard decoder-only transformer, which I find particularly elegant.</p>

<h3 id="just-show-me-the-code">Just show me the code</h3>

<p>So conceptually this makes sense, so what do we need to actually modify? Well, as you can also see from the above diagram, there’s not too much to change. You need to</p>

<ol>
  <li>Ensure your targets are the same size as your queries.</li>
  <li>Explicitly choose queries as the tail of your full context input.</li>
  <li>Ensure your “triangular” causal attention masking is shifted accordingly.</li>
</ol>

<h4 id="targets">Targets</h4>

<div class="language-py highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">stack</span><span class="p">([</span><span class="n">data</span><span class="p">[</span><span class="n">i</span> <span class="p">:</span> <span class="n">i</span> <span class="o">+</span> <span class="n">block_size</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">ix</span><span class="p">])</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">stack</span><span class="p">(</span>
    <span class="p">[</span><span class="n">data</span><span class="p">[</span><span class="n">i</span> <span class="o">+</span> <span class="n">block_size</span> <span class="o">+</span> <span class="mi">1</span> <span class="o">-</span> <span class="n">query_size</span> <span class="p">:</span> <span class="n">i</span> <span class="o">+</span> <span class="n">block_size</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">ix</span><span class="p">]</span>
<span class="p">)</span>
</code></pre></div></div>

<p>For a standard transformer, <code class="language-plaintext highlighter-rouge">query_size</code> is identical to <code class="language-plaintext highlighter-rouge">block_size</code>.</p>

<h4 id="attention-module">Attention module</h4>

<div class="language-py highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs_q</span><span class="p">,</span> <span class="n">inputs_kv</span><span class="p">):</span>
    <span class="p">...</span>
    <span class="c1"># Causal masking.
</span>    <span class="n">mask</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">tril</span><span class="p">(</span>
        <span class="n">torch</span><span class="p">.</span><span class="n">ones</span><span class="p">(</span><span class="n">q_time</span><span class="p">,</span> <span class="n">kv_time</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">attention</span><span class="p">.</span><span class="n">device</span><span class="p">),</span>
        <span class="n">diagonal</span><span class="o">=</span><span class="n">kv_time</span> <span class="o">-</span> <span class="n">q_time</span><span class="p">,</span>
    <span class="p">)</span>
    <span class="n">attention</span> <span class="o">=</span> <span class="n">attention</span><span class="p">.</span><span class="n">masked_fill</span><span class="p">(</span><span class="o">~</span><span class="n">mask</span><span class="p">.</span><span class="nb">bool</span><span class="p">(),</span> <span class="nb">float</span><span class="p">(</span><span class="s">"-inf"</span><span class="p">))</span>
</code></pre></div></div>

<p>For a standard transformer, the <code class="language-plaintext highlighter-rouge">diagonal</code> argument is just <code class="language-plaintext highlighter-rouge">0</code>, the default value.</p>

<h4 id="perceiver-block">Perceiver block</h4>

<div class="language-py highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="p">.</span><span class="n">Tensor</span><span class="p">):</span>
    <span class="n">inputs_q</span><span class="p">,</span> <span class="n">inputs_kv</span> <span class="o">=</span> <span class="n">x</span><span class="p">[:,</span> <span class="o">-</span><span class="bp">self</span><span class="p">.</span><span class="n">query_size</span> <span class="p">:,</span> <span class="p">:],</span> <span class="n">x</span>
    <span class="n">normed_q</span><span class="p">,</span> <span class="n">normed_kv</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">ln1</span><span class="p">(</span><span class="n">inputs_q</span><span class="p">),</span> <span class="bp">self</span><span class="p">.</span><span class="n">ln1</span><span class="p">(</span><span class="n">inputs_kv</span><span class="p">)</span>
    <span class="n">x</span> <span class="o">=</span> <span class="n">inputs_q</span> <span class="o">+</span> <span class="bp">self</span><span class="p">.</span><span class="n">attn</span><span class="p">(</span><span class="n">inputs_q</span><span class="o">=</span><span class="n">normed_q</span><span class="p">,</span> <span class="n">inputs_kv</span><span class="o">=</span><span class="n">normed_kv</span><span class="p">)</span>
    <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="p">.</span><span class="n">mlp</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">ln2</span><span class="p">(</span><span class="n">x</span><span class="p">))</span>
    <span class="k">return</span> <span class="n">x</span>
</code></pre></div></div>

<p>For a standard transformer, the main <code class="language-plaintext highlighter-rouge">Block</code>’s <code class="language-plaintext highlighter-rouge">forward</code> function only passes the input to attention, since it’s doing <strong>self</strong>-attention, but since we’re now doing <strong>cross</strong>-attention on inputs, we need to handle <code class="language-plaintext highlighter-rouge">inputs_q</code> and <code class="language-plaintext highlighter-rouge">inputs_kv</code> separately.</p>

<p>The attention operations in the middle of the transformer are self-attentions over the \(\mathbb{R} ^ {N \times C}\) latent space, so for these layers the <code class="language-plaintext highlighter-rouge">inputs_q == inputs_kv</code> since getting the last the full size of the latent is just <code class="language-plaintext highlighter-rouge">self.query_size</code> anyways.</p>

<h4 id="is-that-all">Is that all?</h4>

<p>And… that’s about it. Feel free to see the full repo <a href="https://github.com/jmschndev/nano-perceiver/tree/main">here</a>.<sup id="fnref:assumptions" role="doc-noteref"><a href="#fn:assumptions" class="footnote" rel="footnote">2</a></sup> The repo takes inspiration (and code) from Karpathy’s <a href="https://github.com/karpathy/ng-video-lecture">NanoGPT repo</a>, and reuses the core logic around a simple training script over a single plain text file Shakespeare “corpus”. I encourage you to mess around with some of the parameters in the <code class="language-plaintext highlighter-rouge">train.py</code> script. I was able to train surprisingly long context models, caveated that I was running locally on a 2019 Macbook, but of course you’re free to use a proper accelerator of your choice.</p>

<p>If you want to see a more fully fleshed out implementation, then you’ll be pleased to know that DeepMind has actually open sourced a repo <a href="https://github.com/google-research/perceiver-ar">here</a>. It’s a research codebase so it definitely was not designed for pedagogical friendliness, but it is supremely flexible<sup id="fnref:jax" role="doc-noteref"><a href="#fn:jax" class="footnote" rel="footnote">3</a></sup>.</p>

<h2 id="notes">Notes</h2>

<div class="footnotes" role="doc-endnotes">
  <ol>
    <li id="fn:shape" role="doc-endnote">
      <p>The original Perceiver paper goes further and uses \(Q \in \mathbb{R} ^ {N \times D}\), which requires further projections. Fun fact, the <a href="https://github.com/deepmind/deepmind-research/blob/f5de0ede8430809180254ee957abf36ed62579ef/perceiver/perceiver.py#L81">original implementation</a> refers to this projection as <code class="language-plaintext highlighter-rouge">conv_1d</code> even though it’s just a <code class="language-plaintext highlighter-rouge">Linear</code>. I’m not exactly sure why they’d call it this; perhaps due to other patterns in computer vision where a 1x1 convolution has been used to reduce channel size? <a href="#fnref:shape" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
    <li id="fn:assumptions" role="doc-endnote">
      <p>This repo doesn’t experiment with any more sophisticated optimizers, nonlinearities, or further-optimized attention chunking/computation. This repo also assumes you are not meddling with channel dimensions, e.g. not projecting into larger or smaller channel sizes. <a href="#fnref:assumptions" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
    <li id="fn:jax" role="doc-endnote">
      <p>It’s also implemented in JAX, which is a pro or con depending on whether or not you’re a Googler (jk, kind of 🙂). <a href="#fnref:jax" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
  </ol>
</div>]]></content><author><name>James Chen</name></author><category term="jekyll" /><category term="update" /><summary type="html"><![CDATA[Tl;dr]]></summary></entry><entry><title type="html">Filling in the middle for great good</title><link href="https://jmschndev.github.io/jekyll/update/2023/07/05/fim.html" rel="alternate" type="text/html" title="Filling in the middle for great good" /><published>2023-07-05T19:50:08+00:00</published><updated>2023-07-05T19:50:08+00:00</updated><id>https://jmschndev.github.io/jekyll/update/2023/07/05/fim</id><content type="html" xml:base="https://jmschndev.github.io/jekyll/update/2023/07/05/fim.html"><![CDATA[<h2 id="tldr">Tl;dr</h2>

<p>My favorite research findings are ones that make me go “Why didn’t I think of that?” due to a paradoxical combination of simplicity and cleverness. The OpenAI paper “<a href="https://arxiv.org/abs/2207.14255">Efficient Training of Language Models to Fill in the Middle</a>” (FIM) is the most recent thing I’ve read that makes me feel this way.</p>

<h2 id="language-modeling-101">Language modeling 101</h2>

<p>For the uninitiated, modern language model (pre)training has a very short number of steps involved:</p>

<ol>
  <li>Download the internet.</li>
  <li>Choose a model architecture.</li>
  <li>Given an imperfect view of some tokens, predict a perfect view.</li>
</ol>

<p>I’m being a bit facetious here because each step is complex, and also because this doesn’t discuss things like fine-tuning, inference, etc… But for <em>pretraining</em>, this is a reasonable approximation.</p>

<p>Some research focuses on step 1, e.g. how to construct the most useful training dataset, given the expected tradeoffs between quality and quantity. An enormous amount of research focuses on step 2, e.g. how to architect your model to improve its performance in any number of ways (memory efficiency, modeling distant relationships, etc). And some research focuses on step 3, e.g. how to set up your task. BART has a denoising approach (input text is corrupted), decoder-only transformers tend to do next token prediction (i.e., all tokens are visible except the last one), etc.</p>

<p><a href="https://arxiv.org/abs/2207.14255">Efficient Training of Language Models to Fill in the Middle</a> tackles step 3. It maintains the standard “next token” autoregressive task of most decoder-only transformers, but with the simple twist that some fraction of those tokens have their order modified.</p>

<h2 id="fim-in-a-nutshell">FIM in a nutshell</h2>

<p>For simplicity, we’ll also assume that we’re doing whitespace/punctuation tokenization.</p>

<p>Now consider the sentence “What I cannot create, I do not understand”. Under the typical autoregressive training setup, your input and target would look something like</p>

<pre><code class="language-txt">Input: ["&lt;bos&gt;", "What", "I", "cannot", "create", ",", "I", "do", "not", "understand"]
Target: ["What", "I", "cannot", "create", ",", "I", "do", "not", "understand", "&lt;eos&gt;"]
</code></pre>

<p>Where <code class="language-plaintext highlighter-rouge">&lt;bos&gt;</code> and <code class="language-plaintext highlighter-rouge">&lt;eos&gt;</code> are special tokens representing the beginning and end of an input sequence. For FIM, the model maintains the same mechanics around training and optimization, but the data is transformed a bit: it’s chunked into a prefix, middle, and suffix, represented with corresponding special tokens:</p>

<pre><code class="language-txt">Input: ["&lt;pre&gt;", "What", "I", "&lt;suf&gt;", "do", "not", "understand", "&lt;mid&gt;", "cannot", "create", ",", "I"]
Target: ["What", "I", "&lt;suf&gt;", "do", "not", "understand", "&lt;mid&gt;", "cannot", "create", ",", "I", "&lt;eot&gt;"]
</code></pre>

<p>This way, during inference time, you can prompt a model to fill in the middle very naturally by just providing context up to <code class="language-plaintext highlighter-rouge">&lt;mid&gt;</code>.  And… that’s all folks.</p>

<p>Well, not exactly. There are other details and context which, of course, you can find in the paper. But the crux of the paper really is that simple!</p>

<h2 id="putting-my-pm-hat-on">Putting my PM hat on</h2>

<p>Chatbots are a natural extension to autoregressive models, and relatively convenient for the engineers implementing them. But not every product or task is most naturally represented as an extension of “predict the most probable next word”. Compared to an off-the-shelf autoregressively pretrained LM, filling in the middle seems like an more intuitive way to generate text for certain types of products, e.g.:<sup id="fnref:difficulty" role="doc-noteref"><a href="#fn:difficulty" class="footnote" rel="footnote">1</a></sup></p>

<ul>
  <li>A docstring, given a function definition and a function body.</li>
  <li>A blog article’s content, given its title and a conclusion paragraph.</li>
  <li>A road trip itinerary, given that it starts in California and ends in New York City.</li>
</ul>

<p>Admittedly, there is a grab bag of overlapping techniques to address the problem of alignment. That is, there is ongoing research in turning something that predicts the most <em>probable</em> next token into something that is <em>useful</em>. These techniques may include standard supervised fine tuning, instruction tuning, RLHF<sup id="fnref:rlhf" role="doc-noteref"><a href="#fn:rlhf" class="footnote" rel="footnote">2</a></sup>, or simply prompt engineering.<sup id="fnref:prompt" role="doc-noteref"><a href="#fn:prompt" class="footnote" rel="footnote">3</a></sup></p>

<p>But there is no reason these techniques wouldn’t work with a FIM pretrained model. To top it off, FIM also seems to grant these new middle-filling capabilities without harming overall model performance, making it just about the closest thing I’ve seen to a <a href="https://en.wikipedia.org/wiki/There_ain%27t_no_such_thing_as_a_free_lunch">free lunch</a> in a little while.</p>

<h2 id="notes">Notes</h2>

<div class="footnotes" role="doc-endnotes">
  <ol>
    <li id="fn:difficulty" role="doc-endnote">
      <p>To be fair, it seems that filling in the middle may be a fundamentally more challenging problem than left-to-right (i.e. typical) sampling, perhaps because generated text needs to flow naturally with both a prefix and suffix, as opposed to just a prefix. <a href="#fnref:difficulty" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
    <li id="fn:rlhf" role="doc-endnote">
      <p>RLHF comes with its own grab bag of challenges, such as detecting/preventing over optimization of the reward model, high VRAM of actor-critic policy gradient methods, and the overarching sample inefficiency of basically all RL techniques. <a href="#fnref:rlhf" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
    <li id="fn:prompt" role="doc-endnote">
      <p>“Prompt engineering” always felt like kind of a bizarre phrase to me. “Engineering” ideally implies some sort of systematic understanding and predictability, but prompt engineering seems more like an art at best and voodoo magic at worst. <a href="#fnref:prompt" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
  </ol>
</div>]]></content><author><name>James Chen</name></author><category term="jekyll" /><category term="update" /><summary type="html"><![CDATA[Tl;dr]]></summary></entry><entry><title type="html">Training a chatbot to talk like me</title><link href="https://jmschndev.github.io/jekyll/update/2023/06/26/lowering-the-floor.html" rel="alternate" type="text/html" title="Training a chatbot to talk like me" /><published>2023-06-26T19:50:08+00:00</published><updated>2023-06-26T19:50:08+00:00</updated><id>https://jmschndev.github.io/jekyll/update/2023/06/26/lowering-the-floor</id><content type="html" xml:base="https://jmschndev.github.io/jekyll/update/2023/06/26/lowering-the-floor.html"><![CDATA[<p>There has been much well-deserved attention paid towards the latest advances in machine learning these days. I feel like I see a new paper or model every week that promises the Earth, moon, and stars.</p>

<p>Perhaps it’s a new approach that will finally™ solve the problem of quadratic scaling of transformers w.r.t. context length, be it via <a href="https://arxiv.org/pdf/2305.07185.pdf">clever tweaks inspired by convolutions</a>, <a href="https://arxiv.org/pdf/2302.10866.pdf">literally using convolutions</a>, <a href="https://arxiv.org/pdf/2205.14135.pdf">more clever utilization of accelerators</a>, or various <a href="https://arxiv.org/pdf/2305.19370.pdf">memory</a> <a href="https://arxiv.org/pdf/2103.03206.pdf">bottlenecks</a>.<sup id="fnref:1" role="doc-noteref"><a href="#fn:1" class="footnote" rel="footnote">1</a></sup> Perhaps it’s any number of new models that have been fine-tuned by hobbyists, perhaps using leaked LLaMA weights or ChatGPT/ShareGPT data.<sup id="fnref:2" role="doc-noteref"><a href="#fn:2" class="footnote" rel="footnote">2</a></sup></p>

<p>But there is another thing that hasn’t gotten as much mainstream attention. That is, just how <em>easy</em> it has become to experiment with some seriously advanced models, models that would have quite recently been state of the art and required non-trivial capital to train. Of course, researchers have been publishing models and code for a while now, but the current state of affairs with easy to use APIs, reasonably good documentation, and emphasis placed on community interaction and contribution? That feels rather new.<sup id="fnref:3" role="doc-noteref"><a href="#fn:3" class="footnote" rel="footnote">3</a></sup></p>

<p>As an example, I wanted to walk through a small language model that I trained for my own amusement.</p>

<h2 id="goal">Goal</h2>

<p>I wanted to train a model that could talk a bit more informally, and perhaps talk a bit more like how I will text with friends. It’s no secret that LLMs tend to output text that errs on the side of formal/verbose. My suspicion is that this is due to some combination of</p>

<ul>
  <li>Instructions to human annotators when generating text data, e.g. for supervised fine tuning datasets</li>
  <li>Instructions to human annotators when ranking text data, i.e generating reward model data for RLHF, i.e. ranking which outputs are better, i.e. more aligned with user preferences.</li>
  <li>Preference for “high quality” text sources during training, e.g. Wikipedia, news articles, or well upvoted comments on Reddit.<sup id="fnref:reddit" role="doc-noteref"><a href="#fn:reddit" class="footnote" rel="footnote">4</a></sup></li>
</ul>

<p>My hope (<a href="https://openai.com/research/improving-language-model-behavior">inspired by this paper</a>) was that it would take a relatively minimal amount of fine tuning to get a language model to chat more like me. That is, talk a bit more informally, use different punctuation (e.g. newlines instead of periods), etc.</p>

<h2 id="training-data">Training data</h2>

<p>Getting training data for this was fairly simple. I’ve been using Facebook Messenger for a long time now, and Facebook provides a convenient way to <a href="https://www.facebook.com/dyi">download all of your data</a> as a bunch of JSONs. Then it’s just a bit of Python to parse the messages, yielding the ones where I’m responding. That generator then can be used directly to create a Dataset, via Hugging Face’s dataset API; specifically <a href="https://huggingface.co/docs/datasets/v2.13.1/en/package_reference/main_classes#datasets.Dataset.from_generator"><code class="language-plaintext highlighter-rouge">from_generator</code></a>.</p>

<p>To be precise, the training data was in the format of “prompts” and “responses”, where prompts were continguous blocks of messages from anybody that wasn’t me, joined with a pipe char (<code class="language-plaintext highlighter-rouge">|</code>). Responses were of the same format, just comprising messages that I sent. I use a pipe char to avoid any preprocessing shenanigans regarding newlines that some tokenizers may perform (e.g., transforming into newlines into spaces). I used a character that wouldn’t normally come up in texting, since substituting other punctuation (e.g. a period) can change perceived tone.<sup id="fnref:4" role="doc-noteref"><a href="#fn:4" class="footnote" rel="footnote">5</a></sup></p>

<h2 id="model">Model</h2>

<p>On Hugging Face, there are many language models appropriate for conversational interaction. I chose to run things using a <a href="https://huggingface.co/facebook/blenderbot-400M-distill">400M-parameter distilled BlenderBot</a>, which uses a standard seq2seq (i.e. encoder-decoder) transformer. The <a href="https://arxiv.org/abs/2004.13637">paper</a> came out in 2020, which is comparatively old, but these models are convenient since they’ve already been fine-tuned on conversational prompts. In particular, the 400M-parameter model isn’t so large that one needs to start thinking about model parallelism yet and has the added benefit of knowledge distillation from the bigger versions. I.e., it ought to be fine for some quick/fun hacking.</p>

<h2 id="training">Training</h2>

<p>I used a bone-stock <a href="https://www.pytorchlightning.ai/index.html">PyTorch Lightning</a> training loop, which is almost a one-liner. Of course one can choose to implement their own checkpointing, looping over epochs, etc., but why bother reinventing the wheel, especially for a one-off just-for-fun training run?</p>

<h3 id="compute">Compute</h3>

<p>As far as compute goes, my 2019-era MacBook is woefully underpowered, but Colab Pro is cheap and good enough here. An instance with 50GB RAM and a 16GB VRAM GPU., albeit an old one, was plenty for my purpose, and cost roughly $.20/hr.<sup id="fnref:5" role="doc-noteref"><a href="#fn:5" class="footnote" rel="footnote">6</a></sup></p>

<h3 id="parameter-efficiency">Parameter efficiency</h3>

<p>I didn’t want to fine-tune the entire model, since that would’ve taken a while with the admittedly slower NVIDIA T4 that I was using. However, the <a href="https://huggingface.co/blog/peft">peft</a> library makes it surprisingly easy to leverage SOTA fine-tuning methods. For me, using <a href="https://arxiv.org/abs/2303.10512">AdaLoRA</a>, an improvement of low-rank adaptation that was only published in March, was three operations: an import, a config initialization, and then an assignment.</p>

<p>One quirk of the library is that it has a hard-coded mapping from transformer model architectures (as strings!) to modules for which to actually adapt via AdaLoRA, see source <a href="https://github.com/huggingface/peft/blob/86290e9660d24ef0d0cedcf57710da249dd1f2f4/src/peft/utils/other.py#L246C1-L246C54">here</a>.<sup id="fnref:6" role="doc-noteref"><a href="#fn:6" class="footnote" rel="footnote">7</a></sup> These mappings don’t work out of the box for BlenderBot, but you can just inspect the module names<sup id="fnref:tin" role="doc-noteref"><a href="#fn:tin" class="footnote" rel="footnote">8</a></sup> and then it’s no problem:</p>

<div class="language-py highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">Imitator</span><span class="p">(</span><span class="n">pl</span><span class="p">.</span><span class="n">LightningModule</span><span class="p">):</span>

    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">().</span><span class="n">__init__</span><span class="p">()</span>
        <span class="n">pretrained_model</span> <span class="o">=</span> <span class="n">BlenderbotForConditionalGeneration</span><span class="p">.</span><span class="n">from_pretrained</span><span class="p">(</span><span class="s">"facebook/blenderbot-400M-distill"</span><span class="p">)</span>
        <span class="n">peft_config</span> <span class="o">=</span> <span class="n">AdaLoraConfig</span><span class="p">(</span>
            <span class="n">peft_type</span><span class="o">=</span><span class="s">"ADALORA"</span><span class="p">,</span>
            <span class="n">task_type</span><span class="o">=</span><span class="n">TaskType</span><span class="p">.</span><span class="n">SEQ_2_SEQ_LM</span><span class="p">,</span>
            <span class="n">inference_mode</span><span class="o">=</span><span class="bp">False</span><span class="p">,</span>
            <span class="n">r</span><span class="o">=</span><span class="mi">16</span><span class="p">,</span>
            <span class="n">lora_alpha</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span>
            <span class="n">target_modules</span><span class="o">=</span><span class="p">[</span><span class="s">"q_proj"</span><span class="p">,</span> <span class="s">"k_proj"</span><span class="p">,</span> <span class="s">"v_proj"</span><span class="p">,</span> <span class="s">"out_proj"</span><span class="p">,</span> <span class="s">"fc1"</span><span class="p">,</span> <span class="s">"fc2"</span><span class="p">],</span>
            <span class="n">lora_dropout</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span>
        <span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">model</span> <span class="o">=</span> <span class="n">get_peft_model</span><span class="p">(</span><span class="n">pretrained_model</span><span class="p">,</span> <span class="n">peft_config</span><span class="p">)</span>
<span class="p">...</span>
</code></pre></div></div>

<h2 id="amusing-non-scientific-results">Amusing non-scientific results</h2>

<h3 id="example-1">Example 1</h3>

<h4 id="prompt">Prompt</h4>

<pre><code class="language-txt">&gt; dude i need your advice
&gt; she said she wanted to "take a break"
</code></pre>

<h4 id="jamesbot-response">JamesBot response</h4>

<pre><code class="language-txt">&gt; LOL
&gt; i'm not sure what to say to that
&gt; i don't really know what advice to give
</code></pre>

<p>Alright, perhaps a bit unempathetic and not much of a specific response, but definitely talks more like me. In particular, the model’s learned that I tend to encapsulate thoughts in distint messages often without ending punctuation.</p>

<h3 id="example-2">Example 2</h3>

<h4 id="prompt-1">Prompt</h4>

<pre><code class="language-txt">&gt; dude
&gt; gels are such a game changer
&gt; felt amazing at mile 15 today
</code></pre>

<h4 id="jamesbot-response-1">JamesBot response</h4>

<pre><code class="language-txt">&gt; oh wow
&gt; that's a lot of running
&gt; i'm impressed
&gt; i can't run a mile
</code></pre>

<p>This isn’t perfectly accurate since I can indeed run a mile and also advocate for gels<sup id="fnref:gels" role="doc-noteref"><a href="#fn:gels" class="footnote" rel="footnote">9</a></sup> on runs longer than 90 minutes, but this is plausibly something I would’ve said before getting into running.</p>

<h2 id="the-punchline">The punchline</h2>

<p>The punchline isn’t that I took a pre-trained model and then fine-tuned it on another dataset; after all that’s been done before.</p>

<p>The punchline is that, not including the script I used to parse my Facebook messages, this was only ~50 lines of code. Code that, by my assessment, is rather explicit and not reminiscent of code golf.</p>

<p>It’s incredible that the open source ecosystem has advanced to the stage where you can experiment with very modern techniques (transformers, parameter efficient fine-tuning, etc) in just O(dozens) of LoC.</p>

<p>This lowers the barrier of entry for not only hobbyists and enthusiasts, but also for professionals who have requirements that aren’t met by the existing ecosystem of model inference APIs, or simply prefer driving stick.</p>

<h2 id="notes">Notes</h2>

<div class="footnotes" role="doc-endnotes">
  <ol>
    <li id="fn:1" role="doc-endnote">
      <p>What’s even more fascinating than this research outright, is just how much other research continues to be done on top of bone stock transformers, even pretty hype research. E.g., DeepMind’s <a href="https://www.deepmind.com/publications/a-generalist-agent">Gato</a> was trained over a decoder-only transformer “for simplicity and scalability”. <a href="#fnref:1" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
    <li id="fn:2" role="doc-endnote">
      <p>What’s also interesting here is that both of these seem to be in a legal gray area since LLaMA weights were leaked and OpenAI’s terms of service prohibits their products from being used to train competitor models, and despite that they’re both incredibly popular approaches. <a href="#fnref:2" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
    <li id="fn:3" role="doc-endnote">
      <p>This is perhaps biased by my experience, which has been mostly with heavy duty infrastructure that originated from within Alphabet. <a href="#fnref:3" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
    <li id="fn:reddit" role="doc-endnote">
      <p>Some may not consider Reddit comments to be “high quality”, but it’s important to compare it to internet text <em>en masse</em>. Seriously, take a look at some examples from these canonical large web scrapes. There’s an amusing amount of SEO spam, websites that just aren’t even parsed correctly, e.g. “This website requires JavaScript …”. <a href="#fnref:reddit" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
    <li id="fn:4" role="doc-endnote">
      <p>Here’s a <a href="https://www.nytimes.com/2021/06/29/crosswords/texting-punctuation-period.html">fun article</a> from the New York Times discussing this in more detail. <a href="#fnref:4" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
    <li id="fn:5" role="doc-endnote">
      <p>This is an estimate, since the Colab pricing model depends on “compute credits” per hour, and it’s not entirely clear how those rates are calculated. Regardless, you can get 100 credits per month for $10, and high-ram instances with an NVIDIA T4 were consistently around 2.xx credits per hour. <a href="#fnref:5" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
    <li id="fn:6" role="doc-endnote">
      <p>Quite a few of these are also commented out, for reasons that aren’t super clear to me. Perhaps it’s the library maintainers being strict about tests? <a href="#fnref:6" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
    <li id="fn:tin" role="doc-endnote">
      <p>This is easy, via <a href="https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.modules"><code class="language-plaintext highlighter-rouge">nn.Module.modules()</code></a>. It’s always nice when something does exactly what it says on the tin. <a href="#fnref:tin" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
    <li id="fn:gels" role="doc-endnote">
      <p>For the uninitiated, energy gels are portable and easy-to-digest carbs for endurance athletes. <a href="#fnref:gels" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
  </ol>
</div>]]></content><author><name>James Chen</name></author><category term="jekyll" /><category term="update" /><summary type="html"><![CDATA[There has been much well-deserved attention paid towards the latest advances in machine learning these days. I feel like I see a new paper or model every week that promises the Earth, moon, and stars.]]></summary></entry></feed>