Introduction

In “ML at Scale: Data Parallelism”, we learned about how to run many copies of models in parallel. But what do we do once models get so big that they can’t fit on one accelerator anymore? We start to split them across multiple machines.

In “ML at Scale: Tensor Parallelism”, we’ll learn about how to shard a model through tensor parallelism. In this post, we’ll talk about pipeline parallelism— how to slice a model in many sections and then “pipeline” data through them.

Conceptual Diagram

Starting Simple: Naive Pipeline Parallelism

Imagine no data parallelism; we have one batch that we’re sending through one model copy. But the model is so big that we have to split it sequentially across four machines in a pipeline parallel fashion. So machine_0 has the first 1/4th of the model, machine_1 has the next one-fourth, and so on.

How would this work?

First, all machines would need to perform the forward pass — we need the output of the last model chunk to be able to calculate a loss.

Once we have the loss, we need to calculate gradients. For this, each chunk of the model needs its own activations and the gradients from the previous stage. Therefore, gradients will be calculated from the last chunk all the way back to the first chunk.

Let’s visualize this with a diagram.

Pipeline Parallelism Visualization Flow

As displayed in GPipe [1], here is another way of visualizing it:

GPipe Demonstration GPipe Time Series Demo

The Problem with Naive Pipeline Parallelism

Think about this: machine_1, machine_2 and machine_3 will be completely idling while machine_0 executes its forward or backward pass. This is very bad! Given four machines, you want them to be performing meaningful operations as often as possible.

So how can we overlap work between the machines? In the current regime, we can’t. Each time step depends on the value of the previous time step. Instead, we have to change the setup — by breaking our task down into smaller chunks.

True Pipelining with Microbatches

Motivation behind Microbatches

A minibatch is the section of a batch given to one data parallel machine. With data parallelism = 1, minibatch and batch are the same.

Imagine you have a minibatch of size 4. This means that the model would perform forward / backward on all four datapoints as a batch. Then at the end we’ll average these gradients with all the other data parallel ranks (when data parallel = 1, there are no other ranks) and then we’ll perform an optimizer step.

What if, then, we broke this minibatch into microbatches? instead of sending in the data as a batch, we sent it over one at a time and averaged their gradients together at the end. this would be no different than batching them together; but there is a critical advantage — the forward and backward passes of microbatch 1 do not depend on the forward/backward pass of microbatch 2. So if we figure out how to overlap these across our four machines, we can get performance gains!

Pipelining Microbatches

so how can we pipeline microbatches? Let’s think through from first principles and see if we can have an emerging pattern.

First, constraints. All datapoints must move from machine_0 to machine_1 to machine_2 to machine_3. This is because model_chunk_0 is stored on machine_0, and we need the output of model_chunk_0 to perform operations of model_chunk_1 on machine_1 and so on.

With these constraints, let’s start simple:

  • first, datapoint_0 goes through machine_0. nothing else can happen at this time.
  • then model_chunk_0’s output of datapoint_0 goes through machine_1.
    • now machine_0 is idle, so we can put datapoint_1 through machine_1.
  • model_chunk_1’s output of datapoint_0 goes through machine_2
    • now machine_1 is idle, so we can put model_chunk_0’s output of datapoint_1 through machine_1.
    • now machine_0 is idle, so we start moving datapoint_2.

Do you see the pattern? Here is a helpful visualization

Microbatch Movement

Over the entire forward pass, it looks something like this (from GPipe):

GPipe Foward Pass

Now that all the forward passes are complete, we need to do the same for the backward passes, but this time we start from the last machine and work our way backwards. Since all the forward passes are concluded, we can start from any of the four datapoint’s loss but to keep in line with literature let’s say we start the backward pass on the last datapoint.

The entire flow then looks something like this!

Gpipe Full Bubble

Pipeline Bubble

What is this “bubble” in the diagram above? Let me explain. Once the first datapoint is through all the forward pass steps, it is ready to start the backward pass. But the backward pass goes last machine → first machine, and the last machine cannot be free until all the forward passes are complete.

