Tl;dr

My favorite research findings are ones that make me go “Why didn’t I think of that?” due to a paradoxical combination of simplicity and cleverness. The OpenAI paper “Efficient Training of Language Models to Fill in the Middle” (FIM) is the most recent thing I’ve read that makes me feel this way.

Language modeling 101

For the uninitiated, modern language model (pre)training has a very short number of steps involved:

  1. Download the internet.
  2. Choose a model architecture.
  3. Given an imperfect view of some tokens, predict a perfect view.

I’m being a bit facetious here because each step is complex, and also because this doesn’t discuss things like fine-tuning, inference, etc… But for pretraining, this is a reasonable approximation.

Some research focuses on step 1, e.g. how to construct the most useful training dataset, given the expected tradeoffs between quality and quantity. An enormous amount of research focuses on step 2, e.g. how to architect your model to improve its performance in any number of ways (memory efficiency, modeling distant relationships, etc). And some research focuses on step 3, e.g. how to set up your task. BART has a denoising approach (input text is corrupted), decoder-only transformers tend to do next token prediction (i.e., all tokens are visible except the last one), etc.

Efficient Training of Language Models to Fill in the Middle tackles step 3. It maintains the standard “next token” autoregressive task of most decoder-only transformers, but with the simple twist that some fraction of those tokens have their order modified.

FIM in a nutshell

For simplicity, we’ll also assume that we’re doing whitespace/punctuation tokenization.

Now consider the sentence “What I cannot create, I do not understand”. Under the typical autoregressive training setup, your input and target would look something like

Input: ["<bos>", "What", "I", "cannot", "create", ",", "I", "do", "not", "understand"]
Target: ["What", "I", "cannot", "create", ",", "I", "do", "not", "understand", "<eos>"]

Where <bos> and <eos> are special tokens representing the beginning and end of an input sequence. For FIM, the model maintains the same mechanics around training and optimization, but the data is transformed a bit: it’s chunked into a prefix, middle, and suffix, represented with corresponding special tokens:

Input: ["<pre>", "What", "I", "<suf>", "do", "not", "understand", "<mid>", "cannot", "create", ",", "I"]
Target: ["What", "I", "<suf>", "do", "not", "understand", "<mid>", "cannot", "create", ",", "I", "<eot>"]

This way, during inference time, you can prompt a model to fill in the middle very naturally by just providing context up to <mid>. And… that’s all folks.

Well, not exactly. There are other details and context which, of course, you can find in the paper. But the crux of the paper really is that simple!

Putting my PM hat on

Chatbots are a natural extension to autoregressive models, and relatively convenient for the engineers implementing them. But not every product or task is most naturally represented as an extension of “predict the most probable next word”. Compared to an off-the-shelf autoregressively pretrained LM, filling in the middle seems like an more intuitive way to generate text for certain types of products, e.g.:1

  • A docstring, given a function definition and a function body.
  • A blog article’s content, given its title and a conclusion paragraph.
  • A road trip itinerary, given that it starts in California and ends in New York City.

Admittedly, there is a grab bag of overlapping techniques to address the problem of alignment. That is, there is ongoing research in turning something that predicts the most probable next token into something that is useful. These techniques may include standard supervised fine tuning, instruction tuning, RLHF2, or simply prompt engineering.3

But there is no reason these techniques wouldn’t work with a FIM pretrained model. To top it off, FIM also seems to grant these new middle-filling capabilities without harming overall model performance, making it just about the closest thing I’ve seen to a free lunch in a little while.

Notes

  1. To be fair, it seems that filling in the middle may be a fundamentally more challenging problem than left-to-right (i.e. typical) sampling, perhaps because generated text needs to flow naturally with both a prefix and suffix, as opposed to just a prefix. 

  2. RLHF comes with its own grab bag of challenges, such as detecting/preventing over optimization of the reward model, high VRAM of actor-critic policy gradient methods, and the overarching sample inefficiency of basically all RL techniques. 

  3. “Prompt engineering” always felt like kind of a bizarre phrase to me. “Engineering” ideally implies some sort of systematic understanding and predictability, but prompt engineering seems more like an art at best and voodoo magic at worst.