0
  • 聊天消息
  • 系统消息
  • 评论与回复
登录后你可以
  • 下载海量资料
  • 学习在线课程
  • 观看技术视频
  • 写文章/发帖/加入社区
创作中心

完善资料让更多小伙伴认识你,还能领取20积分哦,立即完善>

3天内不再提示

如何计算transformer模型的参数量

jf_pmFSk4VX 来源:GiantPandaCV 2023-07-10 09:13 次阅读

1. 前言

最近,OpenAI推出的ChatGPT展现出了卓越的性能,引发了大规模语言模型(Large Language Model,LLM)的研究热潮。大规模语言模型的“大”体现在两个方面:模型参数规模大,训练数据规模大。以GPT3为例,GPT3的参数量为1750亿,训练数据量达到了570GB。进而,训练大规模语言模型面临两个主要挑战:显存效率和计算效率。

现在业界的大语言模型都是基于transformer模型的,模型结构主要有两大类:encoder-decoder(代表模型是T5)和decoder-only,具体的,decoder-only结构又可以分为Causal LM(代表模型是GPT系列)和PrefixLM(代表模型是GLM)。归因于GPT系列取得的巨大成功,大多数的主流大语言模型都采用Causal LM结构。因此,针对decoder-only框架,为了更好地理解训练训练大语言模型的显存效率和计算效率,本文分析采用decoder-only框架transformer模型的模型参数量、计算量、中间激活值、KV cache。

853e25d8-1d86-11ee-962d-dac502259ad0.jpg

为了方便分析,先定义好一些数学符号。记transformer模型的层数为8568a3ee-1d86-11ee-962d-dac502259ad0.png ,隐藏层维度为85808590-1d86-11ee-962d-dac502259ad0.png ,注意力头数为8597e2c6-1d86-11ee-962d-dac502259ad0.png 。词表大小为85af3034-1d86-11ee-962d-dac502259ad0.png ,训练数据的批次大小为85c04e64-1d86-11ee-962d-dac502259ad0.png ,序列长度为85cf59e0-1d86-11ee-962d-dac502259ad0.png 。

2. 模型参数量

transformer模型由8568a3ee-1d86-11ee-962d-dac502259ad0.png个相同的层组成,每个层分为两部分:self-attention块和MLP块。

self-attention块的模型参数有85f2dcbc-1d86-11ee-962d-dac502259ad0.png 的权重矩阵860b6cb4-1d86-11ee-962d-dac502259ad0.png和偏置,输出权重矩阵 861f975c-1d86-11ee-962d-dac502259ad0.png 和偏置,4个权重矩阵的形状为86324bae-1d86-11ee-962d-dac502259ad0.png ,4个偏置的形状为8645eff6-1d86-11ee-962d-dac502259ad0.png 。self- attention块的参数量为8657ee68-1d86-11ee-962d-dac502259ad0.png 。

MLP块由2个线性层组成,一般地,第一个线性层是先将维度从85808590-1d86-11ee-962d-dac502259ad0.png 映射到867afe44-1d86-11ee-962d-dac502259ad0.png,第二个线性层再将维度从867afe44-1d86-11ee-962d-dac502259ad0.png映射到85808590-1d86-11ee-962d-dac502259ad0.png。第一个线性层的权重矩阵86ac99f4-1d86-11ee-962d-dac502259ad0.png 的形状为86c0c348-1d86-11ee-962d-dac502259ad0.png ,偏置的形状为86d6bf18-1d86-11ee-962d-dac502259ad0.png 。第二个线性层权重矩阵86e7af30-1d86-11ee-962d-dac502259ad0.png 的形状为86fae0c8-1d86-11ee-962d-dac502259ad0.png ,偏置形状为8645eff6-1d86-11ee-962d-dac502259ad0.png 。MLP块的参数量为872bf654-1d86-11ee-962d-dac502259ad0.png 。

self-attention块和MLP块各有一个layer normalization,包含了2个可训练模型参数:缩放参数873cc984-1d86-11ee-962d-dac502259ad0.png 和平移参数8753cc42-1d86-11ee-962d-dac502259ad0.png ,形状都是8645eff6-1d86-11ee-962d-dac502259ad0.png 。2个layernormalization的参数量为 867afe44-1d86-11ee-962d-dac502259ad0.png 。

87817ce6-1d86-11ee-962d-dac502259ad0.jpg

总的,每个transformer层的参数量879c4148-1d86-11ee-962d-dac502259ad0.png 。

除此之外,词嵌入矩阵的参数量也较多,词向量维度通常等于隐藏层维度85808590-1d86-11ee-962d-dac502259ad0.png,词嵌入矩阵的参数量为 87c03350-1d86-11ee-962d-dac502259ad0.png。最后的输出层的权重矩阵通常与词嵌入矩阵是参数共享的。

关于位置编码,如果采用可训练式的位置编码,会有一些可训练模型参数,数量比较少。如果采用相对位置编码,例如RoPE和ALiBi,则不包含可训练的模型参数。我们忽略这部分参数。

综上,8568a3ee-1d86-11ee-962d-dac502259ad0.png层transformer模型的可训练模型参数量为87d9da76-1d86-11ee-962d-dac502259ad0.png。当隐藏维度 85808590-1d86-11ee-962d-dac502259ad0.png 较大时,可以忽略一次项, 模型参数量近似为8803fd6a-1d86-11ee-962d-dac502259ad0.png 。

接下来,我们估计不同版本LLaMA模型的参数量。

实际参数量 隐藏维度h 层数l 12lh^2
6.7B 4096 32 6,442,450,944
13.0B 5120 40 12,582,912,000
32.5B 6656 60 31,897,681,920
65.2B 8192 80 64,424,509,440

2.1 训练过程中的显存占用分析

在训练神经网络的过程中,占用显存的大头主要分为四部分:模型参数、前向计算过程中产生的中间激活、后向传递计算得到的梯度、优化器状态。这里着重分析参数、梯度和优化器状态的显存占用,中间激活的显存占用后面会详细介绍。训练大模型时通常会采用AdamW优化器,并用混合精度训练来加速训练,基于这个前提分析显存占用。

在一次训练迭代中,每个可训练模型参数都会对应1个梯度,并对应2个优化器状态(Adam优化器梯度的一阶动量和二阶动量)。设模型参数量为881c3236-1d86-11ee-962d-dac502259ad0.png ,那么梯度的元素数量为881c3236-1d86-11ee-962d-dac502259ad0.png ,AdamW优化器的元素数量为8841de5a-1d86-11ee-962d-dac502259ad0.png。float16数据类型的元素占2个bytes,float32数据类型的元素占4个bytes。在混合精度训练中,会使用float16的模型参数进行前向传递和后向传递,计算得到float16的梯度;在优化器更新模型参数时,会使用float32的优化器状态、float32的梯度、float32的模型参数来更新模型参数。因此,对于每个可训练模型参数,占用了88581f9e-1d86-11ee-962d-dac502259ad0.png。使用AdamW优化器和混合精度训练来训练参数量为 881c3236-1d86-11ee-962d-dac502259ad0.png的大模型, 模型参数、梯度和优化器状态占用的显存大小为887f5154-1d86-11ee-962d-dac502259ad0.png 。

8892c3c4-1d86-11ee-962d-dac502259ad0.jpg

2.2 推理过程中的显存占用分析

在神经网络的推理阶段,没有优化器状态和梯度,也不需要保存中间激活。少了梯度、优化器状态、中间激活,模型推理阶段占用的显存要远小于训练阶段。模型推理阶段,占用显存的大头主要是模型参数,如果使用float16来进行推理,推理阶段模型参数占用的显存大概是88b124fe-1d86-11ee-962d-dac502259ad0.png 。如果使用KVcache来加速推理过程, KV cache也需要占用显存,KVcache占用的显存下文会详细介绍。此外,输入数据也需要放到GPU上,还有一些中间结果(推理过程中的中间结果用完会尽快释放掉),不过这部分占用的显存是很小的,可以忽略。

3. 计算量FLOPs估计

FLOPs,floating point operations,表示浮点数运算次数,衡量了计算量的大小。

如何计算矩阵乘法的FLOPs呢?

对于88c2f724-1d86-11ee-962d-dac502259ad0.png ,计算 88d9729c-1d86-11ee-962d-dac502259ad0.png 需要进行 88f02cc6-1d86-11ee-962d-dac502259ad0.png 次乘法运算和 88f02cc6-1d86-11ee-962d-dac502259ad0.png 次加法运算,共计 8913769a-1d86-11ee-962d-dac502259ad0.png 次浮点数运算,需要 8913769a-1d86-11ee-962d-dac502259ad0.png 的FLOPs。对于 893c3c60-1d86-11ee-962d-dac502259ad0.png ,计算 88d9729c-1d86-11ee-962d-dac502259ad0.png 需要的浮点数运算次数为 8962e2fc-1d86-11ee-962d-dac502259ad0.png 。

