0
  • 聊天消息
  • 系统消息
  • 评论与回复
登录后你可以
  • 下载海量资料
  • 学习在线课程
  • 观看技术视频
  • 写文章/发帖/加入社区
会员中心
创作中心

完善资料让更多小伙伴认识你,还能领取20积分哦,立即完善>

3天内不再提示

FlashAttention2详解(性能比FlashAttention提升200%)

jf_pmFSk4VX 来源:GiantPandaCV 2023-11-24 16:21 次阅读
加入交流群
微信小助手二维码

扫码添加小助手

加入工程师交流群

摘要

在过去几年中,如何扩展Transformer使之能够处理更长的序列一直是一个重要问题,因为这能提高Transformer语言建模性能和高分辨率图像理解能力,以及解锁代码、音频和视频生成等新应用。然而增加序列长度,注意力层是主要瓶颈,因为它的运行时间和内存会随序列长度的增加呈二次(平方)增加。FlashAttention利用GPU非匀称的存储器层次结构,实现了显著的内存节省(从平方增加转为线性增加)和计算加速(提速2-4倍),而且计算结果保持一致。但是,FlashAttention仍然不如优化的矩阵乘法(GEMM)操作快,只达到理论最大FLOPs/s的25-40%。作者观察到,这种低效是由于GPU对不同thread blocks和warps工作分配不是最优的,造成了利用率低和不必要的共享内存读写。因此,本文提出了FlashAttention-2以解决这些问题。

简介

如何扩展Transformer使之能够处理更长的序列一直是一个挑战,**因为其核心注意力层的运行时间和内存占用量随输入序列长度成二次增加。**我们希望能够打破2k序列长度限制,从而能够训练书籍、高分辨率图像和长视频。此外,写作等应用也需要模型能够处理长序列。过去一年中,业界推出了一些远超之前长度的语言模型:GPT-4为32k,MosaicML的MPT为65k,以及Anthropic的Claude为100k。

虽然相比标准Attention,FlashAttention快了2~4倍,节约了10~20倍内存,但是离设备理论最大throughput和flops还差了很多。本文提出了FlashAttention-2,它具有更好的并行性和工作分区。实验结果显示,FlashAttention-2在正向传递中实现了约2倍的速度提升,达到了理论最大吞吐量的73%,在反向传递中达到了理论最大吞吐量的63%。在每个A100 GPU上的训练速度可达到225 TFLOPs/s。

本文主要贡献和创新点为:

1. 减少了non-matmul FLOPs的数量(消除了原先频繁rescale)。虽然non-matmul FLOPs仅占总FLOPs的一小部分,但它们的执行时间较长,这是因为GPU有专用的矩阵乘法计算单元,其吞吐量高达非矩阵乘法吞吐量的16倍。因此,减少non-matmul FLOPs并尽可能多地执行matmul FLOPs非常重要。

2. 提出了在序列长度维度上并行化。该方法在输入序列很长(此时batch size通常很小)的情况下增加了GPU利用率。即使对于单个head,也在不同的thread block之间进行并行计算。

3. 在一个attention计算块内,将工作分配在一个thread block的不同warp上,以减少通信和共享内存读/写。

动机

为了解决这个问题,研究者们也提出了很多近似的attention算法,然而目前使用最多的还是标准attention。FlashAttention利用tiling、recomputation等技术显著提升了计算速度(提升了2~4倍),并且将内存占用从平方代价将为线性代价(节约了10~20倍内存)。虽然FlashAttention效果很好,但是仍然不如其他基本操作(如矩阵乘法)高效。例如,其前向推理仅达到GPU(A100)理论最大FLOPs/s的30-50%(下图);反向传播更具挑战性,在A100上仅达到最大吞吐量的25-35%。相比之下,优化后的GEMM(矩阵乘法)可以达到最大吞吐量的80-90%。通过观察分析,这种低效是由于GPU对不同thread blocks和warps工作分配不是最优的,造成了利用率低和不必要的共享内存读写。

959fa84a-76f8-11ee-939d-92fbcf53809c.jpg

Attention forward speed on A100 GPU. (Source: Figure 5 of the paper.)

背景知识

下面介绍一些关于GPU的性能和计算特点,有关Attention和FlashAttention的详细内容请参考第一篇文章

FlashAttention图解(如何加速Attention)

GPU

GPU performance characteristics.GPU主要计算单元(如浮点运算单元)和内存层次结构。大多数现代GPU包含专用的低精度矩阵乘法单元(如Nvidia GPU的Tensor Core用于FP16/BF16矩阵乘法)。内存层次结构分为高带宽内存(High Bandwidth Memory, HBM)和片上SRAM(也称为shared memory)。以A100 GPU为例,它具有40-80GB的HBM,带宽为1.5-2.0TB/s,每个108个streaming multiprocessors共享的SRAM为192KB,带宽约为19TB/s。

这里忽略了L2缓存,因为不能直接被由程序员控制。

95acec58-76f8-11ee-939d-92fbcf53809c.jpg

CUDA的软件和硬件架构

从Hardware角度来看:

Streaming Processor(SP):是最基本的处理单元,从fermi架构开始被叫做CUDA core。

Streaming MultiProcessor(SM):一个SM由多个CUDA core(SP)组成,每个SM在不同GPU架构上有不同数量的CUDA core,例如Pascal架构中一个SM有128个CUDA core。

SM还包括特殊运算单元(SFU),共享内存(shared memory),寄存器文件(Register File)和调度器(Warp Scheduler)等。register和shared memory是稀缺资源,这些有限的资源就使每个SM中active warps有非常严格的限制,也就限制了并行能力。

从Software(编程)角度来看:

95b87a46-76f8-11ee-939d-92fbcf53809c.jpg

CUDA软件示例

thread是最基本的执行单元(the basic unit of execution)。

warp是SM中最小的调度单位(the smallest scheduling unit on an SM),一个SM可以同时处理多个warp

thread block是GPU执行的最小单位(the smallest unit of execution on the GPU)。

一个warp中的threads必然在同一个block中,如果block所含thread数量不是warp大小的整数倍,那么多出的那个warp中会剩余一些inactive的thread。也就是说,即使warp的thread数量不足,硬件也会为warp凑足thread,只不过这些thread是inactive状态,但也会消耗SM资源。

thread:一个CUDA并行程序由多个thread来执行

warp:一个warp通常包含32个thread。每个warp中的thread可以同时执行相同的指令,从而实现SIMT(单指令多线程)并行。

thread block:一个thread block可以包含多个warp,同一个block中的thread可以同步,也可以通过shared memory进行通信。

grid:在GPU编程中,grid是一个由多个thread block组成的二维或三维数组。grid的大小取决于计算任务的规模和thread block的大小,通常根据计算任务的特点和GPU性能来进行调整。

Hardware和Software的联系:

SM采用的是Single-Instruction Multiple-Thread(SIMT,单指令多线程)架构,warp是最基本的执行单元,一个warp包含32个并行thread,这些thread以不同数据资源执行相同的指令。

当一个kernel被执行时,grid中的thread block被分配到SM上,大量的thread可能被分到不同的SM上,但是一个线程块的thread只能在一个SM上调度,SM一般可以调度多个block。每个thread拥有自己的程序计数器和状态寄存器,并且可以使用不同的数据来执行指令,从而实现并行计算,这就是所谓的Single Instruction Multiple Thread。

一个CUDA core可以执行一个thread,一个SM中的CUDA core会被分成几个warp,由warp scheduler负责调度。GPU规定warp中所有thread在同一周期执行相同的指令,尽管这些thread执行同一程序地址,但可能产生不同的行为,比如分支结构。一个SM同时并发的warp是有限的,由于资源限制,SM要为每个block分配共享内存,也要为每个warp中的thread分配独立的寄存器,所以SM的配置会影响其所支持的block和warp并发数量。

GPU执行模型小结:

GPU有大量的threads用于执行操作(an operation,也称为a kernel)。这些thread组成了thread block,接着这些blocks被调度在SMs上运行。在每个thread block中,threads被组成了warps(32个threads为一组)。一个warp内的threads可以通过快速shuffle指令进行通信或者合作执行矩阵乘法。在每个thread block内部,warps可以通过读取/写入共享内存进行通信。每个kernel从HBM加载数据到寄存器和SRAM中,进行计算,最后将结果写回HBM中。

