arcsin-santo.tech

Speed Up Training using the PyTorch Reference API

Introduction

Training speed matters due to compute costs, especially when renting GPUs by the hour. Fortunately, PyTorch's documentation allows developers to implement optimized algorithms to help in a simplified manner. Following their documentation, I decided to experiment with FlashAttention.

A speedup algorithm that we can use is FlashAttention, an algorithm that seems to be used in every modern LLM. I was curious how it would perform on a smaller scale and decided to implement it in Andrej Karpathy's 'makemore' tutorial to see what would happen (for context, the makemore series gives foundational knowledge to build language models from scratch).

The model is essentially a name generator that creates similar looking names. He originally started using the Bigram model, where it predicts the next letter based only on the previous letter, ignoring all other context. When trained, this Bigram model produces text, but it lacks context. This is where an alternative model can be used, like the Transformer, allowing better context when making predictions.

Names on Bigram Picture Names on Transformer Picture

Comparing the Bigram and Transformer models on names, it's clear the Transformer generated names sound more "namey". Though, I noticed that the implementation is still a manual implementation.

By manual, I mean it is not completely utilizing the optimized versions used in modern libraries. Today, FlashAttention is possibly used in most Transformers in practice, and PyTorch provides an API for it. But before seeing how it replaces the manual implementation, let's take a closer look inside.

I. Attention inside Transformer

The fundamental building block of the Transformer model is the attention mechanism.

How it works is that it allows the model to focus on different parts of the input sequence, and establish context.

As proposed in the original paper 'Attention is All You Need', given an input sequence, each token is projected into a Query, Key, and Value vector. We can think of them as:

The procedure to follow is:

  1. Queries, Keys, and Values are projected from the same input X.
  2. Attention scores are computed using the dot product QKt
  3. The scores are scaled and normalized using softmax
  4. The resulting attention weights are applied to the Value matrix V to produce the output.

This procedure used in the attention mechanism is the fundamental building block that helps understanding context and relationships. However, steps 2 to 4 produce a large matrix that becomes costly (slow and memory-intensive) when scaling to larger LLMs, affecting training and inference.

II. Upgrade Attention to use FlashAttention

This is when FlashAttention is introduced.

FlashAttention works by tiling the attention computation and fusing GPU operations so queries, keys, and values are processed in chunks directly in fast on-chip memory instead of writing large matrices to GPU memory. Without diving too much into the details, it combines most steps in the procedure for greater efficiency. But how can we change the original code in the Attention part?

Code Dive

...
           # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
           att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
           att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
           att = F.softmax(att, dim=-1)
           y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
           y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side

Original Attention Code in the Transformer

In the original makemore code, you see the @ symbol, which is Python's matrix multiplication syntax. It follows the steps in the procedure.

       y = F.scaled_dot_product_attention(
           q, k, v,
           attn_mask=None,
           dropout_p=0.0,
           is_causal=True
       )

Attention Code replaced with PyTorch's scaled_dot_product_attention API

Changes I’ve made: The scaled_dot_product_attention function handles the causal mask internally, so I removed the causal masking code.

From what I understand, there are different ways to create the causal mask in attention manually, such as using torch.triu, which prevents tokens from attending to future positions. The mask works by blocking some connections that can be harmful before softmax so their attention weights become zero. Since PyTorch now handles causal masking internally, I removed att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf')) from Karpathy's code, and instead set is_causal=True to achieve the same behavior more cleanly.

I set attn_mask=None here. This might be confusing since I just mentioned a causal mask, but attn_mask is actually a different parameter for custom mask. The original code didn't use it, so I left it as None.

Karpathy also removed dropout in his implementation since the models are small. The API has a dropout_p parameter, which is why I set it to 0.

According to PyTorch's API documentation, scaled dot product attention automatically selects the most optimal implementation based on the inputs. However, I wanted to explicitly use FlashAttention, so I specified by:

with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):

Comparing Performance

Then I did a benchmark to show that switching from the manual attention implementation to scaled_dot_product_attention creates a performance improvement. With same Transformer model and training parameters, the code I replaced with the PyTorch API for FlashAttention averages ~5.81 ms per step, compared to ~6.77 ms per step for the manual implementation -> a ~16–17% speedup.

Conclusion

Admittedly, the speedup change isn't necessary for this small dataset. Even without this change, I did not have any trouble training my laptop. We are just training on names and not the whole internet, in the end, its just a short sequence. So, the benefits of FlashAttention are limited here, and we can see that Attention is O(n2), n is tiny here. But, on a larger scale making this change can be impactful, where compute and memory usage balloon like larger LLMs.

To conclude, it's a small change in the grand scheme of things. However, it adds up, especially as models scale. I learned a lot when trying out this experiment, and it serves as a reminder that sometimes the best optimizations can be found in the documentation of the library that you're already using.