在一次训练迭代中,假设输入数据的形状为897c4b84-1d86-11ee-962d-dac502259ad0.png 。我们 先分析self-attention块的计算,计算公式如下:

89962dd8-1d86-11ee-962d-dac502259ad0.png89a87cb8-1d86-11ee-962d-dac502259ad0.png

1. 计算89bccee8-1d86-11ee-962d-dac502259ad0.png :矩阵乘法的输入和输出形状为89d2b0f0-1d86-11ee-962d-dac502259ad0.png。计算量为89e69084-1d86-11ee-962d-dac502259ad0.png

2.89fdc10a-1d86-11ee-962d-dac502259ad0.png 矩阵乘法的输入和输出形状为

8a0cb43a-1d86-11ee-962d-dac502259ad0.png。计算量为 8a280b4a-1d86-11ee-962d-dac502259ad0.png 。

3. 计算在85af3034-1d86-11ee-962d-dac502259ad0.png 上的加权 8a4c4500-1d86-11ee-962d-dac502259ad0.png ,矩阵乘法的输入和输出形状为8a619a9a-1d86-11ee-962d-dac502259ad0.png。计算量为 8a280b4a-1d86-11ee-962d-dac502259ad0.png 。

4. attention后的线性映射,矩阵乘法的输入和输出形状为89d2b0f0-1d86-11ee-962d-dac502259ad0.png。计算量为 8a93189a-1d86-11ee-962d-dac502259ad0.png 。

接下来分析MLP块的计算,计算公式如下

8aaa25ee-1d86-11ee-962d-dac502259ad0.png

1. 第一个线性层,矩阵乘法的输入和输出形状为8ac3882c-1d86-11ee-962d-dac502259ad0.png。计算量为 8adb754a-1d86-11ee-962d-dac502259ad0.png 。

2. 第二个线性层,矩阵乘法的输入和输出形状为8af27f60-1d86-11ee-962d-dac502259ad0.png。计算量为 8adb754a-1d86-11ee-962d-dac502259ad0.png 。

将上述计算量相加,得到每个transformer层的计算量大约为8b1dfc44-1d86-11ee-962d-dac502259ad0.png 。

此外,另一个计算量的大头是logits的计算,将隐藏向量映射为词表大小。矩阵乘法的输入和输出形状为8b35ff06-1d86-11ee-962d-dac502259ad0.png,计算量为 8b4c2218-1d86-11ee-962d-dac502259ad0.png 。

因此,对于一个8568a3ee-1d86-11ee-962d-dac502259ad0.png 层的transformer模型,输入数据形状为897c4b84-1d86-11ee-962d-dac502259ad0.png 的情况下,一次训练迭代的计算量为8b7f76ea-1d86-11ee-962d-dac502259ad0.png

3.1 计算量与参数量的关联

当隐藏维度85808590-1d86-11ee-962d-dac502259ad0.png 比较大,且远大于序列长度85cf59e0-1d86-11ee-962d-dac502259ad0.png 时,我们可以忽略一次项,计算量可以近似为8bb25fce-1d86-11ee-962d-dac502259ad0.png 。前面提到当模型参数量为8803fd6a-1d86-11ee-962d-dac502259ad0.png ,输入的tokens数为8bd8c614-1d86-11ee-962d-dac502259ad0.png ,存在等式8bef6874-1d86-11ee-962d-dac502259ad0.png。我们可以近似认为: 在一次前向传递中,对于每个token,每个模型参数,需要进行2次浮点数运算,即一次乘法法运算和一次加法运算。

一次训练迭代包含了前向传递和后向传递,后向传递的计算量是前向传递的2倍。因此,前向传递 + 后向传递的系数8c064c1a-1d86-11ee-962d-dac502259ad0.png。一次训练迭代中,对于每个token,每个模型参数,需要进行8c185e50-1d86-11ee-962d-dac502259ad0.png 次浮点数运算。

接下来,我们可以估计训练GPT3-175B所需要的计算量。对于GPT3,每个token,每个参数进行了6次浮点数运算,再乘以参数量和总tokens数就得到了总的计算量。GPT3的模型参数量为8c29c1b8-1d86-11ee-962d-dac502259ad0.png ,训练数据量为 8c3c5f3a-1d86-11ee-962d-dac502259ad0.png tokens。

8c4efc26-1d86-11ee-962d-dac502259ad0.png

8c661cb2-1d86-11ee-962d-dac502259ad0.jpg

3.2 训练时间估计

