0
  • 聊天消息
  • 系统消息
  • 评论与回复
登录后你可以
  • 下载海量资料
  • 学习在线课程
  • 观看技术视频
  • 写文章/发帖/加入社区
创作中心

完善资料让更多小伙伴认识你,还能领取20积分哦,立即完善>

3天内不再提示

大模型系列:Flash Attention V2整体运作流程

深度学习自然语言处理 来源:大猿搬砖简记 2024-02-21 11:38 次阅读

大家好,这就为您献上不知鸽了多久的Flash Attention V2原理解读。

在V1的讲解中,我们通过详细的图解和公式推导,一起学习了Flash Attention的整体运作流程。如果大家理解了V1的这块内容,就会发现V2的原理其实非常简单:无非是将V1计算逻辑中的内外循环相互交换,以此减少在shared memory上的读写次数,实现进一步提速。那当你交换了循环位置之后,在cuda层面就可以配套做一些并行计算优化。这就是V2的整体内容。

总结起来一句话:“交换了循环位置“,虽是短短一句话,却蕴含着深深的人生哲理:只要基座选得好,回回都有迭代点,年年勇破okr!

回归正题,本文也分两个部分进行讲解:原理与cuda层面的并行计算。

在阅读本文前,需要先阅读V1的讲解,本文会沿用V1的表达符号及推演思路。

一、Flash Attention V2整体运作流程

1.1 V1的运作流程

我们先快速回顾一下V1的运作流程:以K,V为外循环,Q为内循环。

,遍历:

89ee976a-cfc5-11ee-a297-92fbcf53809c.png

,遍历:

89f50cc6-cfc5-11ee-a297-92fbcf53809c.png

为了帮助大家更好理解v1中数据块的流转过程,在图中我们画了6块O。但实际上最终只有三块O:。

为例,它可理解成是由经过某些处理后汇总而来的。进一步说,

我们在外循环j = 0时,先遍历一次所有的i,在这个阶段中我们产出,并将它和一些别的重要数据写回HBM中

接下来我们进行第二次外循环,即j=1,在这个阶段中我们产出。同时我们把和那些重要的数据从HBM传入shared memory中,然后从shared memory中读取它们,以配合产出最终的

(关于如何得到的细节我们在V1讲解中详细推导过,这里不再赘述)

在这个过程中,你是不是隐隐觉得有些别扭:

其实都和有关系,那我为什么不以Q为外循环,以KV为内循环做遍历呢?这样我不就能避免往shared memory上读写中间结果,从而一次性把乃至最终的给算出来?

同时,softmax这个操作也是在row维度上的,所以我固定Q循环KV的方式,更天然符合softmax的特性。

1.2 V2的运作流程

基于1.1中的思想,我们在V2中将原本的内外循环置换了位置(示意图就不画了,基本可以对比V1示意图想象出来)。我们直接来看V2的伪代码(如果对以下伪代码符号表示或解读有疑惑的朋友,最好先看一下V1的讲解)。

(1)V2 FWD

8a0872e8-cfc5-11ee-a297-92fbcf53809c.png

现在,想象自己固定住了一块Q(i),依此循环K和V的分块(j),在这个想象下我们来解读这份FWD为代码。

第8行,计算分块

第9行:

表示截止到当前分块(包含当前分块)为止的rowmax

表示使用当前每行最大值计算归一化前的(我们在V1中说过,不带波浪号的P表示(s-rowmax)/rowsum的结果,带波浪号表示(s-rowmax))

表示截止到当前分块(包含当前分块为止)的rowsum

第10行:表示截止到当前分块(包含当前分块)为止计算出的O值。由第9和第10行知,当我们固定Q循环KV时,我们每个分块都是用当前最新的rowmax和rowsum计算的,同理对应的也是用当前最新的rowmax和rowsum计算的。这样当我们遍历完所有的KV时,得到的就等于最终全局的结果。相关的证明我们在V1讲解中给过,这里不再赘述,只额外提两点:

可能在有些朋友下载的V2论文中,第十行这里O前面的因子项是,这个公式应该是错误的(大家动手推一下就可知,初次看到时让我困扰了很久)。在作者个人主页的论文链接中,这个typo已经被修正。

你可能已发现这个O的计算中缺少归一化的一项,这一项其实放到了第12行做统一计算。这也是V2优化的一个点:尽量减少非矩阵的计算,因为在GPU中,非矩阵计算比矩阵计算慢16倍。

