Introduction
Last time, we talked about data parallelism, which is a strategy for sharing data across multiple machines for large model training. Today, we’ll extend this idea to address how we can split the model across multiple machines. We need to start doing this when our models get too big (billions of parameters do not fit on a GPU).
To start with, we could think of a simple way to do so. Split the model into two halves, place each half of a machine and imagine them as one unit! problem solved!
What you have invented is known as pipeline parallelism. It has its own set of complications, discussed in its own post. Today, we’re talking about tensor parallelism. Very generally and imprecisely, the idea of tensor parallelism is something like this:
Put another way, if you imagine a language model as a series of steps $s_1…s_n$, pipeline parallelism is splitting the model $s1, s2…$ on one machine and $S_i…s_n$ on another machine. On the other hand, tensor parallelism is about how we can split a single step $s_i$ onto multiple machines.
So let’s get into it!
Simple Tensor Sharding
Here we’ll discuss how you can shard a matrix for multiplication and parallelize it. Imagine a simple multiplication of the shape below:
Our core insight here is that each of the values of Y depend on a different column of B, and thus can be parallelized.
and thus we can create an equivalent multi-machine version of the matrix multiply that shards B, calculates individual $y_i$ and collects Y at the end.
Here i’ll remind readers of a bitter lesson we learned last time— that communication cost is the boogeyman of large model training. If we were to do this for every operation in a large language model, the tensor parallel version of the model would be far slower than if we hadn’t parallelized due to all the communications overhead! Let’s illustrate why by expanding this example.
Column-parallel Sharding + Row-parallel Sharding
Suppose now we have a similar computation as before, but we have two sequential matrix multiplies, i.e. $Y = ABC$. If we follow the same trick where we shard B column-wise, we’ll get screwed!
Each server has just one piece of the y
vector, so how can we possibly do the next calculation? Here, we use the most amazing fact. Notice that if we shard the C matrix by its rows, the calculations can still be done independently!
Utilizing tensor parallelism well comes down to minimizing these synchronization points so that our GPUs can keep on chugging on the computation. In practice, this type of optimization is what makes applying tensor parallelism tricky, and frequently the right thing to do is model specific. so we have to be thoughtful about what parts of the LLM to apply tensor parallelism to, and how…
Tensor Parallelism in LLMs
In LLMs, there are two elements that we can parallelize: multi-head attention and the MLP.
Let’s discuss the MLP first. An MLP performs the following mathematical operation
$$ Z = \text{Dropout}(B(\text{ReLU}(AX)) $$
Now, you could take the initial sharding concept and claim that we should shard A and B. Good! That would be correct. But then what do we do about the RELU and the Dropout? We could synchronize our parameters between both of them. Something like this:
Two phases of sharding and un-sharding! Ah but there’s a slightly more optimal answer still — the RELU operation does not need the input vector to be coalesced. Because of this, we can remove the sync in the middle and maintain our sharded values throughout, only syncing right before dropout at the end!
A similar idea could be applied to self-attention. The attention operation looks something like this:
$$ \text{Self-Attention} = \text{Dropout}(\text{Softmax}(QK^T))V $$
You could imagine parallelizing the attention operation… but this would mean we would need to communicate before the softmax / dropout. This ends up being worse than if we hadn’t sharded the queries and keys at all.
So, what to do? Readers familiar with LLMs might remember that the attention typically used is multi-head attention. What does this mean? Multi-head attention means we calculate many self-attentions in parallel, then take a learned weighted combination of those attentiodns. That is to say, for multi-head attention
$$ Z = W * [h_1, h_2, h_3…] \ \text{where } h_{i} = \text{Self-Attention}(Q_i, K_i, V_i) $$
such that all $Q_i, K_i, V_i$ are independently randomly initialized. This is a nature of ensemble learning within language models, and as you might start to realize, are inherently parallel.
So all we need to do is shard the $W$ matrix such that we can calculate $z_1, z_2…$ independently!
This is how we end up parallelizing attention in the language model as well. You might be wondering now what the $f$s and $g$s are supposed to represent — how does the splitting and coalescing work?
Synchronization Strategies and Backprop
Now that we have a strategy, let us go back and fill in the details of how we can synchronize and perform the backpropagation, i.e. what are the nebulous f
and g
. We’ll see that they need to have different behavior during the forward and backward passes.
Forward Pass
Let’s think about what the synchronization functions f
and g
need to do in the forward pass. We want f(x) = x
, which means that actually we don’t need f
to do anything! There is no communication there.
We want a g
such that g(Z_1 + Z_2) = Z
, which means that we do want a communication, and in particular we want one that will sum the two matrices together.
Here, we can have g
be the All-reduce collective communication operation. We won’t get into the details of how this works (it depends on implementation!), but All-reduce takes the version of everyone’s activation, sums (reduces) them together, and ensures that every server has a copy of the complete data by the end of it. Thus, it will sum the two Z
matrices together to get the final result.
Backward Pass
Now let’s see what f
and g
need to be while doing backpropagation. Suppose that we have the gradients for the loss with respect to the components of Z
, i.e. we have $\frac{dL}{dZ_{mn}}$. The next step for the two servers is to respectively calculate the loss with respect to Z_1
and Z_2
. Take Z_1
. The gradient for a particular element of this matrix will be $\frac{dL}{Z_{1, ij}} = \sum_{m,n}\frac{dL}{Z_{mn}} * \frac{dZ_{mn}}{dZ_{1, ij}}$. Likewise, for Z_2
the gradients will be $\frac{dL}{Z_{2, ij}} = \sum_{m,n}\frac{dL}{Z_{mn}} * \frac{dZ_{mn}}{dZ_{2, ij}}$. Thus, all we need to pass backwards is just the gradients for Z
, which means that the two machines don’t actually need to synchronize here! Thus, g
is the identity function here.
What about f
? Before f
, each server has its own version of $\frac{dL}{dX_{ij}}$. How do we resolve this? The key here is that each server is computing the gradient of X
with respect to half of the overall computation, so we can just sum them together. Here’s the math behind it:
Let’s consider just the gradient from the Value matrix for now. Server 1 calculates $\frac{dL}{dX_{ij}} = \sum_{m,n}\frac{dL}{dV_{1, mn}} * \frac{dV_{1, mn}}{dX_{ij}}$. Likewise, server 2 does $\frac{dL}{dX_{ij}} = \sum_{m,n}\frac{dL}{dV_{2, mn}} * \frac{dV_{2, mn}}{dX_{ij}}$.
But remember, $V = [V_1 V_2]$. If we had the entire V
matrix on one machine, what we would do is calculate the gradient of X
as $\sum_{mn}\frac{dL}{dV_{mn}} * \frac{dV_{mn}}{dX_{ij}}$. In other words, we would just add together the gradient as calculated from every index of V
. Therefore, we can just sum together the gradients of X
from both servers to get the gradient with respect to the whole thing. That lets us use the same communication operation as before, the All-reduce, to sum together the gradients and pass them to the previous layer.
Data and Tensor Parallelism
Heirarchical Relationship
The typical strategy for combining data parallelism and tensor-parallelism is to leave them as separate entities. For each set of tensor-parallel ranks, we abstract them as one server that can be used in data parallelism: one data-parallel rank. It would be better understood through a visualization:
Non-Heirarchical
We don’t have to make the lines between the data parallel and tensor parallel ranks so defined. Consider the above example, where the two machines are both working on different data and have sharded an operator. At the beginning of the sharded block, they have to run an All-to-all, which essentially gathers up the data and ensures that each machine has the chunk of data that it is operating on. Then, each runs their portion of the sharded operator. Then, they run a reduce-scatter that merges the separated pieces of data together and puts them back on their assigned machine. As you might imagine, this is really communication-heavy, and so it is not done in practice as often as hierarchical.
References
References
[1] Parallelism — transformers 4.11.3 documentation
[2] Issue #10321: Tensor Parallelism in Hugging Face Transformers
[3] Issue #9766: Model Parallelism in Hugging Face Transformers
[4] Tensor Parallelism in PyTorch on Amazon SageMaker
[5] Tensor Parallelism — Text Generation Inference Documentation
[6] Tensor Parallelism Overview — AWS Neuron Documentation
[7] Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM