0
  • 聊天消息
  • 系统消息
  • 评论与回复
登录后你可以
  • 下载海量资料
  • 学习在线课程
  • 观看技术视频
  • 写文章/发帖/加入社区
创作中心

完善资料让更多小伙伴认识你,还能领取20积分哦,立即完善>

3天内不再提示

理解KV cache的作用及优化方法

jf_pmFSk4VX 来源:知乎 2023-12-04 15:24 次阅读

作者丨紫气东来

在 Transformer 的 Encoder-base 的模型(如 BERT系列)中,推理和训练过程保持了高度的统一性(差异仅仅在于是否存在反向过程)。而在 Decoder-base 的生成式模型(如 GPT系列)中,推理和训练存在相当大的差异性,主要体现在推理过程具有以下3点特征:

自回归

两阶段(第一阶段输入 prompt,第二阶段输入上一个生成的token)

KV cache

以上三点实际上也是相辅相成、不可分割的,其中自回归的生成模式是根本原因,两阶段是外在的体现形式,KV cache是优化手段。

下面将通过梳理整个推理过程,来理解 KV cache 的作用及优化方法。

一、KV cache 的由来与基本矛盾

885a422a-9125-11ee-939d-92fbcf53809c.png

第一阶段(prompt 输入):

88749c6a-9125-11ee-939d-92fbcf53809c.png

8884c4a0-9125-11ee-939d-92fbcf53809c.png

888bc2aa-9125-11ee-939d-92fbcf53809c.png

889cd7f2-9125-11ee-939d-92fbcf53809c.jpg

KV cache 作用过程

第二阶段(token by token):

88b78a48-9125-11ee-939d-92fbcf53809c.png

88bbef52-9125-11ee-939d-92fbcf53809c.png

88c97186-9125-11ee-939d-92fbcf53809c.png

KV cache的显存占用分析

88d47d88-9125-11ee-939d-92fbcf53809c.png

88e1e108-9125-11ee-939d-92fbcf53809c.png

batch size s+n KV cache(GB) KV cache/weight
4 4096 81 0.23
16 4096 324 0.93
64 4096 1297 3.71

可见随着 batch size 和 长度的增大,KV cache 占用的显存开销快速增大,甚至会超过模型本身。

而 LLM 的窗口长度也在不断增大,因此就出现一组主要矛盾,即:对不断增长的 LLM 的窗口长度的需要与有限的 GPU 显存之间的矛盾。因此优化 KV cache 就显得非常必要。

二、KV cache 优化的典型方法

2.1 共用 KV cache:MQA,GQA

MQA (Multi Query Attention,多查询注意力) 是多头注意力的一种变体。其主要区别在于,在 MQA 中不同的注意力头共享一个K和V的集合,每个头只单独保留了一份查询参数。因此K和V的矩阵仅有一份,这大幅度减少了显存占用,使其更高效。由于MQA改变了注意力机制的结构,因此模型通常需要从训练开始就支持 MQA 。也可以通过对已经训练好的模型进行微调来添加多查询注意力支持,仅需要约 5% 的原始训练数据量 就可以达到不错的效果。包括 Falcon、SantaCoder、StarCoder 等在内很多模型都采用了 MQA 机制。

# Multi Head Attention
self.Wqkv = nn.Linear(     # Multi-Head Attention 的创建方法
    self.d_model,
    3 * self.d_model,     # Q、K和V 3 个矩阵, 所以是 3 * d_model
    device=device
)
query, key, value = qkv.chunk(3, dim=2)      # 每个 tensor 都是 (1, 512, 768)

# Multi Query Attention
self.Wqkv = nn.Linear(       # Multi-Query Attention 的创建方法
    d_model,
    d_model + 2 * self.head_dim,    # 只创建Q的头向量,所以是 1* d_model, 而K和V不再具备单独的头向量, 所以是 2 * self.head_dim
    device=device,
)
query, key, value = qkv.split(
    [self.d_model, self.head_dim, self.head_dim],    # query -> (1, 512, 768), key   -> (1, 512, 96), value -> (1, 512, 96)
    dim=2
)

88ec3ba8-9125-11ee-939d-92fbcf53809c.jpg

MHA v.s. GQA v.s. MQA

GQA(Grouped Query Attention,分组查询注意力)是一种介于多头注意力和 MQA 之间的折中方案。它将查询头(Query Heads)分组,并在每组中共享一个键头(Key Head)和一个值头(Value Head)。表达能力与推理速度:GQA既保留了多头注意力的一定表达能力,又通过减少内存访问压力来加速推理速度。

88f5d97e-9125-11ee-939d-92fbcf53809c.jpg

MHA, GQA, MQA 性能比较

2.2 窗口优化

890f5b60-9125-11ee-939d-92fbcf53809c.png

