推广 热搜:   公司  中国  行业  快速  企业  设备  上海  未来  技术 

【BBuf的CUDA笔记】十四,OpenAI Triton入门笔记三 FusedAttention

   日期:2025-01-02     移动:http://www78564.xrbh.cn/mobile/quote/28716.html

继续Triton的学习,这次来到 https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html 教程。也就是如何使用Triton来实现FlashAttention V2。对于FlashAttention和FlashAttention V2网上已经有非常多的介绍了,大家如果感兴趣的话我推荐FlashAttention V1看 《图解大模型计算加速系列:FlashAttention V1,从硬件到计算逻辑》https://zhuanlan.zhihu.com/p/669926191 这篇文章的讲解 以及 FlashAttention V2 看 《图解大模型计算加速系列:Flash Attention V2,从原理到并行计算》 https://mp.weixin.qq.com/s/5K6yNj23NmNLcAQofHcT4Q ,原理和公式推导都非常清晰,不过想一口气读完还是要花一些精力的。同时你也可以在 https://github.com/BBuf/how-to-optim-algorithm-in-cuda 找到更多相关资料(此外Meagtron-LM,DeepSpeed等训练Infra框架的迅速跟进也说明了FlashAttention这个系列工作影响之大),例如:

在这里插入图片描述

这篇文章主要的问题是读懂如何使用Triton来实现FlashAttention V2的前向,所以我不会去复述FlashAttention的公式细节,而是从更加工程的角度来说FlashAttention Forward的代码应该如何实现,我在这个过程中也会提供FlashAttention V1/V2 Forward的一个最简Python实现来非常直观的把握代码流程,在这个基础上才会展开对Triton FlashAttention实现的解读,让我们开始吧。(后续如果有精力也会写一下Backward的实现

FlashAttention V1/V2的paper链接为:https://arxiv.org/abs/2205.14135 和 https://tridao.me/publications/flash2/flash2.pdf 。 本文涉及到的实验代码见我的个人仓库:https://github.com/BBuf/how-to-optim-algorithm-in-cuda/tree/master/triton ,也欢迎大家点star。

跑了一下 https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html 这个教程里的FlashAttention V2的BenchMark。

对于Batch=4,Head=48,HeadDim=64,causal=True的Flash Attention V2 Forward,对比不同序列长度下Triton实现和cutlass实现版本的性能:

在这里插入图片描述

对于Batch=4,Head=48,HeadDim=64,causal=False的Flash Attention V2 Forward,对比不同序列长度下Triton实现和cutlass实现版本的性能:

在这里插入图片描述

对于Batch=4,Head=48,HeadDim=64,causal=True的Flash Attention V2 Backward,对比不同序列长度下Triton实现和cutlass实现版本的性能:

在这里插入图片描述

在这组配置下Triton在各种Sequence Length下都实现了比cutlass更优的性能,然后在Triton的kernel实现里面有,也就是说Triton的实现需要注意力头的隐藏层维度在[16, 32, 64, 128]里,我这里再测一组16的看下表现。

对于Batch=4,Head=48,HeadDim=16,causal=True的Flash Attention V2 Forward,对比不同序列长度下Triton实现和cutlass实现版本的性能:

在这里插入图片描述

对于Batch=4,Head=48,HeadDim=16,causal=False的Flash Attention V2 Forward,对比不同序列长度下Triton实现和cutlass实现版本的性能:

在这里插入图片描述

对于Batch=4,Head=48,HeadDim=16,causal=True的Flash Attention V2 Backward,对比不同序列长度下Triton实现和cutlass实现版本的性能:

在这里插入图片描述

这一组case下虽然Forward Pass还是Triton更快,但是Backward Pass却是cutlass更快了。

另外之前在Triton的issue里面还刷如果HeadDim=128,Triton的Bakcward会比cutlass慢更多:https://github.com/openai/triton/issues/1975 ,参数设置为 这里也测试一下:

在这里插入图片描述

反向的耗时对比图:

在这里插入图片描述

结果很神奇,这个反向耗时的差距非常大,且Triton的速度远好于Cutlass的实现,并且随着序列长度的增加Triton的反向的耗时竟然是接近横定的。。保险起见还是建议大家用官方FlashAttention库提供的实现,我现在使用的Triton版本为2.1.0。

从FlashAttention的paper里面截一下标准Attention流程:

在这里插入图片描述

我这里再描述一下流程,首先从HBM中加载

,

矩阵,接着执行

的计算,并将结果

写回HBM;然后将

再从HBM中读取出来,执行

的计算,再将

写回HBM;然后将

从HBM中读取出来,执行

的计算,最后把结果写回HBM中。对于,

,他们的维度都是

,中间变量

的维度都是

。这里还有个问题就是对于S和P可能还会有一些其它的操作比如Mask和Dropout,所以上面也提到了有不少的fuse kernel的工作,比如把softmax和mask fuse起来。最后,这里的softmax是PyTorch的softmax算子,也是safe softmax的实现,safe的意思就是在naive softmac的基础上对指数上的每个原始输入值都减掉所有原始输入值中的最大值。具体请参考下面的图片,来源于 https://arxiv.org/pdf/2205.14135.pdf :

在这里插入图片描述

对于safe softmax来说,所有的值都减掉了输入向量里面的最大值,保证了指数部分的最大值是0,避免了数值溢出。

为了验证正确性,我写了一个脚本,这个地方以经典的GPT2为例,然后硬件以A100为例 。这里的

分别设置成1024和64,那么Q,K,V的shape都是

,S和P的维度都是

代码实现具体如下:

测试可以正确通过,也说明了PyTorch的算子的确是用safe softmax的方法来实现的。

FlashAttention V1通过分块计算的方法,将Q、K和V切块成很多小块,然后将这些切分后的小块放进SRAM(shared memory)中执行计算,最后再写回HBM中。算法流程如下:

在这里插入图片描述

如果你想完全搞清楚这个伪代码的来龙去脉推荐看 https://zhuanlan.zhihu.com/p/669926191 这篇文章,但是从源码实现的角度来看,有了这个伪代码已经接近够了。只需要知道这些看似奇奇怪怪的公式是因为在分块遍历的时候每次计算的是一部分token,而自注意力机制要计算的最终结果是所有token间的,所以从局部到整体的更新就会用到在线的softmax算法以及在线更新最后的输出。这也是上面那堆复杂的公式由来。

我这里尝试用Python来模拟一下这个算法的流程,实现之后对Triton的实现会有帮助,因为从前面几节Triton的教程来看,相比于单纯的Python实现Triton kernel只是多了一个块级别的kernel启动过程而已。沿用上一节GPT2的设置,

分别设置成1024和64,那么Q,K,V的shape都是

,注意在FlashAttention里面就没有全局的S和P了。假设硬件是A100,A100的Shared Memory大小为192KB=196608B,那么可以计算出这里Flash Attention的分块大小,也就是上面的伪代码的第一行。

然后伪代码的第2行初始化了一个全0的输出矩阵

,shape的大小也是

,同时初始化了一个

矩阵,维度大小都是

,不过

被初始化为全0矩阵,

被初始化为负无穷大。

接下来可以根据上面的参数直接计算出

,对应伪代码的第3行,

接下来的伪代码解析我直接放到下面的Python实现里,每一行代码都可以对应到上面的伪代码:

需要说明的是在上面的Attention Forward Pass流程中没有考虑到Dropout以及Mask的操作,如果考虑这两个操作整体的流程有一些变化,具体如Flash Attention V1的paper里的Algorithm2所示:

在这里插入图片描述

相比于Algorithm1,多了Mask和Dropout的操作,其它的没有变化。

如果你想很清晰的了解FlashAttention V2背后的改进原理请阅读 《图解大模型计算加速系列:Flash Attention V2,从原理到并行计算》 https://mp.weixin.qq.com/s/5K6yNj23NmNLcAQofHcT4Q 。我这里只做一个简单的原理解析,重点是关注代码层面相比于FlashAttention V1 Forward Pass的变化,并基于FlashAttention V1的版本实现FlashAttention V2 Forward Pass。

有了上一节代码的铺垫,Flash Attention V1 Forward Pass其实可以抽象为下面的图(从上面的《图解大模型计算加速系列:Flash Attention V2,从原理到并行计算》文章copy来的):

在这里插入图片描述

这个图和我们的Flash Attention V1实现是完全对应的,需要注意的是图中有6个O的小块,但实际上横着的O只有一个并且是逐步更新的,这里是为了体现分块的思想才画出来的。

这里以

为例子,我们可以看到

共用了

,FlashAttention V2基于这个观察调整了Flash Attention V1的循环顺序,现在外层循环遍历Q不就可以避免重复访问Q了吗?调整训练的顺序只是FlashAttention V2的操作之一,另外两个比较重要的操作是对计算公式进行了改写尽量减少non-matmul FLOPs,具体来说在计算局部attention时,先不考虑softmax的分母以及将rescale的时机后移,只能感叹作者大佬的数学太强,具体的大家可以参考一下《FlashAttention2详解(性能比FlashAttention提升200%)》https://zhuanlan.zhihu.com/p/645376942 这篇文章的Algorthm的解释。此外,Paper中还提了一个重要的并行性方面的改进,即加入了序列并行,具体说来 FlashAttention V1 在 batch 和 heads 两个维度上进行了并行化,使用一个线程块来处理一个注意力头,总共需要的线程块的数量等于batch和注意力头的乘积。每个block被调到到一个SM上运行,例如A100 GPU上有108个SMs。当block数量很大时(例如≥80),这种调度方式是高效的,因为这几乎可以有效利用GPU上所有计算资源。但是在处理长序列输入(目前训练100k,200k的长文本模型需求逐步增长)时,由于内存限制,通常会减小batch和注意力头数量,这样GPU并行化程度就降低了。基于此,FlashAttention-2在序列长度这一维度上进行并行化,显著提升了GPU的并行度并提升了性能。这些改进我们都可以在 https://github.com/openai/triton/blob/main/python/tutorials/06-fused-attention.py 这个Triton实现中找到,留在下一节细讲。

这里仍然是贴出Flash AttentionV2的算法伪代码,并且使用Python来模拟一下流程。

在这里插入图片描述

对应的python代码以及流程如下,由于这里只考虑了forward pass所以代码里只计算了Attention的输出O没有计算logsumexp L(这个是给backward pass用的):

然后FlashAttention V2里面还有两节和GPU并行性相关的话,在对Triton实现的解读之前我先把这两节翻译一下。

在这里插入图片描述

翻译:FlashAttention V1在batch和heads两个维度上进行了并行化:使用一个thread block来处理一个attention head,总共需要thread block的数量等于batch size × number of heads。每个block被调到到一个SM上运行,例如A100 GPU上有108个SMs。当block数量很大时(例如≥80),这种调度方式是高效的,因为几乎可以有效利用GPU上所有计算资源。

但是在处理长序列输入时,由于内存限制,通常会减小batch size和head数量,这样并行化成都就降低了。因此,FlashAttention V2还在序列长度这一维度上进行并行化,显著提升了计算速度。此外,当batch size和head数量较小时,在序列长度上增加并行性有助于提高GPU占用率。

Forward pass 这里大概就是说,FlashAttention V1伪代码中有两个循环,K,V在外循环j,Q在内循环i。FlashAttention V2将Q移到了外循环i,K,V移到了内循环 j,由于改进了算法使得warps之间不再需要相互通信去处理,所以外循环可以放在不同的 thread block 上。这个交换的优化方法是由Phil Tillet在Triton提出并实现的,也就是下一节要解读的Triton代码了。我们会看到它启动kernel的时候线程网格有两个维度,其中一个维度是序列长度,另外一个维度是batch和注意力头数的乘积。

在这里插入图片描述

翻译:paper 的3.2节讨论了如何分配thread block,然而在每个thread block内部,我们也需要决定如何在不同的warp之间分配工作。我们通常在每个thread block中使用4或8个warp,如Figure3所示。

FlashAttention forward pass. 这里精简一下,如Figure3所示,外循环对K,V在输入序列N上遍历,内循环对Q在N上遍历。对于每个块,FlashAttention V1将K和V分别分成4个warp,并且所有的warp都可以访问Q。K的warp乘以Q得到S的一部分

,然后

经过局部softmax后还需要乘以V的一部分得到

。但是,每次外循环

都要更新一次

(对上一次的

先rescale再加上当前的值),这就导致每个warp需要从HBM里面频繁读写

来累计最后的结果,这种方案也被称为"Split-K"方案,整体是低效的,因为所有warp都需要从HBM频繁读写中间结果

。FlashAttention V2 将Q移到了外循环i,K,V移到了内循环j,并将Q分为4个warp,所有warp都可以访问K,V。这样做的好处是,原来FlashAttention每次内循环i++会导致

也变换(而

需要通过HBM读写),现在每次内循环j++处理的都是

,此时

是存储在SRAM上的,代价远小于HBM。

有了上面的铺垫,就可以直接来看Triton的实现了,这里只关注 Forward Pass部分,Triton的核心计算逻辑在下面的这个函数:

需要说明的是这个函数负责的是一小块Q(入参中的q)和KV的计算,代码中的for循环对应的就是伪代码中的对KV的循环,而Q的循环实际上是体现在triton kernel启动的设置,见下面的代码和注释:

这里的其实就是对Q进行分块,需要说明的是这个地方输入的Q,K,V的形状是(Batch, NHeads, Seq, HeadDim),所以这里启动的线程网格有2个维度都是有值的,除了x维度为,它的y维度则为的乘积(这里的x是在序列维度上切分也导致了后面构造内存指针的时候有一个特殊的参数)。也就是说这里的Block数量其实是比较多的,更容易让GPU的SM用满,这个启动方式和FlashAttention V2 paper中提到的启动方式是一致的,具体请看上一节的末尾翻译部分。至于,我们在计算的时候使用多少个warp,这个也是和Paper的设置保持一致,一般是用4个,只有针对H100才用8个。另外就是由于现在的Q,K,V形状和paper中的

不一样,所以分块的个数也是不一样的,这里是写死了分块数:

最后还有一个要解析,内容如下:

需要特别注意的是这段代码最后的epilogue部分就对应了FlashAttention V2伪代码中的12行以后的内容,根据softmax的分母部分较正输出。此外,Triton的实现里面考虑了一些paper里面没有的东西比如,,对的结果应用了减掉m,使得整个实现看起来要复杂不少,但整体的算法逻辑和并行设置和paper还是一致的。

本文地址:http://www78564.xrbh.cn/quote/28716.html    迅博思语 http://www78564.xrbh.cn/ , 查看更多

特别提示:本信息由相关用户自行提供,真实性未证实,仅供参考。请谨慎采用,风险自负。


相关最新动态
推荐最新动态
点击排行
网站首页  |  二维码  |  关于我们  |  联系方式  |  使用协议  |  版权隐私  |  网站地图  |  排名推广  |  广告服务  |  积分换礼  |  网站留言  |  RSS订阅  |  违规举报  |  粤ICP备2023022329号