FlashAttention

FlashAttention应用了tiling技术来减少内存访问,具体来说:

1. 从HBM中加载输入数据(K,Q,V)的一部分到SRAM中

2. 计算这部分数据的Attention结果

3. 更新输出到HBM,但是无需存储中间数据S和P

下图展示了一个示例:首先将K和V分成两部分(K1和K2,V1和V2,具体如何划分根据数据大小和GPU特性调整),根据K1和Q可以计算得到S1和A1,然后结合V1得到O1。接着计算第二部分,根据K2和Q可以计算得到S2和A2,然后结合V2得到O2。最后O2和O1一起得到Attention结果。

95d09586-76f8-11ee-939d-92fbcf53809c.jpg

值得注意的是,输入数据K、Q、V是存储在HBM上的,中间结果S、A都不需要存储到HBM上。通过这种方式,FlashAttention可以将内存开销降低到线性级别,并实现了2-4倍的加速,同时避免了对中间结果的频繁读写,从而提高了计算效率。

FlashAttention-2

经过铺垫,正式进入正文。我们先讲述FlashAttention-2对FlashAttention的改进,从而减少了非矩阵乘法运算(non-matmul)的FLOPs。然后说明如何将任务分配给不同的thread block进行并行计算,充分利用GPU资源。最后描述了如何在一个thread block内部分配任务给不同的warps,以减少访问共享内存次数。这些优化方案使得FlashAttention-2的性能提升了2-3倍。

Algorithm

FlashAttention在FlashAttention算法基础上进行了调整,减少了非矩阵乘法运算(non-matmul)的FLOPs。这是因为现代GPU有针对matmul(GEMM)专用的计算单元(如Nvidia GPU上的Tensor Cores),效率很高。以A100 GPU为例,其FP16/BF16矩阵乘法的最大理论吞吐量为312 TFLOPs/s,但FP32非矩阵乘法仅有19.5 TFLOPs/s,即每个no-matmul FLOP比mat-mul FLOP昂贵16倍。为了确保高吞吐量(例如超过最大理论TFLOPs/s的50%),我们希望尽可能将时间花在matmul FLOPs上。

Forward pass

通常实现Softmax算子为了数值稳定性(因为指数增长太快,数值会过大甚至溢出),会减去最大值:

95d47dae-76f8-11ee-939d-92fbcf53809c.png

这样带来的代价就是要对95df671e-76f8-11ee-939d-92fbcf53809c.png遍历3次。

为了减少non-matmul FLOPs,本文在FlashAttention基础上做了两点改进:

95ed1a30-76f8-11ee-939d-92fbcf53809c.png

95f7f234-76f8-11ee-939d-92fbcf53809c.png

960c8b2c-76f8-11ee-939d-92fbcf53809c.png

简单示例的FlashAttention完整计算步骤(红色部分表示V1和V2区别):

9615dea2-76f8-11ee-939d-92fbcf53809c.jpg

FlashAttention-2的完整计算步骤(红色部分表示V1和V2区别):

962181e4-76f8-11ee-939d-92fbcf53809c.png

962a8410-76f8-11ee-939d-92fbcf53809c.jpg

有了上面分析和之前对FlashAttention的讲解,再看下面伪代码就没什么问题了。

962e598c-76f8-11ee-939d-92fbcf53809c.jpg

Causal masking是attention的一个常见操作,特别是在自回归语言建模中,需要对注意力矩阵S应用因果掩码(即任何S ,其中 > 的条目都设置为−∞)。

1. 由于FlashAttention和FlashAttention-2已经通过块操作来实现,对于所有列索引都大于行索引的块(大约占总块数的一半),我们可以跳过该块的计算。这比没有应用因果掩码的注意力计算速度提高了1.7-1.8倍。

2. 不需要对那些行索引严格小于列索引的块应用因果掩码。这意味着对于每一行,我们只需要对1个块应用因果掩码。

Parallelism

FlashAttention在batch和heads两个维度上进行了并行化:使用一个thread block来处理一个attention head,总共需要thread block的数量等于batch size × number of heads。每个block被调到到一个SM上运行,例如A100 GPU上有108个SMs。当block数量很大时(例如≥80),这种调度方式是高效的,因为几乎可以有效利用GPU上所有计算资源。

