arcsin-santo.tech

What Happens When You Start Adding GPUs?

Introduction

It seems that open-source models scale to 100B+ parameters (LLMs, video diffusion transformers, multimodal architectures). Because it has already been open-sourced and those models come with the weights, your initial question may have changed from "what is happening under the hood?" to "how do I actually deploy and serve this efficiently?" However, I would argue that both questions remain important and should be addressed together.

In this article, I examine the open-source implementation of WanVideo (FastVideo) and explore how understanding the original model's underlying concepts led to efficient serving when GPUs are added.

Background

A method to add GPUs is called parallelism. It is used because a single GPU eventually runs out of memory.

For example: In LLMs, long prompts consume a lot of memory because the model must process and attend over many tokens at once (its cost grows with sequence length). When a user gives a very long text prompt, the model has to keep that context in memory while generating an answer, that possibly gives that CUDA out of memory error.

It is advised that if a model fits comfortably on a single GPU and the sequence length is small, the parallelism strategy is unnecessary1. In that case, attention can be efficiently optimized on a single GPU (as discussed in FlashAttention in my previous article).

In-depth Study

Although LLMs use transformer blocks for text, video generators like WanVideo handle much larger data (information that includes both location/where and time/when components). This makes it a great example where distribution can significantly improve performance when you start adding GPUs.

Sequence Parallelism (SP)

In a typical single-GPU setup, attention operates on the entire sequence at once. The data is represented as:

[batch_size, seq_len, hidden_size]

You can think of this as a large input holding all the data. The model then creates projections (Q, K, and V), where all computations are performed on a single GPU.

However, when SP is used, the input is split across the sequence dimension: [batch_size, seq_len, hidden_size] becomes [batch_size, seq_len / N, hidden_size] per GPU

Each GPU processes only its chunk of tokens (divide it by N). The effect of this is that it reduces per-GPU memory and enabling even longer sequences 2

Once you split along the sequence dimension, tokens depend on other tokens stored on different GPUs. This means that when distributing Q, K, and V, devices need to communicate to access information from one another. This approach however introduces a trade-off between memory savings and communication overhead.

Real Application in Open-Source Code

Note that attention is a mechanism, while a transformer is a full model architecture built primarily using attention.

I will look at FastVideo which uses the concepts discussed, on WanVideo. Now, if you look at the following code of modification of WanVideo:

        hidden_states = self.patch_embedding(hidden_states)
        hidden_states = hidden_states.flatten(2).transpose(1, 2)

        # Shard with padding support - returns (sharded_tensor, original_seq_len)
        hidden_states, original_seq_len = sequence_model_parallel_shard(hidden_states, dim=1)
        
        current_seq_len = hidden_states.shape[1]
        sp_world_size = get_sp_world_size()
        padded_seq_len = current_seq_len * sp_world_size

a.) Prepare a tensor for SP

We see that sequence_model_parallel_shard prepares a tensor for sequence parallelism (SP), similar to the concept we discussed.

So, one long [batch, seq, hidden] tensor -> each GPU keeps only part of the sequence.

Then, the following code shows that we redistribute heads across sequence dimension.

        # Stack QKV
        qkv = torch.cat([q, k, v], dim=0)  # [3*batch, seq_len, num_heads, head_dim]

        # Redistribute heads across sequence dimension
        qkv = sequence_model_parallel_all_to_all_4D(qkv, scatter_dim=2, gather_dim=1)
        ...
        output = self.attn_impl.forward(q, k, v, ctx_attn_metadata)
        ...
        output = sequence_model_parallel_all_to_all_4D(output, scatter_dim=1, gather_dim=2)

b.) Split along sequence dimension

Those sequence_model_parallel_all_to_all_4D calls, however, are the communication overhead. We have to call them each time when redistributing heads accross.

        # Gather and unpad in one operation
        hidden_states = sequence_model_parallel_all_gather_with_unpad(
            hidden_states, original_seq_len, dim=1)

c.) Gather

To put them back together, we gather all the split pieces so the final layer can process the complete sequence.

We can see that an efficient strategy is used when additional GPUs are added. However, this strategy is not implemented by default and instead optimizes what can be found inside the model, attention.

So efficient use of additional GPUs depends on understanding the underlying architecture across available resources.

References

  1. Parallelism Scaling: https://docs.vllm.ai/en/stable/serving/parallelism_scaling/

  2. Sequence Parallelism: https://insujang.github.io/2024-01-11/tensor-parallelism-and-sequence-parallelism-detailed-analysis/#attention-layer