Then, we run the backward pass from last datapoint to first datapoint, which means that the first datapoint waits a lot of time before it’s backward pass can be fired.

Bubble Annotation Explained

This is known as the bubble in pipeline parallelism — the irremovable cost we have to pay to be able to keep all gradient updates synchronous and semantically the same as if pipelining didn’t exist at all.

This is a design decision. PipeDream [2] employs asynchronous updates to get rid of the bubble problem. As we’ll read later, GPipe decides that the bubble is necessary but can be made negligible by choosing our pipeline size and microbatch numbers correctly.

The Math Behind Gradient Descent

How precisely do we perform the backward pass for a pipelined execution with $K$ stages? With a batched single-device execution, we would get the loss per data point $L(x_i)$, sum them all up, perform backpropagation and get the updates to each of our weights.

In our case, only the last device knows the loss per microbatch, $L(b_i)$ (the sum of losses of all data points in that microbatch). However, this gets resolved nicely. All the last device needs to do is perform gradient descent on its portion of the weights and the input its stage of the pipeline, $\frac{dL}{dF_K}$. The only information it then needs to pass back to the previous stage is precisely that loss!

Why? $F_K$ can be viewed as the output of the model that is being trained by the previous pipeline stage, so thus we can just independently perform gradient descent with that loss for stage $K - 1$, and repeat the process until all stages get the updates!

The last remaining complexity in this approach is, when do we actually apply the updates we have calculated? We only want to update the model once the entire minibatch is done, so we play our classic trick of just accumulating gradients from all the microbatches together until we have everything we need to update.

Rematerialization

There is one caveat to what we’ve described so far. In order to calculate the gradients for each microbatch, the device has to store the activations that were generated from that particular microbatch.

In a setting with non pipelining and where we were saving everything, we would need to store 1) the entire minibatch worth of data itself, 2) every activation for every layer for every minibatch entry, 3) a gradient for every layer. Let minibatch size be $N$ and let the total number of layers in the model be $L$. Then we’re storing $O(N + N * L + L) = O(N*L)$ information.

Even with pipeline parallelism across machines, this can become costly - imagine the first device in the pipeline. It’s the first to begin generating activations and the last to receive the loss, so depending on how many microbatches you have and the pipeline length, it could have to store many sets of activations. To be more precise, we’re talking roughly $O(\frac{L}{K} * N)$ activations. Is there a way to cut down on this?

Yup, otherwise we wouldn’t be talking about it this way. Here’s a crazy thought: the only information we need is the input activation to the stage. If, after calculating the entire set of layers to get the output to pass to the next stage, we just discarded all of them, we could save on the memory but save the initial activation to recalculate them later when we want to.

Let’s get mathy. Suppose that, for a given stage, the input activation of the $i$th microbatch to that stage is $F_{i,k}$. We calculate $F_{i+1,k}$ using the sharded portion of the model, pass it on to the next stage, and delete everything except for $F_{i,k}$. Then, when we get $\frac{dL}{dF_{i+1,k}}$, we recalculate every activation between $F_{i,k}$ and $F_{i+1,k}$ , perform gradient descent, and then properly delete all of the activations associated with microbatch $i$.

Thus, we only need to store at most one activation per minibatch, thus bringing our memory complexity down to $O(\frac{L}{K} * \frac{N}{M} + N)$. The additive $N$ is because we may still need to store a single activation per data point across the entire set. The memory savings can be significant, but this does come at the cost of having to recompute the entire forward pass during each backward pass.

How to Divide Layers into Stages?

typically, when trying to divide models, we do a very simple “chop it in half” strategy. however, not all layers are created equal! It might not be performance optimal to just split a model in half. Pipeline parallelism libraries allow specifying how many layers should land on which accelerator, which helps design around this issue.

However, there is still a lack of literature around the process of choosing the optimal split across pipeline stages. Works like Alpa [4] further investigation into this problem, but future work could focus on developing closed form solutions and estimations to calculating more optimal distributions of models across machines. On large scale training runs, this might be able to save millions of dollars!