模型参数量和训练总tokens数决定了训练transformer模型需要的计算量。给定硬件GPU类型的情况下,可以估计所需要的训练时间。给定计算量,训练时间(也就是GPU算完这么多flops的计算时间)不仅跟GPU类型有关,还与GPU利用率有关。计算端到端训练的GPU利用率时,不仅要考虑前向传递和后向传递的计算时间,还要**考虑CPU加载数据、优化器更新、多卡通信和记录日志的时间。一般来讲,GPU利用率一般在8c8a6fd6-1d86-11ee-962d-dac502259ad0.png之间

上文讲到一次前向传递中,对于每个token,每个模型参数,进行2次浮点数计算。使用激活重计算技术来减少中间激活显存(下文会详细介绍)需要进行一次额外的前向传递,因此前向传递+ 后向传递 + 激活重计算的系数=1+2+1=4。使用激活重计算的一次训练迭代中,对于每个token,每个模型参数,需要进行8c9e60b8-1d86-11ee-962d-dac502259ad0.png 次浮点数运算。在给定训练tokens数、硬件环境配置的情况下,训练transformer模型的计算时间为

8cb12194-1d86-11ee-962d-dac502259ad0.png

8cc7905a-1d86-11ee-962d-dac502259ad0.jpg

以GPT3-175B为例,在1024张40GB显存的A100上,在300Btokens的数据上训练175B参数量的GPT3。40GB显存A100的峰值性能为312TFLOPS,设GPU利用率为0.45,则所需要的训练时间为34天,这与[7]中的训练时间是对得上的

8cee6784-1d86-11ee-962d-dac502259ad0.png

以LLaMA-65B为例,在2048张80GB显存的A100上,在1.4TBtokens的数据上训练了65B参数量的模型。80GB显存A100的峰值性能为624TFLOPS,设GPU利用率为0.3,则所需要的训练时间为21天,这与[4]中的实际训练时间是对得上的

8d05f390-1d86-11ee-962d-dac502259ad0.png

4. 中间激活值分析

除了模型参数、梯度、优化器状态外,占用显存的大头就是前向传递过程中计算得到的中间激活值了,需要保存中间激活以便在后向传递计算梯度时使用。这里的激活(activations)指的是:前向传递过程中计算得到的,并在后向传递过程中需要用到的所有张量。这里的激活不包含模型参数和优化器状态,但包含了dropout操作需要用到的mask矩阵。

在分析中间激活的显存占用时,只考虑激活占用显存的大头,忽略掉一些小的buffers。比如,对于layernormalization,计算梯度时需要用到层的输入、输入的均值8d33dd50-1d86-11ee-962d-dac502259ad0.png 和方差8d45df46-1d86-11ee-962d-dac502259ad0.png 。输入包含了8d5b7ad6-1d86-11ee-962d-dac502259ad0.png 个元素,而输入的均值和方差分别包含了8bd8c614-1d86-11ee-962d-dac502259ad0.png 个元素。由于85808590-1d86-11ee-962d-dac502259ad0.png 通常是比较大的(千数量级),有 8d91f19c-1d86-11ee-962d-dac502259ad0.png 。因此,对于layernormalization,中间激活近似估计为 8d5b7ad6-1d86-11ee-962d-dac502259ad0.png ,而不是8db7052c-1d86-11ee-962d-dac502259ad0.png 。

大模型在训练过程中通常采用混合精度训练,中间激活值一般是float16或者bfloat16数据类型的。在分析中间激活的显存占用时,假设中间激活值是以float16或bfloat16数据格式来保存的,每个元素占了2个bytes。唯一例外的是,dropout操作的mask矩阵,每个元素只占1个bytes。在下面的分析中,单位是bytes,而不是元素个数。

每个transformer层包含了一个self-attention块和MLP块,并分别对应了一个layer normalization连接。

先分析self-attention块的中间激活。self-attention块的计算公式如下:

89962dd8-1d86-11ee-962d-dac502259ad0.png

89a87cb8-1d86-11ee-962d-dac502259ad0.png

1. 对于89bccee8-1d86-11ee-962d-dac502259ad0.png ,需要保存它们共同的输入8df04ff8-1d86-11ee-962d-dac502259ad0.png ,这就是中间激活。输入8df04ff8-1d86-11ee-962d-dac502259ad0.png 的形状为8e13c38e-1d86-11ee-962d-dac502259ad0.png ,元素个数为8d5b7ad6-1d86-11ee-962d-dac502259ad0.png ,占用显存大小为8e323ab2-1d86-11ee-962d-dac502259ad0.png 。