但是在处理长序列输入时,由于内存限制,通常会减小batch size和head数量,这样并行化成都就降低了。因此,FlashAttention-2还在序列长度这一维度上进行并行化,显著提升了计算速度。此外,当batch size和head数量较小时,在序列长度上增加并行性有助于提高GPU占用率。

96415c80-76f8-11ee-939d-92fbcf53809c.png

Work Partitioning Between Warps

上一节讨论了如何分配thread block,然而在每个thread block内部,我们也需要决定如何在不同的warp之间分配工作。我们通常在每个thread block中使用4或8个warp,如下图所示。

964a3526-76f8-11ee-939d-92fbcf53809c.jpg

Work partitioning between different warps in the forward pass

964e7a6e-76f8-11ee-939d-92fbcf53809c.png

论文中原话是”However, this is inefficient since all warps need to write their intermediate results out toshared memory, synchronize, then add up the intermediate results.”,说的是shared memory而非HBM,但是结合下图黄色框部分推断,我认为是HBM。

966cc258-76f8-11ee-939d-92fbcf53809c.jpg

96714378-76f8-11ee-939d-92fbcf53809c.png

967ddbe2-76f8-11ee-939d-92fbcf53809c.jpg

声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉
  • 存储器
    +关注

    关注

    39

    文章

    7715

    浏览量

    170877
  • gpu
    gpu
    +关注

    关注

    28

    文章

    5100

    浏览量

    134479
  • 矩阵
    +关注

    关注

    1

    文章

    441

    浏览量

    35820

原文标题:FlashAttention2详解(性能比FlashAttention提升200%)

文章出处:【微信号:GiantPandaCV,微信公众号:GiantPandaCV】欢迎添加关注!文章转载请注明出处。

收藏 人收藏
加入交流群
微信小助手二维码

扫码添加小助手

