Introduction
You might have heard the news of AI companies acquiring increasingly large clusters of GPUs to train models — Meta announced that they have 600K GPUs, Sam Altman is trying to raise $7T to build an NVIDIA competitor, and NVIDIA keeps surpassing the GDP of countries in net worth.
But how are these companies using all these GPUs? How are large models with billions of parameters trained in a distributed fashion? There’s two realms of parallelism techniques that most modern training setups utilize to train these models at scale — split your data across machines (data parallelism) or split your model across machines (model parallelism).
Today, we dive deep into how data parallelism works.
Imagine: 1 GPU…
First, let’s all agree on how single-gpu training works.
**Step 1: Input data. **This is some vector of integers $x$.
Step 2: Forward pass. $x$ is put through the model, resulting in some output $\hat{y}$. This is the forward pass $f(x) = \hat{y}$.
Step 3: Loss calculation. We have some ground truth label $y$ and take the loss between them: $L(y, \hat{y})$.
Step 4: Backward pass. We calculate the derivate of the loss with respect to all of the model weights: $g(w) = \frac{d(L)}{dw}\ \forall w$.
Step 5: Optimizer step. Finally, we update the weights by the value of the gradient times some constant factor $c$: $w_{new} = w - c\dot g(w)$. These are the new (and hopefully better!) weights of our model.
Repeat until GPT. That’s it!
Scaling Up with Data
As far as possible, we want to scale vertically— making our 1 GPU setup chunkier, by increasing the GPU’s memory, power, etc. But this is not a good solution, because:
- How chunky we can make machines is constrained by physical limits.
- Chunky machines are extremely expensive.
- Chunky machines would be a single point for failure for large-scale model training.
So when we can no longer scale vertically, we need to figure out how to scale horizontally. In all parallelism strategies we will investigate, the goal is to shard something across multiple machines— either the data sequences, or the model weights. We need to do this in a way that preserves ML semantics — it makes no sense to do this if we get a much worse model with many more resources.
The idea behind data parallelism is this: If we copy our neural network onto N
GPUs, assign each of them a portion of data to work on, then coalesce all your N
models together somehow… then we would be processing data at a rate N
times greater than before!
How Data Parallelism Training Works
Here’s how one loop of this process happens:
- Each machine chooses 1 datapoint and performs the forward and backward pass, calculating the gradients for their datapoint.
- At the end of the backward step, they all communicate the gradients to each other and average all the
N
gradients out. - Everyone conducts the optimizer step using these averaged gradients.
This means:
- Each of the data parallel workers gets to contribute equally to the overall gradient.
- The model weights are never out of sync, since the optimizer step is identical on all machines.
- In effect, we have trained on
N
data points each before running the optimizer step. Thus, we are mathematically equivalent to a model trained on a single machine with a batch size ofN
.
In the Wild: Optimization Strategies for Data Parallelism
Data parallelism in concept is a simple and elegant idea with one elephant-sized problem — it makes communications cost the bottleneck of training. The description above, in its naive implementation, would slow down training to the point that you might be better just training on the one GPU. We need to find techniques that ease this bottleneck. Below we’ve listed a few!
Batching
How Does Batching Work in a Single-GPU Regime?
In practice, you typically don’t do the optimizer step after each data point. Instead, you process a batch of K
data points, doing the forward and backward pass on each one, and accumulate an average gradient. Then, you do a single optimizer step from the average gradient. More concretely, you have a batch ${x_1, …, x_k}$ , and then compute:
$\forall{i}$ in batch:
- Forward Pass. $f(x_i) = \hat{y}_i$
- Loss Calculation. $L(y_i, \hat{y}_i)$
- Backward Pass. $g_i(w) = \frac{dL}{dw} \forall{w}$
Once batch is processed:
- Gradient Averaging. $g(w) = \frac{1}{K}\sum_i g_i(w)$
- Optimizer Step. $w_{new} = w - c\dot{g}(w)$
A common confusion: This is not necessarily going to lead to the same model as if we had just done the optimizer step after each data point. Hence, we say that a model trained with this scheme is not mathematically equivalent to a model that does the optimizer step after each data point. So why do this? The primary benefit is to save on computation, because we are not running the optimization step as often.
Additional benefits and drawbacks of batching are an active area of research, and remain complex to understand. For our case, we will focus on the commonly used mini-batch gradient descent strategy, where the batch size is greater than one, but it is less than the entire size of the dataset.
Batches in Data Parallel Training
Batching takes on an additional purpose for data parallel training. Suppose we have a batch size of M
. If each device generates M
gradients and averages those M
into one before synchronizing, then we have decreased our communication cost by a factor of M
times!
Note, however, that this does actually change the model we are training. If each device is training with a batch size of M
, and we have N
servers, then we can define our global batch size N*M
. This is the amount of data we train on globally per communication. This would be equivalent to training a model on a single GPU with a batch size of N*M
.
Bucketing Gradients
The bottleneck in data parallel training is the communications cost of sharing and averaging the gradients before every step. For this reason, we want to maximize the bandwidth utilization during our gradient communication. If we were to send over gradients every time we calculated it, the bandwidth would be underutilized. If we were to wait for all gradients to be calculated before communications, we would waste precious time — we fail to utilize network bandwidth during computation, and waste compute power while communicating!
Thus, we utilize a strategy between the two extremes: create “buckets” of gradients that get communicated together. The buckets are created such that parts of the model that would go through the backward pass first are bucketed together, those that would calculate gradients next are bucketed together, and so on until you have the last bucket with the last gradients to be calculated. This allows us to interleave the backward step and communications as well, and save on a bunch of time.
Other Data Parallelism Strategies
Parameter Averaging
In this article, we discussed a very particular form of data parallelism — one that averages the gradients. However, there is an alternative strategy we could employ: parameter averaging! The core idea: Train N
data parallel models on some subset of the data, then average their params together! Conditioned on how frequently you average your params, this can be even faster than gradient-based data parallelism.
So why don’t we do this? The intuition here is that the average of N
different local minimas is not guaranteed to be a local minima.
As mentioned previously, there is mathematical equivalence between gradient-based data parallelism and single-GPU batching. This is no longer true with param averaging. Since modern optimization strategies also consider the second derivatives and previous gradients in calculations, the N
data parallel models can drift substantially away from each other. These ML semantics pitfalls make ML scientists vary of using parameter averaging.
Parameter Servers
What are parameter servers? The idea is that “Both data and workload are distributed into client nodes, while server nodes maintain globally shared parameters” [4]. This seems like a great idea, since we can have shared parameters and wouldn’t need to worry about mathematical equivalence anymore! The drawback of this approach, however, is that it makes communications even more of a bottleneck, since param servers need to talk to their workers as well as other param servers. Parameter servers have gone out of fashion simply because they could not keep up with the increasing scale of ML models.
The Limits of Data Parallelism
With data parallelism, we have increased our data processing throughput. However, this scaling does not actually last forever. It has been observed that there are diminishing gains as one increases the number of GPUs, saturating out at a certain point. Furthermore, data parallelism does not let us scale the actual model size. In other words, our model is still restricted to existing within the memory of a single GPU. For reference, a high-end H100 from NVIDIA has 80 GB of memory [5].
GPT-3, a model 4 years old at this point, has around 350GB of weights, more than four times that! How can we possibly support efficient training of such large models? We will have to parallelize the model!
References
[1] A Comprehensive Guide of Distributed Data Parallel (DDP) [2] Data Parallel Distributed Training — Neural Network Libraries 1.39.0 documentation [3] Intro Distributed Deep Learning [4] Parameter Server for Distributed Machine Learning [5] NVIDIA H100 Tensor Core GPU Datasheet