2. 对于89fdc10a-1d86-11ee-962d-dac502259ad0.png 矩阵乘法,需要保存中间激活8e548176-1d86-11ee-962d-dac502259ad0.png ,两个张量的形状都是8e13c38e-1d86-11ee-962d-dac502259ad0.png ,占用显存大小合计为8e754262-1d86-11ee-962d-dac502259ad0.png 。

3. 对于8e8d8cc8-1d86-11ee-962d-dac502259ad0.png函数,需要保存函数的输入 89fdc10a-1d86-11ee-962d-dac502259ad0.png ,占用显存大小为8eaf0420-1d86-11ee-962d-dac502259ad0.png ,这里的8597e2c6-1d86-11ee-962d-dac502259ad0.png 表示注意力头数。

8ed2bb04-1d86-11ee-962d-dac502259ad0.png

8ee428ee-1d86-11ee-962d-dac502259ad0.png 的形状为: 8ef662e8-1d86-11ee-962d-dac502259ad0.png

8f0ab716-1d86-11ee-962d-dac502259ad0.png 的形状为:8f2396a0-1d86-11ee-962d-dac502259ad0.png

89fdc10a-1d86-11ee-962d-dac502259ad0.png 的形状为:8f445c96-1d86-11ee-962d-dac502259ad0.png,元素个数为 8f5a3d36-1d86-11ee-962d-dac502259ad0.png ,占用显存大小为8eaf0420-1d86-11ee-962d-dac502259ad0.png 。

4. 计算完8e8d8cc8-1d86-11ee-962d-dac502259ad0.png函数后,会进行dropout操作。需要保存一个mask矩阵,mask矩阵的形状与89fdc10a-1d86-11ee-962d-dac502259ad0.png 相同,占用显存大小为8f5a3d36-1d86-11ee-962d-dac502259ad0.png 。

5. 计算在85af3034-1d86-11ee-962d-dac502259ad0.png 上的attention,即 8a4c4500-1d86-11ee-962d-dac502259ad0.png ,需要保存8fed2d6c-1d86-11ee-962d-dac502259ad0.png ,大小为8eaf0420-1d86-11ee-962d-dac502259ad0.png ;以及85af3034-1d86-11ee-962d-dac502259ad0.png ,大小为90275a96-1d86-11ee-962d-dac502259ad0.png 。二者占用显存大小合计为90482d70-1d86-11ee-962d-dac502259ad0.png 。

6. 计算输出映射以及一个dropout操作。输入映射需要保存其输入,大小为90275a96-1d86-11ee-962d-dac502259ad0.png ;dropout需要保存mask矩阵,大小为8d5b7ad6-1d86-11ee-962d-dac502259ad0.png 。二者占用显存大小合计为90a0671a-1d86-11ee-962d-dac502259ad0.png 。

因此,将上述中间激活相加得到,self-attention块的中间激活占用显存大小为90b20fce-1d86-11ee-962d-dac502259ad0.png 。

接下来看MLP块的中间激活。MLP块的计算公式如下

8aaa25ee-1d86-11ee-962d-dac502259ad0.png

1. 第一个线性层需要保存其输入,占用显存大小为90275a96-1d86-11ee-962d-dac502259ad0.png 。

2. 激活函数需要保存其输入,占用显存大小为90dd5c56-1d86-11ee-962d-dac502259ad0.png 。

3. 第二个线性层需要保存其输入,占用显存大小为90dd5c56-1d86-11ee-962d-dac502259ad0.png 。

4. 最后有一个dropout操作,需要保存mask矩阵,占用显存大小为8d5b7ad6-1d86-11ee-962d-dac502259ad0.png 。

对于MLP块,需要保存的中间激活值为910fbe12-1d86-11ee-962d-dac502259ad0.png 。

另外,self-attention块和MLP块分别对应了一个layer normalization。每个layer norm需要保存其输入,大小为90275a96-1d86-11ee-962d-dac502259ad0.png 。2个layer norm需要保存的中间激活为912fd2e2-1d86-11ee-962d-dac502259ad0.png 。

综上,每个transformer层需要保存的中间激活占用显存大小为91429bf2-1d86-11ee-962d-dac502259ad0.png 。对于8568a3ee-1d86-11ee-962d-dac502259ad0.png层transformer模型,还有embedding层、最后的输出层。embedding层不需要中间激活。总的而言,当隐藏维度85808590-1d86-11ee-962d-dac502259ad0.png 比较大,层数8568a3ee-1d86-11ee-962d-dac502259ad0.png 较深时,这部分的中间激活是很少的,可以忽略。因此,对于8568a3ee-1d86-11ee-962d-dac502259ad0.png 层transformer模型,中间激活占用的显存大小可以近似为918e2cd4-1d86-11ee-962d-dac502259ad0.png