Pipelining Performance Gains

Space Complexity Optimization

As we discussed previously, without pipeline parallelism, for a minibatch of $N$ and a model of $L$ layers, the machine would be storing $O(N + N * L + L) = O(N*L)$ information.

Without rematerialization, we’re using $O(N + N * \frac{L}{K})$ for each machine. With rematerialization, we only store the input data to each of the model chunks, discard activations, and rematerialize them as necessary.

This means that, for a given machine at a given time, we only have in memory the activations of the microbatch, not the minibatch! Let the number of microbatches be $M$, which means the size of each of the microbatches would be $N/M$.

Thus, the space complexity with rematerialization of activations is $O(N + \frac{N}{M} * \frac{L}{K})$.

Size of the Bubble

As discussed before, with pipeline parallelism we have bubbles where the devices are just idling. The devices that get hit the hardest from this are the first and last device;

Suppose there are $K$ stages in the pipeline. How long do the first and last devices wait? The last device idles for $K - 1$ stages before receiving its first activation, and after processing its last backwards pass, waits $K-1$ stages for that loss to propagate to the previous stages before it can do its weight update. The first device will wait $K - 1$ stages after processing the last microbatch for the last device to run it. It will then wait $K - 1$ stages for the first loss to reach it. Thus, symmetrically both the first and last device are idle for $2(K - 1)$, or $O(K)$.

How does this compare to the overall runtime of the pipeline? In general calculating the critical path of a pipeline is challenging, but for the simple pipelines we’ve been discussing we can calculate it handily. Let’s break down the pipeline into 6 periods:

  1. A low utilization period in which the pipeline is getting filled. Ends when the last device gets its first activation.
  2. A high utilization period in which every device is doing useful work. Ends when the first device finishes the last forward pass.
  3. A low utilization period that ends when the last device finishes the last forward pass.
  4. A low utilization period that ends when the first device finishes the first backward pass.
  5. A high utilization period that ends when the last device finishes the last backward pass.
  6. A low utilization period that ends when the first device finishes the last backward pass.

Foward and Backward with Regions Shaded

Notice a pattern? We basically have two pipelines, one for the forward pass with the first 3 periods and one for the backward pass with the last 3 periods. Thus, the analysis comes down to how long it takes to fill and drain the pipeline, and how long it spends in the high utilization period. As we’ve discussed, filling and draining takes $K - 1$ steps. In the high utilization period, we end when we begin processing the last microbatch, so the time complexity comes out to $M - K + 1$.

The time complexity then follows:

  1. K - 1
  2. M - K + 1
  3. K - 1
  4. K - 1
  5. M - K + 1
  6. K - 1

Foward and Backward with Regions Shaded and Annotated

Thus, the overall running time for is $O(M+K)$. How does this compare to the bubble? The trick here is that $K$ is limited by the number of devices you have, so if you can increase $M$, then the portion of time that is spent bubbling, i.e. $O(\frac{K}{M + K})$, decreases! In practice, this is extremely challenging to do, because your data is typically split across several data parallel pipelines, and we want to keep the global batch size small to keep learning efficiency high.

Empirically, GPipe [1] finds that having your number of microbatches be 4x the number of pipeline stages makes the bubble negligible (since the green section of the visual becomes much longer than the red sections).

Conclusion

Almost every model you’ve seen from any large model training co — google, openai, anthropic, mistral — uses 3D parallelism: data parallelism, tensor parallelism, and pipeline parallelism all together. This allows them to train very large models as fast as possible; and now you know how to do this as well!

References

[1] GPipe: Efficient Training of Giant Neural Networks using Pipeline Parallelism

[2] Attention Is All You Need

[3] XLNet: Generalized Autoregressive Pretraining for Language Understanding

[4] Alpa: Automating Inter- and Intra-Operator Parallelism for Distributed Deep Learning

[5] TeraPipe: Token-Level Pipeline Parallelism for Training Large-Scale Language Models