891f68b6-9125-11ee-939d-92fbcf53809c.jpg

3)箭型 attention 窗口,在LM-Infinit中就已经被提出了,其基本原理和StreamingLLM是一致的。

89312f42-9125-11ee-939d-92fbcf53809c.jpg

2.3 量化与稀疏

该类方法是基于压缩的思想,通过量化与稀疏压缩 KV cache 的 显存消耗。

当前主流推理框架都在逐步支持 KV cache 量化,一个典型的案例是lmdeploy,下图展示了其在TurboMind框架下 KV INT8 的支持情况。

893c6b6e-9125-11ee-939d-92fbcf53809c.jpg

lmdeploy 的推理特性

稀疏的方法也比较简单,其做法无外乎以下几种方式:

894638b0-9125-11ee-939d-92fbcf53809c.jpg

这里最值得一提的是H2O。简单来说就是通过动态的评价方式来判断需要保留和废弃的KV值,其评估的算法如下所示:

895912a0-9125-11ee-939d-92fbcf53809c.jpg

结果显示,在 KV cache 稀疏到只有原来的 20% 时仍然可以保持很高的精度。

89688564-9125-11ee-939d-92fbcf53809c.jpg

2.4 存储与计算优化

该方法的典型代表即vLLM的 PagedAttention,简单来说就是允许在非连续的内存空间中存储连续的 K 和 V。详情可参考笔者之前的文章,在此不予赘述

FlashDecoding 是在 FlashAttention 的基础上针对 inference 的优化主要分为三步:

长文本下将KV分成更小且方便并行的chunk

对每个chunk的KV,Q和他们进行之前一样的FlashAttention获取这个chunk的结果

对每个chunk的结果进行reduce

8977e086-9125-11ee-939d-92fbcf53809c.gif

三、StreamingLLM:简洁高效的“无限长度”

StreamingLLM 的基本思想同样是来源于上述的窗口思想,其最大的创新在于提出了识别并保存模型固有的「注意力池」(attention sinks)锚定其推理的初始 token。下面将详细讨论其工作的原理。

3.1 精度是如何保证的?

核心的发现:Lost in the Middle。

多个研究都发现,self-attention 的注意力比较集中于头部和尾部,对文本中段的注意力相对较弱,如下图所示:

89ac0e4c-9125-11ee-939d-92fbcf53809c.jpg

绘制出 self-attention 的热力图也能看到这一点,由此当文本长度超过额定长度时,头部的 token 就会被遗弃掉,这就会在 softmax 阶段产生很大的问题。

89b64c0e-9125-11ee-939d-92fbcf53809c.jpg

89ce455c-9125-11ee-939d-92fbcf53809c.png

89d52ad4-9125-11ee-939d-92fbcf53809c.png

3.2 “无限长度”是如何做到的?

该问实际上可以换种表述为:如何在文本长度不断增加的情况下,保证GPU显存不会溢出。由于该方案主要应用于多轮对话的场景,那么有必要回顾一下当前多轮对话生成的主流做法,概括起来就以下几点:

将用户输入与模型输出拼接,中间做必要分割;

多个轮次之间倒序排列,并拼接;

如果前边所有轮次长度之和超过最大长度,则截断到最大长度;