4.1 对比中间激活与模型参数的显存大小

在一次训练迭代中,模型参数(或梯度)占用的显存大小只与模型参数量和参数数据类型有关,与输入数据的大小是没有关系的。优化器状态占用的显存大小也是一样,与优化器类型有关,与模型参数量有关,但与输入数据的大小无关。而中间激活值与输入数据的大小(批次大小85c04e64-1d86-11ee-962d-dac502259ad0.png 和序列长度85cf59e0-1d86-11ee-962d-dac502259ad0.png )是成正相关的,随着批次大小85c04e64-1d86-11ee-962d-dac502259ad0.png 和序列长度85cf59e0-1d86-11ee-962d-dac502259ad0.png的增大,中间激活占用的显存会同步增大。当我们训练神经网络遇到显存不足OOM(Out OfMemory)问题时,通常会尝试减小批次大小来避免显存不足的问题,这种方式减少的其实是中间激活占用的显存,而不是模型参数、梯度和优化器的显存。

以GPT3-175B为例,我们来直观地对比下模型参数与中间激活的显存大小。GPT3的模型配置如下。我们假设采用混合精度训练,模型参数和中间激活都采用float16数据类型,每个元素占2个bytes。

模型名 参数量 层数 隐藏维度 注意力头数
GPT3 175B 96 12288 96

GPT3的模型参数量为175B,占用的显存大小为91e94efc-1d86-11ee-962d-dac502259ad0.png。GPT3模型需要占用350GB的显存。

GPT3的序列长度85cf59e0-1d86-11ee-962d-dac502259ad0.png 为920b2784-1d86-11ee-962d-dac502259ad0.png 。对比不同的批次大小85c04e64-1d86-11ee-962d-dac502259ad0.png 占用的中间激活:

922b29ee-1d86-11ee-962d-dac502259ad0.png 时,中间激活占用显存为92448182-1d86-11ee-962d-dac502259ad0.png,大约是模型参数显存的0.79倍。

925af32c-1d86-11ee-962d-dac502259ad0.png 时,中间激活占用显存为9271fd88-1d86-11ee-962d-dac502259ad0.png,大约是模型参数显存的50倍。

928c62ea-1d86-11ee-962d-dac502259ad0.png 时,中间激活占用显存为

929f7ba0-1d86-11ee-962d-dac502259ad0.png,大约是模型参数显存的101倍。

可以看到随着批次大小85c04e64-1d86-11ee-962d-dac502259ad0.png的增大,中间激活占用的显存远远超过了模型参数显存。通常会采用 激活重计算技术来减少中间激活,理论上可以将中间激活显存从92c0fb7c-1d86-11ee-962d-dac502259ad0.png 减少到92d78c98-1d86-11ee-962d-dac502259ad0.png,代价是增加了一次额外前向计算的时间,本质上是“时间换空间”。

5. KV cache

在推断阶段,transformer模型加速推断的一个常用策略就是使用 KV cache。一个典型的大模型生成式推断包含了两个阶段:

1.预填充阶段:输入一个prompt序列,为每个transformer层生成 key cache和value cache(KV cache)。

2.解码阶段:使用并更新KV cache,一个接一个地生成词,当前生成的词依赖于之前已经生成的词。

92ed9d6c-1d86-11ee-962d-dac502259ad0.png 个transformer层的权重矩阵为9304c852-1d86-11ee-962d-dac502259ad0.png。其中,self-attention块的4个权重矩阵 9319b14a-1d86-11ee-962d-dac502259ad0.png,并且MLP块的2个权重矩阵 93302484-1d86-11ee-962d-dac502259ad0.png

预填充阶段

假设第92ed9d6c-1d86-11ee-962d-dac502259ad0.png 个transformer层的输入为93515a78-1d86-11ee-962d-dac502259ad0.png ,self-attention块的key、value、query和output表示为93684648-1d86-11ee-962d-dac502259ad0.png,其中, 93822d4c-1d86-11ee-962d-dac502259ad0.png

key cache和value cache的计算过程为:

9398dd94-1d86-11ee-962d-dac502259ad0.png93af854e-1d86-11ee-962d-dac502259ad0.png