比起V1,V2中不用再存每一Q分块对应的了。但是在BWD的过程中,我们仍需要来做的重计算,这样才能用链式求导法则把dQ,dK,dV正常算出来。V2在这里用了一个很巧妙的方法,它只存一个东西(代码13行,这样又能进一步减少shared memory的读写):,这个等式中小写的m和l可以理解成是全局的rowmax和rowsum。在接下来BWD的讲解中,我们会来看到这一项的妙用。

(2)V2 BWD

一个建议:如果你在阅读本节中觉得很困惑,一定记得先去看V1的BWD部分,有非常详细的推导介绍。看完再来看本节就很顺畅了。

8a269336-cfc5-11ee-a297-92fbcf53809c.png

我们观察到,在V2 BWD中,内外循环的位置又换回来了,即还是KV外循环,Q内循环,这是为什么呢?

我们知道在BWD的过程中,我们主要是求(为了求它们还需要求中间结果,我们来总结一下这些梯度都需要沿着哪些方向AllReduce:

:沿着i方向做AllReduce,也就是需要每行的结果加总

:沿着i方向做AllReduce,也就是需要每行的结果加总

:沿着j方向做AllReduce,也就是需要每列的结果加总

:只与当前i,j相关

基于此,如果你还是保持Q外循环,KV外循环不变的话,这种操作其实是固定行,遍历列的,那么在这些梯度中,只有从中受益了,K和V的梯度则进入了别扭的循环(也意味着要往shared memory上写更多的中间结果);但如果你采用KV外循环,Q内循环,这样K和V都受益,只有Q独自别扭,因此是一种更好的选择。(S和P的计算不受循环变动影响)。

前面说过,在BWD过程中读写我们要用全局的重新计算,计算公式如下:

但如此一来,我们就要从shared memory上同时读取,似乎有点消耗读写。所以在V2中,我们只存储,然后计算:

很容易发现这两个计算是等价的,但V2的做法节省了读写量

好,现在我们就把V2相对于V1在计算原理上的改进介绍完了。接下来我们总结一下V2相对于V1所有的改进点。

二、V2相对V1的改进点

之所以把这块内容放到“V2整体流程介绍”之后,是想让大家在先理解V2是怎么做的基础上,更好体会V2的优点。

总体来说,V2从以下三个方面做了改进:

置换内外循环位置,同时减少非矩阵的计算量。(这两点我们在第一部分中已给出详细说明)

优化Attention部分thread blocks的并行化计算,新增seq_len维度的并行,使SM的利用率尽量打满。这其实也是内外循环置换这个总体思想配套的改进措施

优化thread blocks内部warp级别的工作模式,尽量减少warp间的通讯和读取shared memory的次数。

第二和第三点都可以归结为是cuda gemm层面的优化,我们马上来细看这两点。

三、V2中的thread blocks排布

//gridDiminV1
//params.b=batch_size,params.h=num_heads
dim3grid(params.b,params.h);

//gridDiminV2
constintnum_m_block=(params.seqlen_q+Kernel_traits::kBlockM-1)/Kernel_traits::kBlockM;
dim3grid(num_m_block,params.b,params.h);

这段代码整合自flash attention github下的cutlass实现,为了方便讲解做了一点改写。

这段代码告诉我们:

在V1中,我们是按batch_size和num_heads来划分block的,也就是说一共有batch_size * num_heads个block,每个block负责计算O矩阵的一部分

在V2中,我们是按batch_size,num_heads和num_m_block来划分block的,其中num_m_block可理解成是沿着Q矩阵行方向做的切分。例如Q矩阵行方向长度为seqlen_q(其实就是我们熟悉的输入序列长度seq_len,也就是图例中的N),我们将其划分成num_m_block份,每份长度为kBlockM(也就是每份维护kBlockM个token)。这样就一共有batch_size * num_heads * num_m_block个block,每个block负责计算矩阵O的一部分。

为什么相比于V1,V2在划分thread block时,要新增Q的seq_len维度上的划分呢?

先说结论,这样做的目的是尽量让SM打满。我们知道block是会被发去SM上执行的。以1块A100 GPU为例,它有108个SM,如果此时我们的block数量比较大(例如论文中所说>=80时),我们就认为GPU的计算资源得到了很好的利用。现在回到我们的输入数据上来,当batch_size和num_heads都比较大时,block也比较多,此时SM利用率比较高。但是如果我们的数据seq_len比较长,此时往往对应着较小的batch_size和num_heads,这是就会有SM在空转了。而为了解决这个问题,我们就可以引入在Q的seq_len上的划分。

看到这里你可能还是有点懵,没关系,我们通过图解的方式,来一起看看V1和V2上的thread block到底长什么样。

3.1 V1 thread block

8a4e28f6-cfc5-11ee-a297-92fbcf53809c.png

假设batch_size = 1,num_heads = 2,我们用不同的颜色来表示不同的head。

我们知道在Multihead Attention中,各个head是可以独立进行计算的,在计算完毕后将结果拼接起来即可。所以我们将1个head划分给1个block,这样就能实现block间的并行计算,如此每个block只要在计算完毕后把结果写入自己所维护的O的对应位置即可。

而每个block内,就能执行V1中的"KV外循环,Q内循环”的过程了,这个过程是由block的再下级warp level层面进行组织,thread实行计算的。这块我们放在第四部分中讲解。

3.2 V2 thread block

8a5b8564-cfc5-11ee-a297-92fbcf53809c.png

现在我们继续假设batch_size = 1,num_heads = 2。

与V1不同的是,我们在Q的seq_len维度上也做了切分,将其分成四份,即num_m_block = 4。所以现在我们共有124 = 8个block在跑。这些block之间的运算也是独立的,因为:

head的计算是独立的,所以红色block和蓝色block互不干扰

采用Q做外循环,KV做内循环时,行与行之间的block是独立的,因此不同行的block互相不干扰。

每个block从Q上加载对应位置的切块,同时从KV上加载head0的切块,计算出自己所维护的那部分O,然后写入O的对应位置。

在这里你可能想问,为什么只对Q的seq_len做了切分,而不对KV的seq_len做切分呢?

在V2的cutlass实现中,确实也提供了对KV的seq_len做切分的方法。但除非你认为SM真得打不满,否则尽量不要在KV维度上做切分,因为如此一来,不同的block之间是没法独立计算的(比如对于O的某一行,它的各个部分来自不同的block,为了得到全局的softmax结果,这些block的结果还需要汇总做一次计算)。

3.3 seq parallel不是V2特有

如果你看过V1的代码,你会发现,其实在V1后期的版本中,也出现了seq维度的并行:

//V1seqparallel:csrc/flash_attn/src/fmha_fwd_launch_template.h
dim3grid(launch_params.params.b,launch_params.params.h,launch_params.params.num_splits);

//nums_splits计算方法
//Findthenumberofsplitsthatmaximizestheoccupancy.Forexample,ifwehave
//batch*n_heads=48andwehave108SMs,having2splits(efficiency=0.89)is
//betterthanhaving3splits(efficiency=0.67).However,wealsodon'twanttoomany
//splitsasthatwouldincurmoreHBMreads/writes.
//Sowefindthebestefficiency,thenfindthesmallestnumberofsplitsthatgets95%
//ofthebestefficiency.
//[2022-11-25]TD:Markthisas"inline"otherwiseweget"multipledefinition"error.
inlineintnum_splits_heuristic_fwd(intbatch_nheads,intnum_SMs,intctas_per_sm,intmax_splits){
floatmax_efficiency=0.f;
std::vectorefficiency;
efficiency.reserve(max_splits);
for(intnum_splits=1;num_splits<= max_splits; num_splits++) {
        float n_waves = float(batch_nheads * num_splits) / (num_SMs * ctas_per_sm);
        float eff = n_waves / ceil(n_waves);
        // printf("num_splits = %d, eff = %f
", num_splits, eff);
        if (eff >max_efficiency){max_efficiency=eff;}
efficiency.push_back(eff);
}
for(intnum_splits=1;num_splits<= max_splits; num_splits++) {
        if (efficiency[num_splits - 1] >0.95*max_efficiency){
//printf("num_splitschosen=%d
",num_splits);
returnnum_splits;
}
}
return1;
}

....
//可以发现num_splits也是由Q的seq_len维度切分来的
launch_params.params.num_splits=num_splits_heuristic_fwd(
launch_params.params.b*launch_params.params.h,dprops->multiProcessorCount,
ctas_per_sm,
/*max_splits=*/std::min(30,(launch_params.params.seqlen_q+M-1/M))
);

上图代码中的num_splits也是在由Q的seq_len维度切分来的。通过这段代码,我猜想作者在V1后期引入seq_len维度切分的原因是:V1也需要解决seq_len过长时,batch_size和num_heads较小而造成SM打不满的问题。

num_splits_heuristic_fwd这个函数的作用概括起来就是,我先提供一连串num_splits值的备选,然后由这个函数计算出每个备选值下SM的利用率。计算完之后,我先找到最高的利用率,然后再找出满足利用率>=0.95 * max(利用率)的那个最小的num_split值,作为最终的选择。

细心的你此时可能已经观察到了,虽然V1也引进过seq parallel,但是它的grid组织形式时(batch_size, num_heads, num_m_blocks),但V2的组织形式是(num_m_blocks, batch_size, num_heads),这种顺序调换的意义是什么呢?

直接说结论,这样的调换是为了提升L2 cache hit rate。大家可以看下3.2中的图(虽然block实际执行时不一定按照图中的序号),对于同一列的block,它们读的是KV的相同部分,因此同一列block在读取数据时,有很大概率可以直接从L2 cache上读到自己要的数据(别的block之前取过的)。

3.4 FWD和BWD过程中的thread block划分

在3.1~3.3中,我们其实给出的是FWD过程中thread block的划分方式,我们知道V2中FWD和BWD的内外循环不一致,所以对应来说,thread block的划分也会有所不同,我们详细来看:

8a5fe514-cfc5-11ee-a297-92fbcf53809c.png

在图中:

worker表示thread block,不同的thread block用不同颜色表示

整个大方框表示输出矩阵O

我们先看左图,它表示FWD下thread block的结构。每一行都有一个worker,它表示O矩阵的每一行都是由一个thread block计算出来的(假设num_heads = 1),这就对应到我们3.1~3.3中说的划分方式。那么白色的部分表示什么呢?我们知道如果采用的是casual attention,那么有一部分是会被mask掉的,所以这里用白色来表示。但这不意味着thread block不需要加载白色部分数据对应的KV块,只是说在计算的过程中它们会因被mask掉而免于计算(论文中的casual mask一节有提过)。

我们再看右图,它表示BWD下thread block的结构,每一列对应一个worker,这是因为BWD中我们是KV做外循环,Q做内循环,这种情况下dK, dV都是按行累加的,而dQ是按列累加的,少数服从多数,因此这里thread_block是按的列划分的。

四、Warp级别并行

8a6a9bda-cfc5-11ee-a297-92fbcf53809c.png

讲完了thread block,我们就可以再下一级,看到warp level级别的并行了。左图表示V1,右图表示V2。不管是V1还是V2,在Ampere架构下,每个block内进一步被划分为4个warp,在Hopper架构下则是8个warp。

在左图(V1)中,每个warp都从shared memory上读取相同的Q块以及自己所负责计算的KV块。在V1中,每个warp只是计算出了列方向上的结果,这些列方向上的结果必须汇总起来,才能得到最终O矩阵行方向上的对应结果。所以每个warp需要把自己算出来的中间结果写到shared memory上,再由一个warp(例如warp1)进行统一的整合。所以各个warp间需要通讯、需要写中间结果,这就影响了计算效率。

在左图(V2)中,每个warp都从shared memory上读取相同的KV块以及自己所负责计算的Q块。在V2中,行方向上的计算是完全独立的,即每个warp把自己计算出的结果写到O的对应位置即可,warp间不需要再做通讯,通过这种方式提升了计算效率。不过这种warp并行方式在V2的BWD过程中就有缺陷了:由于bwd中dK和dV是在行方向上的AllReduce,所以这种切分方式会导致warp间需要通讯。

针对V2 warp切分影响BWD这点,作者在论文中依然给出了“BWD过程相比V1也有提升”的结论,针对这点,我在github issue上找到了一条作者的回复(在“安装报错”组成的issue海洋里捞出的宝贵一条):

8a815a64-cfc5-11ee-a297-92fbcf53809c.png

最关键的可能是第1和第2点,关于第1点,我想作者应该是说,之前需要反复读取KV的数据,现在只用反复读取Q的数据,因此从一定程度上节省了shared memory的读写次数。第2点理解起来有点复杂,个人觉得是将warp处理的tile划分得更像方形。这样做的好处是在做casual mask的时候可以方便写代码大块丢掉被mask掉的tile(见论文casual masking部分),进一步加速计算。第3点是关于一些底层的优化,就不提了。

好!关于V2我们就介绍到这了,写这篇文章的时候,我刚粗过了一遍triton的flash attention实现,以及扫了一下cutlass实现的入口。如果后续有时间,我会出一些源码解读的文章(从cuda gemm -> triton gemm -> triton flash attention,看,又给自己挖了一个坑)。如果出不了,那一定不是我鸽人,那肯定是我不会(没错,就是这样)。

审核编辑:黄飞

声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉
  • gpu
    gpu
    +关注

    关注

    27

    文章

    4424

    浏览量

    126737
  • 并行计算
    +关注

    关注

    0

    文章

    27

    浏览量

    9369
  • 大模型
    +关注

    关注

    2

    文章

    1550

    浏览量

    1146

原文标题:图解大模型计算加速系列:Flash Attention V2,从原理到并行计算

文章出处:【微信号:zenRRan,微信公众号:深度学习自然语言处理】欢迎添加关注!文章转载请注明出处。

收藏 人收藏

    评论

    相关推荐

    LT8228从V1到V2是BUCK模式,从V2V1是BOOST模式,请问这个模式是固定的吗?

    看LT8228的示例,从V1到V2是BUCK模式,从V2V1是BOOST模式,请问,这个模式是固定的吗? 假如我输入V1是48
    发表于 01-05 07:10

    使用uio Stick v2升级uIO-Stick_Firmware后无法连接怎么解决?

    使用uio Stick v2(不是uio Stick),通过config wizard for smart power switches工具控制SPOC系列智能功率开关,使用时通过updata
    发表于 01-26 07:06

    硬件平台介绍及使用(SC171开发套件V2

    /27203.html *附件:SC171开发套件V2开箱 _20240408.pdf 硬件 硬件测试使用流程指导_SC171开发套件V2 *附件:硬件测试使用流程指导_SC171
    发表于 04-09 15:34

    [转载] STLINK V2安装使用详解

    STLINK V2安装使用详解1.解压st-link_v2_u*** driver.zip文件。2. 运行解压后的st-link_v2_u***driver.exe文件,安装STLIN
    发表于 05-31 21:50

    求助如下电路V2的输出电压是多少

    两个器件是一样的,都是轨到轨输出运放,V1无疑是2VV2的输出电压很诡异,第一张无二极管的V2输出1.9V多点,第二张有二极管的
    发表于 12-23 21:31

    我可以使用ST-LINK/V2为STM8S系列编程闪存吗?

    大家好,我可以使用''ST-LINK / V2''为STM8S系列编程闪存吗? #ST-LINK / V2以上来自于谷歌翻译以下为原文 Hi to all,Can i use ''ST-LINK/
    发表于 09-17 12:53

    ST-Link/V2与HALT模式不兼容

    大家好,ST-Link / V2工具(所以SWIM接口)与HALT指令不兼容。 因此,为了调试和/或使用STMStudio,我不得不使用这种功能模式。是否有任何解决方案可以优雅地解决问题
    发表于 10-30 17:57

    如何使用Matlab Wrapper与V2进行通信?

    我购买了MKI062V2演示板,我在使用st.com提供的AN3138 Matlab Wrapper时遇到了问题。包装器用于V1,但我找不到V2。 代码编译但只有流零
    发表于 06-04 09:45

    在哪里可以找到ST-LINK/V2主机接口详细信息

    ,但我查看了STM8FLASH的代码(在GitHib上),它表明我需要了解更多关于ST-LINK / V2的主机端接口。非常感谢任何帮助。背景:STM8正在运行自定义代码。我想使用调试模块来停止cpu
    发表于 07-02 11:18

    MobileNet v2的学习笔记

    Classification基础实验系列四——MobileNet v2论文笔记与复现
    发表于 04-03 11:34

    ST-link/V2 SWIM接口是如何定义的呢

    ST-link/V2是什么?有哪些属性?ST-link/V2 SWIM接口是如何定义的呢?
    发表于 02-10 06:07

    串口框架V1和V2版本对比差异及其串口V2的一些设计实现思想

    流程能够统一。V2版本的串口框架(以及驱动)主要改动点:取消了硬件工作模式的判断,硬件工作模式由驱动层支持,使得框架层与 硬件工作模式 无关;统一操作接口,应用层不再关心 硬件工作模式,统一使用 阻塞
    发表于 08-02 15:52

    OnSpeed Gen2 v2的资料分享

    描述OnSpeed Gen2 v2
    发表于 09-09 07:18

    如何将st-link固件从V2更改为v2-1?

    大家好我有 St-link v2,但在 stm32mp157-ev1 中调试 Cortex M 时遇到问题。我认为是因为keil文档中的HW Vesion和FW的差异以及我的结果。在附件图片
    发表于 02-07 08:30

    正点原子STM32系列——精英V2开发板评测报告聚合帖

    华仔stm32 【正点原子STM32精英V2开发板体验】初识精英V2开发板【正点原子STM32精英V2开发板体验】资料获取方式 【正点原子STM32精英V2开发板体验《精英
    发表于 05-10 11:17