上述过程可以用代码描述如下:

  history = ["
[|Human|]{}
[|AI|]{}".format(x[0], x[1]) for x in history]
  history.append("
[|Human|]{}
[|AI|]".format(text))
  history_text = ""
  flag = False
  for x in history[::-1]:
    if tokenizer(prompt + history_text + x, return_tensors="pt")["input_ids"].size(-1) <= max_length:
            history_text = x + history_text
            flag = True
        else:
            break
    if flag:
        inputs = tokenizer(prompt + history_text, return_tensors="pt")
        input_ids = inputs["input_ids"][:, -max_length:].to(device)
        torch.cuda.empty_cache()
        return input_ids, text
    else:
        return None

实际上这就是典型的滑动窗口的做法,滑窗�的存在保证了 GPU 的显存不会溢出,但是由于上节的讨论,会存在精度损失。

89f51d1c-9125-11ee-939d-92fbcf53809c.jpg

8a000696-9125-11ee-939d-92fbcf53809c.png

审核编辑:黄飞

声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉
  • gpu
    gpu
    +关注

    关注

    27

    文章

    4426

    浏览量

    126742
  • GPT
    GPT
    +关注

    关注

    0

    文章

    302

    浏览量

    14876
  • LLM
    LLM
    +关注

    关注

    0

    文章

    203

    浏览量

    234

原文标题:漫谈 KV Cache 优化方法,深度理解 StreamingLLM

文章出处:【微信号:GiantPandaCV,微信公众号:GiantPandaCV】欢迎添加关注!文章转载请注明出处。

收藏 人收藏

    评论

    相关推荐

    内存分配及Cache优化

    C6000的芯片支持库CSL中的CACHE-setL2Mode函数,将L2设置为198KB的SRAM和64KB的Cache模式。并根据H.264算法本身的结构,采取以下方法对存储器进行优化
    发表于 08-10 14:54

    如何理解C6678中关于cache的描述?

    在TMS320C6678中,有这样对cache的描述:“L1D memory cannot be cached within L1D cache, L1P cache, or L2 cache
    发表于 06-21 16:07

    请教关于EDMA和cache优化的疑惑

    hi,everyone:经常看到网上说,EDMA算法优化,在片上L2SRAM 中开辟内存,将片外数据从DDR或SDRAM 利用EDMA搬运到L2SRAM中。但是, 我有两点疑惑:1.我觉得这种方法
    发表于 07-27 09:38

    使用CACHE_disableCaching函数禁止cache没起作用

    CACHE_getMemRegionInfo (129, &pcx, &pfx); 读取pcx的值 仍然是1,所以没起作用。怀疑是当前模式是user mode,而修改MAR寄存器需要
    发表于 12-28 11:12

    Cache为什么还要分I-Cache,D-Cache,L2 Cache作用是什么?

    Cache为什么还要分I-Cache,D-Cache,L2 Cache作用是什么?
    发表于 10-25 06:38

    Cache中Tag电路的设计

    摘要:在SoC系统中,片上缓存(Cache)的采用是解决片上处理器和片外存储器之间速度差异的重要方法Cache中用来存储标记位并判断Cache是否命中的Tag电路的设计将会影响到整个
    发表于 05-08 09:26 11次下载

    降低Cache失效率的方法[1]

    降低Cache失效率的方法[1]  学习目标:     理解失效的三种类型(3C);
    发表于 04-13 16:32 4074次阅读

    降低Cache失效率的方法[2]

    降低Cache失效率的方法[2] 表4.7列出了在这两种极端情况之间的各种块大小和各种 Cache 容量的平均访存时间。速度最快的情况: Cache 容量为1KB、4KB、1
    发表于 04-13 16:33 4723次阅读

    一种基于贝叶斯网络的随机测试方法Cache一致性验证中的设计与实现

    基于贝叶斯网络的随机测试生成方法,解决Cache -致性协议状态空间爆炸的问题。首先分析了Cache -致性协议及基于贝叶斯网络推理的CDG方法,并将CDG
    发表于 11-17 17:24 2次下载
    一种基于贝叶斯网络的随机测试<b class='flag-5'>方法</b>在<b class='flag-5'>Cache</b>一致性验证中的设计与实现

    一种有效的Cache优化替换策略

    该问题,一种有效的解决方法优化Cache替换策略,减少Cache中脏块被替换出的数量。现有研究主要通过在插入和访问命中时给脏块设定较高的保护优先级来达到给脏块额外保护的目的,但是在降
    发表于 11-27 15:16 1次下载
    一种有效的<b class='flag-5'>Cache</b><b class='flag-5'>优化</b>替换策略

    Page Cache是什么 一文带你深入理解Linux的Page Cache

    是什么? 为了理解 Page Cache,我们不妨先看一下 Linux 的文件 I/O 系统,如下图所示: Figure1. Linux 文件 I/O 系统 上图中,红色部分为 Page Cache。可见 Page
    的头像 发表于 10-20 14:12 5427次阅读
    Page <b class='flag-5'>Cache</b>是什么 一文带你深入<b class='flag-5'>理解</b>Linux的Page <b class='flag-5'>Cache</b>

    从三个方面阐述Cache

    (directmapped),全相连(fullyassociative),组相连(setassociative)。 为了便于理解,现在假设一个例子,比如咱们的内存只有16bytes,而cache只有4bytes
    的头像 发表于 11-21 11:09 2189次阅读

    什么是 Cache? Cache读写原理

    由于写入数据和读取指令分别通过 D-Cache 和 I-Cache,所以需要同步 D-Cache 和 I-Cache,即复制后需要先将 D-Cach
    发表于 12-06 09:55 1270次阅读

    深入理解Cache工作原理

    按照数据关系划分:Inclusive/exclusive Cache: 下级Cache包含上级的数据叫inclusive Cache。不包含叫exclusive Cache。举个例子,
    的头像 发表于 05-30 16:02 471次阅读
    深入<b class='flag-5'>理解</b><b class='flag-5'>Cache</b>工作原理

    Cache分类与替换算法

    根据不同的分类标准可以按以下3种方法Cache进行分类。 •1)数据cache和指令cache •● 指令cache:指令预取时使用的
    的头像 发表于 10-31 11:26 469次阅读
    <b class='flag-5'>Cache</b>分类与替换算法