arcsin-santo.tech

What Happens When You Start Adding GPUs? ๐Ÿ”ฅ

Introduction

It seems that open-source models scale to 100B+ parameters (LLMs, video diffusion transformers, multimodal architectures). Because it has already been open-sourced and those models come with the weights, the question changes from "what is happening under the hood?" to "how do I actually deploy and serve this efficiently?"

In this article, I present open-source examples of WanVideo, diving into into one of it's multi-GPU serving strategies and investigate the trade-off.

Background

A method to add GPUs is called parallelism. It is used because a single GPU eventually runs out of memory.

For example: In LLMs, long prompts consume a lot of memory because the model must process and attend over many tokens at once (its cost grows with sequence length). When a user gives a very long text prompt, the model has to keep that context in memory while generating an answer, that possibly gives that CUDA out of memory error.

It is generally advised that if a model fits comfortably on a single GPU and the sequence length is small, the parallelism strategy is unnecessary1. In that case, attention can be efficiently optimized on a single GPU (as discussed in FlashAttention in my previous article).

In-depth Study

Although LLMs use transformer blocks for text, video generators like WanVideo handle much larger data (information that includes both location/where and time/when components). This makes it a great example where distribution can significantly improve performance when you start adding GPUs.

Sequence Parallelism (SP)

In a typical single-GPU setup, attention operates on the entire sequence at once. The data is represented as:

[batch_size, seq_len, hidden_size]

You can think of this as a large input holding all the data. The model then creates projections (Q, K, and V), where all computations are performed on a single GPU.

However, when SP is used, the input is split across the sequence dimension: [batch_size, seq_len, hidden_size] becomes [batch_size, seq_len / N, hidden_size] per GPU

Each GPU processes only its chunk of tokens (divide it by N). The effect of this is that it reduces per-GPU memory and enabling even longer sequences 2

Once you split along the sequence dimension, tokens depend on other tokens stored on different GPUs. This means that when distributing Q, K, and V, devices need to communicate to access information from one another. This approach however introduces a trade-off between memory savings and communication overhead.

Real Application in Open-Source Code

Note that attention is a mechanism, while a transformer is a full model architecture built primarily using attention.

I will look at FastVideo which uses the concepts discussed. Now, if you look at the following code of modification of WanVideo:

        hidden_states = self.patch_embedding(hidden_states)
        hidden_states = hidden_states.flatten(2).transpose(1, 2)

        # Shard with padding support - returns (sharded_tensor, original_seq_len)
        hidden_states, original_seq_len = sequence_model_parallel_shard(hidden_states, dim=1)
        
        current_seq_len = hidden_states.shape[1]
        sp_world_size = get_sp_world_size()
        padded_seq_len = current_seq_len * sp_world_size

a.) Prepare a tensor for SP

We see that sequence_model_parallel_shard prepares a tensor for sequence parallelism (SP), similar to the concept we discussed.

So, one long [batch, seq, hidden] tensor -> each GPU keeps only part of the sequence.

Then, the following code shows that we redistribute heads across sequence dimension.

        # Stack QKV
        qkv = torch.cat([q, k, v], dim=0)  # [3*batch, seq_len, num_heads, head_dim]

        # Redistribute heads across sequence dimension
        qkv = sequence_model_parallel_all_to_all_4D(qkv, scatter_dim=2, gather_dim=1)
        ...
        output = self.attn_impl.forward(q, k, v, ctx_attn_metadata)
        ...
        output = sequence_model_parallel_all_to_all_4D(output, scatter_dim=1, gather_dim=2)

b.) Split along sequence dimension

Those sequence_model_parallel_all_to_all_4D calls, however, are the communication overhead. We have to call them each time.

        # Gather and unpad in one operation
        hidden_states = sequence_model_parallel_all_gather_with_unpad(
            hidden_states, original_seq_len, dim=1)

c.) Gather

Finally, we gather all the split pieces back together so the final layer can process the complete sequence.

Conclusion

In conclusion, adding more GPUs help a lot when you need more memory or performance. We looked at a Sequence Parallelism strategy that can be found in open-source. Overall, you can expect added complexity, especially when more GPUs are added, provided that the have a strategy implemented.

References

  1. Parallelism Scaling: https://docs.vllm.ai/en/stable/serving/parallelism_scaling/

  2. Sequence Parallelism: https://insujang.github.io/2024-01-11/tensor-parallelism-and-sequence-parallelism-detailed-analysis/#attention-layer