新PyTorch API:几行代码实现不同注意力变体,兼具FlashAttention性能和PyTorch灵活性

新PyTorch API:几行代码实现不同注意力变体,兼具FlashAttention性能和PyTorch灵活性

💡 原文中文,约2400字,阅读约需6分钟。
📝

内容提要

PyTorch团队引入了FlexAttention,一个灵活的API,允许用户使用几行PyTorch代码实现多个注意力变体。通过torch.compile将其降低到一个融合的FlashAttention内核中,生成了一个不会占用额外内存且性能可与手写内核相媲美的FlashAttention内核。FlexAttention具有令人惊讶的表达能力,可以满足大多数用户对注意力变体的需求。

🎯

关键要点

  • FlexAttention 是一个灵活的 API,允许用户用几行 PyTorch 代码实现多个注意力变体。

  • 通过 torch.compile,FlexAttention 被降低到一个融合的 FlashAttention 内核,性能可与手写内核相媲美且不占用额外内存。

  • 现有的注意力机制在性能提升的同时失去了灵活性,导致用户面临运行缓慢和 CUDA 内存不足的问题。

  • FlexAttention 允许用户定义 score_mod 函数,以满足对注意力变体的需求。

  • FlexAttention 动态计算偏差值,显著提高内存和性能,支持相对位置编码等变体。

  • FlexAttention 的性能接近手写的 Triton 内核,前向传播实现了 FlashAttention2 性能的 90%,反向传播实现了 85%。

  • 研究者计划改进 FlexAttention 的反向算法,以缩小与 FlashAttention2 的性能差距。

延伸问答

FlexAttention 是什么?

FlexAttention 是一个灵活的 PyTorch API,允许用户用几行代码实现多个注意力变体。

FlexAttention 如何提高性能?

通过 torch.compile,FlexAttention 被降低到一个融合的 FlashAttention 内核,性能可与手写内核相媲美且不占用额外内存。

FlexAttention 支持哪些注意力变体?

FlexAttention 支持因果注意力、相对位置嵌入、滑动窗口注意力等多种注意力变体。

使用 FlexAttention 的好处是什么?

使用 FlexAttention,用户可以灵活定义注意力变体,避免了运行缓慢和 CUDA 内存不足的问题。

FlexAttention 的性能与手写内核相比如何?

FlexAttention 的性能接近手写的 Triton 内核,前向传播实现了 FlashAttention2 性能的 90%,反向传播实现了 85%。

未来对 FlexAttention 的改进计划是什么?

研究者计划改进 FlexAttention 的反向算法,以缩小与 FlashAttention2 的性能差距。

🏷️

标签

➡️

继续阅读