92ed9d6c-1d86-11ee-962d-dac502259ad0.png 个transformer层剩余的计算过程为:

93d233fa-1d86-11ee-962d-dac502259ad0.png93e27e72-1d86-11ee-962d-dac502259ad0.png93f8cbaa-1d86-11ee-962d-dac502259ad0.png

解码阶段

给定当前生成词在第92ed9d6c-1d86-11ee-962d-dac502259ad0.png 个transformer层的向量表示为9418a024-1d86-11ee-962d-dac502259ad0.png。推断计算分两部分:更新KV cache和计算第 92ed9d6c-1d86-11ee-962d-dac502259ad0.png个transformer层的输出。

更新key cache和value cache的计算过程如下:

943ab0a6-1d86-11ee-962d-dac502259ad0.png

94514e74-1d86-11ee-962d-dac502259ad0.png

92ed9d6c-1d86-11ee-962d-dac502259ad0.png 个transformer层剩余的计算过程为:

946e96e6-1d86-11ee-962d-dac502259ad0.png

9480832e-1d86-11ee-962d-dac502259ad0.png9492f5e0-1d86-11ee-962d-dac502259ad0.png

5.1 KV cache的显存占用分析

假设输入序列的长度为85cf59e0-1d86-11ee-962d-dac502259ad0.png ,输出序列的长度为88f02cc6-1d86-11ee-962d-dac502259ad0.png ,以float16来保存KV cache,那么 KVcache的峰值显存占用大小为94c74c82-1d86-11ee-962d-dac502259ad0.png。这里第一个2表示K/V cache,第二个2表示float16占2个bytes。

以GPT3为例,对比KV cache与模型参数占用显存的大小。GPT3模型占用显存大小为350GB。假设批次大小925af32c-1d86-11ee-962d-dac502259ad0.png ,输入序列长度94ee1858-1d86-11ee-962d-dac502259ad0.png ,输出序列长度95050914-1d86-11ee-962d-dac502259ad0.png ,则KV cache占用显存为951b99e0-1d86-11ee-962d-dac502259ad0.png,大约是模型参数显存的0.5倍。

6. 总结

本文首先介绍了如何计算transformer模型的参数量,基于参数量可以进一步估计模型参数、梯度和优化器状态占用的显存大小。接着,本文估计了训练迭代中,在给定训练tokens数的情况下transformer模型的计算量,给予计算量和显卡性能可以进一步估计训练迭代的计算耗时。然后,本文分析了transformer模型前向计算过程中产生的中间激活值的显存大小,中间激活的显存大小与输入数据大小正相关,甚至会远超过模型参数占用的显存。最后,本文介绍了transformer模型推理过程常用的加速策略:使用KVcache。总的来说,分析transformer模型的参数量、计算量、中间激活和KV cache,有助于理解大模型训练和推断过程中的显存效率和计算效率。

声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉
  • 模型
    +关注

    关注

    1

    文章

    2704

    浏览量

    47687
  • Transformer
    +关注

    关注

    0

    文章

    130

    浏览量

    5898
  • ChatGPT
    +关注

    关注

    27

    文章

    1408

    浏览量

    4745

原文标题:分析transformer模型的参数量、计算量、中间激活、KV cache

文章出处:【微信号:GiantPandaCV,微信公众号:GiantPandaCV】欢迎添加关注!文章转载请注明出处。