加入工程师交流群

    评论

    相关推荐
    热点推荐

    SiLM2024CA-DG 200V半桥驱动器在工业应用中的性能解析

    支持CMOS/LSTTL标准 3.保护功能 欠压锁定保护 瞬态负压耐受 dV/dt噪声免疫 交叉导通防止逻辑 技术优势详解:1. 可靠性设计 该驱动器采用专利HVIC技术,确保在200V高压环境下
    发表于 11-22 10:50

    谷歌云发布最强自研TPU,性能比前代提升4倍

    电子发烧友网报道(文/李弯弯)近日,谷歌云在官方博客上正式宣布,公司成功推出第七代TPU(张量处理器)“Ironwood”,该芯片预计在未来几周内正式上市。   “Ironwood”由谷歌自主精心设计,能够轻松处理从大型模型训练到实时聊天机器人运行以及AI智能体操作等各类复杂任务。   谷歌在新闻稿中着重强调,“Ironwood”是专为应对最严苛的工作负载而打造的。无论是大规模模型训练、复杂的强化学习(RL),还是高容量、低延迟的AI推理和模型服务,
    的头像 发表于 11-13 07:49 8158次阅读
    谷歌云发布最强自研TPU,<b class='flag-5'>性能比</b>前代<b class='flag-5'>提升</b>4倍

    低 ESR 设计降损耗:车规铝电解电容提升电机驱动 EMC 性能

    低ESR车规铝电解电容通过优化材料、结构与工艺,显著降低电机驱动系统的电磁干扰(EMI)和能量损耗,提升电磁兼容性(EMC)性能,成为高压环境下的关键解决方案。 以下从技术原理、性能优势、应用场
    的头像 发表于 10-20 16:54 523次阅读
    低 ESR 设计降损耗:车规铝电解电容<b class='flag-5'>提升</b>电机驱动 EMC <b class='flag-5'>性能</b>

    小白学大模型:大模型加速的秘密 FlashAttention 1/2/3

    在Transformer架构中,注意力机制的计算复杂度与序列长度(即文本长度)呈平方关系()。这意味着,当模型需要处理更长的文本时(比如从几千个词到几万个词),计算时间和所需的内存会急剧增加。最开始的标准注意力机制存在两个主要问题:内存占用高:模型需要生成一个巨大的注意力矩阵(N×N)。这个矩阵需要被保存在高带宽内存(HBM)中。对于长序列,这很快就会超出G
    的头像 发表于 09-10 09:28 4341次阅读
    小白学大模型:大模型加速的秘密 <b class='flag-5'>FlashAttention</b> 1/<b class='flag-5'>2</b>/3

    EV12AS200A的采样延迟微调如何提升相位精度?

    提前或延后,步进就是 24 fs。3. 相位精度提升的数学关系• 对于 1.5 GSPS、3.3 GHz 满功率带宽,24 fs 对应相位误差 ≈ 2π × 3.3 GHz × 24 fs ≈ 0.5
    发表于 08-04 08:46

    什么是共模抑制

    共模抑制详解在探头的数据手册上,共模抑制性能参数是核心指标之一。共模抑制又名CMRR,通常用分贝(dB)来表示,其计算公式为:其中其中
    的头像 发表于 06-23 09:45 986次阅读
    什么是共模抑制<b class='flag-5'>比</b>?

    进迭时空第三代高性能核X200研发进展

    继X60和X100之后,进迭时空正在基于开源香山昆明湖架构研发第三代高性能处理器核X200。与进迭时空的第二代高性能核X100相,X200
    的头像 发表于 06-06 16:56 1144次阅读
    进迭时空第三代高<b class='flag-5'>性能</b>核X<b class='flag-5'>200</b>研发进展

    快手上线鸿蒙应用高性能解决方案:数据反序列化性能提升90%

    近日,快手在Gitee平台上线了鸿蒙应用性能优化解决方案“QuickTransformer”,该方案针对鸿蒙应用开发中广泛使用的三方库“class-transformer”进行了深度优化,有效提升
    发表于 05-15 10:01

    能效和算力提升的衡量方法

    /h·W表示。 影响因素及优化方向‌ 技术升级‌:采用变频技术、高效电机等可提升能效,例如变频空调通过动态调节功率减少能耗。 环境因素‌:温度、湿度等外部条件会影响实际能效表现,需结合具体场景评估。 系统优化‌:通过维护保养(如清洁滤网)和合理选
    的头像 发表于 04-28 07:47 2797次阅读
    能效<b class='flag-5'>比</b>和算力<b class='flag-5'>提升</b>的衡量方法

    直线电机与旋转电机性能比

    直线电机与旋转电机作为现代工业驱动系统的两大核心组件,各自拥有独特的性能特点和适用场景。本文将从速度、加速度、精度、动态响应、结构及应用领域等多个维度,对直线电机与旋转电机进行全面而深入的性能比
    的头像 发表于 03-16 16:55 1448次阅读

    烧结银的导电性能比其他导电胶优势有哪些???

    烧结银的导电性能比其他导电胶优势有哪些???
    的头像 发表于 02-27 21:41 557次阅读

    ADS1212、ADS1231和ADS1230这3种AD芯片性能比

    AD芯片性能比较大家好: 小弟正在做一个静态应变处理电路,使用120Ω的应变片,采样频率为1Hz,分辨率为1微应变!之前试过用分立元件搭建,可温漂效果总是不理想,最后想用集成PGA+低通滤波
    发表于 01-21 06:47

    FDD网络性能提升的方法

    提升FDD(Frequency Division Duplex,频分双工)网络性能的方法可以从多个方面入手,以下是一些具体的策略: 一、硬件升级与优化 升级硬件设备 : 更换为性能更强的基带处理器或
    的头像 发表于 01-07 17:16 1242次阅读

    台积电2纳米制程技术细节公布:性能功耗双提升

    在近日于旧金山举行的IEEE国际电子器件会议(IEDM)上,全球领先的晶圆代工企业台积电揭晓了其备受期待的2纳米(N2)制程技术的详细规格。 据台积电介绍,相较于前代制程技术,N2制程在性能
    的头像 发表于 12-19 10:28 1192次阅读

    台积电2nm制成细节公布:性能提升15%,功耗降低35%

    的显著进步。 台积电在会上重点介绍了其2纳米“纳米片(nanosheets)”技术。据介绍,相较于前代制程,N2制程在性能提升了15%,功耗降低了高达30%,能效显著
    的头像 发表于 12-18 16:15 1211次阅读