收藏 人收藏

    评论

    相关推荐

    基于卷积的基础模型InternImage网络技术分析

    近年来大规模视觉 Transformer 的蓬勃发展推动了计算机视觉领域的性能边界。视觉 Transformer 模型通过扩大模型
    发表于 11-18 10:49 493次阅读
    基于卷积的基础<b class='flag-5'>模型</b>InternImage网络技术分析

    大语言模型背后的Transformer,与CNN和RNN有何不同

      电子发烧友网报道(文/李弯弯)近年来,随着大语言模型的不断出圈,Transformer这一概念也走进了大众视野。Transformer是一种非常流行的深度学习模型,最早于2017年
    的头像 发表于 12-25 08:36 1576次阅读
    大语言<b class='flag-5'>模型</b>背后的<b class='flag-5'>Transformer</b>,与CNN和RNN有何不同

    你了解在单GPU上就可以运行的Transformer模型

    上一步也跑不了,因为它们的内存需求太大了。例如,完整的GPT-2模型大约包含1.5B参数。最大配置的参数数量超过每层0.5B,而层数有64 层。图2:标准Transformer
    发表于 11-02 15:19

    Google科学家设计简化稀疏架构Switch Transformer,语言模型参数量可扩展至 1.6 万亿

    刚刚,Google Brain 高级研究科学家 Barret Zoph 发帖表示,他们设计了一个名叫「Switch Transformer」的简化稀疏架构,可以将语言模型参数量扩展至 1.6 万亿
    的头像 发表于 01-13 16:50 2697次阅读

    一个GPU训练一个130亿参数模型

    现在的模型动辄数百、数千亿参数,普通人训不动怎么办? 前不久,谷歌发布了参数量为 1.6 万亿的语言模型Swith Transformer
    的头像 发表于 02-11 09:04 2213次阅读
    一个GPU训练一个130亿<b class='flag-5'>参数</b>的<b class='flag-5'>模型</b>

    Transformer模型的多模态学习应用

    随着Transformer在视觉中的崛起,Transformer在多模态中应用也是合情合理的事情,甚至以后可能会有更多的类似的paper。
    的头像 发表于 03-25 09:29 9939次阅读
    <b class='flag-5'>Transformer</b><b class='flag-5'>模型</b>的多模态学习应用

    使用跨界模型Transformer来做物体检测!

    用了Transformer 架构开发的一个目标检测模型。在这篇文章中,我将通过分析DETR架构的内部工作方式来帮助提供一些关于它的直觉。 下面,我将解释一些结构,但是如果你只是想了解如何使用模型,可以直接跳到代码部分
    的头像 发表于 06-10 16:04 1949次阅读
    使用跨界<b class='flag-5'>模型</b><b class='flag-5'>Transformer</b>来做物体检测!

    超大Transformer语言模型的分布式训练框架

    模型的预训练计算。 大模型是大势所趋 近年来,NLP 模型的发展十分迅速,模型的大小每年以1-2个数量
    的头像 发表于 10-11 16:46 2269次阅读
    超大<b class='flag-5'>Transformer</b>语言<b class='flag-5'>模型</b>的分布式训练框架

    Microsoft使用NVIDIA Triton加速AI Transformer模型应用

    Microsoft 的目标是,通过结合使用 Azure 与 NVIDIA GPU 和 Triton 推理软件,率先将一系列强大的 AI Transformer 模型投入生产用途。
    的头像 发表于 04-02 13:04 1490次阅读

    一种显著降低Transformer计算量的轻量化方法

    然而,transformer的原始公式在输入令牌(token)数量方面具有二次计算复杂度。鉴于这个数字通常从图像分类的14^2到图像去噪的128^2 = 16K不等,内存和计算的这一限
    的头像 发表于 01-10 14:12 890次阅读

    在X3派上玩转一亿参数量超大Transformer,DIY专属你的离线语音识别

    Transformer模型在自然语言领域被提出后,目前已经扩展到了计算机视觉、语音等诸多领域。然而,虽然Transformer模型在语音识别
    的头像 发表于 02-21 16:08 535次阅读
    在X3派上玩转一亿<b class='flag-5'>参数量</b>超大<b class='flag-5'>Transformer</b>,DIY专属你的离线语音识别

    transformer模型详解:Transformer 模型的压缩方法

     动机&背景 Transformer 模型在各种自然语言任务中取得了显著的成果,但内存和计算资源的瓶颈阻碍了其实用化部署。低秩近似和结构化剪枝是缓解这一瓶颈的主流方法。然而,作者通过分析发现,结构化
    的头像 发表于 07-17 10:50 1345次阅读
    <b class='flag-5'>transformer</b><b class='flag-5'>模型</b>详解:<b class='flag-5'>Transformer</b> <b class='flag-5'>模型</b>的压缩方法

    盘古大模型参数量有多少

    盘古大模型参数量有多少 盘古大模型(PanGu-α)是由中国科学院计算技术研究所提供的一种语言生成预训练模型。该
    的头像 发表于 08-17 11:28 2053次阅读

    盘古大模型与ChatGPT的模型基础架构

    华为盘古大模型Transformer模型架构为基础,利用深层学习技术进行训练。模型的每个数量达到2.6亿个,是目前世界上最大的汉语预备训练
    的头像 发表于 09-05 09:55 1431次阅读

    基于Transformer模型的压缩方法

    基于Transformer架构的大型模型在人工智能领域中发挥着日益重要的作用,特别是在自然语言处理(NLP)和计算机视觉(CV)领域。
    的头像 发表于 02-22 16:27 277次阅读
    基于<b class='flag-5'>Transformer</b><b class='flag-5'>模型</b>的压缩方法