0
  • 聊天消息
  • 系统消息
  • 评论与回复
登录后你可以
  • 下载海量资料
  • 学习在线课程
  • 观看技术视频
  • 写文章/发帖/加入社区
创作中心

完善资料让更多小伙伴认识你,还能领取20积分哦,立即完善>

3天内不再提示

LLaMA微调显存需求减半,清华提出4比特优化器

深度学习自然语言处理 来源:机器之心 2023-09-11 16:08 次阅读

大模型的训练和微调对显存要求很高,优化器状态是显存主要开销之一。近日,清华大学朱军、陈键飞团队提出了用于神经网络训练的 4 比特优化器,节省了模型训练的内存开销,同时能达到与全精度优化器相当的准确率。

4 比特优化器在众多预训练和微调任务上进行了实验,在保持准确率无损的情况下可将微调 LLaMA-7B 的显存开销降低多达 57%。

论文:https://arxiv.org/abs/2309.01507

代码:https://github.com/thu-ml/low-bit-optimizers

模型训练的内存瓶颈

从 GPT-3,Gopher 到 LLaMA,大模型有更好的性能已成为业界的共识。但相比之下,单个 GPU 的显存大小却增长缓慢,这让显存成为了大模型训练的主要瓶颈,如何在有限的 GPU 内存下训练大模型成为了一个重要的难题。

为此,我们首先需要明确消耗显存的来源有哪些。事实上来源有三类,分别是:

1. 「数据显存」,包括输入的数据和神经网络每层输出的激活值,它的大小直接受到 batch size 以及图像分辨率 / 上下文长度的影响;

2. 「模型显存」,包括模型参数,梯度,以及优化器状态(optimizer states),它的大小与模型参数数量呈正比;

3. 「临时显存」,包括 GPU kernel 计算时用到的临时内存和其他缓存等。随着模型规模的增大,模型显存的占比逐渐增大,成为主要瓶颈。

优化器状态的大小由使用哪种优化器决定。当前,训练 Transformer 往往使用 AdamW 优化器,它们在训练过程中需要存储并更新两个优化器状态,即一阶和二阶矩(first and second moments)。如果模型参数量为 N,那么 AdamW 中优化器状态的数量为 2N,这显然是一笔极大的显存开销。

以 LLaMA-7B 为例,该模型含的参数数量大约 7B,如果使用全精度(32 比特)的 AdamW 优化器对它进行微调,那么优化器状态所占用的显存大小约为 52.2GB。此外,虽然朴素的 SGD 优化器不需要额外状态,节省了优化器状态所占用的内存,但是模型的性能难以保证。因此,本文主要关注如何减少模型内存中的优化器状态,同时保证优化器的性能不受损。

节省优化器内存的方法

目前在训练算法方面,节省优化器显存开销的方法主要有三类:

1. 通过低秩分解(Factorization)的思路对优化器状态进行低秩近似(low-rank approximation);

2. 通过只训练一小部分参数来避免保存大多数的优化器状态,例如 LoRA

3. 基于压缩 (compression)的方法,使用低精度数值格式来表示优化器状态。

特别的,Dettmers et al. (ICLR 2022)针对 SGD with momentum 和 AdamW 提出了相应的 8 比特优化器,通过使用分块量化(block-wise quantization)和动态指数数值格式(dynamic exponential numerical format)的技术,在语言建模、图像分类、自监督学习、机器翻译等任务上达到了与原有的全精度优化器相匹配的效果。

本文在基础上,将优化器状态的数值精度进一步降低至 4 比特,提出了针对不同优化器状态的量化方法,最终提出了 4 比特 AdamW 优化器。同时,本文探索了将 压缩和低秩分解方法结合的可能性,提出了 4 比特 Factor 优化器,这种混合式的优化器同时享有好的性能和更好的内存高效性。本文在众多经典的任务上对 4 比特优化器进行了评估,包括自然语言理解、图像分类、机器翻译和大模型的指令微调。

在所有的任务上,4 比特优化器达到了与全精度优化器可比的效果,同时能够占用更少的内存。

问题设置

基于压缩的内存高效优化器的框架

首先,我们需要了解如何将压缩操作引入到通常使用的优化器中,这由算法 1 给出。其中,A 是一个基于梯度的优化器(例如 SGD 或 AdamW)。该优化器输入现有的参数 w,梯度 g 和优化器状态 s,输出新的参数和优化器状态。在算法 1 中,全精度的 s_t 是暂时存在的,而低精度的 (s_t ) ̅ 会持久地保存在 GPU 内存中。这种方式能够节省显存的重要原因是:神经网络的参数往往由每层的参数向量拼接而成。因此,优化器更新也是逐层 / 张量进行,进而在算法 1 下,最多只有一个参数的优化器状态以全精度的形式留在内存中,其他层对应的优化器状态都处于被压缩的状态。

93e8cc98-5079-11ee-a25d-92fbcf53809c.jpg

主要的压缩方法:量化(quantization)

量化是用低精度数值来表示高精度数据的技术,本文将量化的操作解耦为两部分:归一化(normalization)和映射(mapping),从而能够更加轻量级的设计并实验新的量化方法。归一化和映射两个操作依次以逐元素的形式施加在全精度数据上。归一化负责将张量中的每个元素投射到单位区间,其中张量归一化(per-tensor normalization)和分块归一化(block-wise normalization)分别如下定义:

940d3f56-5079-11ee-a25d-92fbcf53809c.jpg

不同归一化方法的粒度不同,处理异常值的能力会有所区别,同时带来的额外内存开销也不同。而映射(mapping)操作负责将归一化的数值映射到低精度能够表示的整数。正式地讲,给定位宽 b(即量化后每个数值使用 b 比特来表示)和预先定义的函数 T

9420d0de-5079-11ee-a25d-92fbcf53809c.jpg

映射操作被定义为:

9433bdac-5079-11ee-a25d-92fbcf53809c.jpg

因此,如何设计恰当的 T 对于减小量化误差有很重要的作用。本文主要考虑线性映射(linear)和动态指数映射(dynamic exponent)。最后,去量化的过程就是按顺序施加映射(mapping)和归一化(normalization)的逆算子。

一阶矩的压缩方法

以下主要针对 AdamW 的优化器状态(一阶矩和二阶矩)提出不同的量化方法。对于一阶矩,本文的量化方法主要基于 Dettmers et al. (ICLR 2022)的方法,使用分块归一化(块大小为 2048)和动态指数映射。

在初步的实验中,我们直接将位宽从 8 比特降低至 4 比特,发现一阶矩对于量化十分鲁棒,在很多任务上已经达到匹配的效果,但也在一部分任务上出现性能上的损失。为了进一步提高性能,我们仔细研究了一阶矩的模式,发现在单个张量中存在很多异常值。

此前的工作对于参数和激活值的异常值的模式已有一定的研究,参数的分布较为平滑,而激活值则具有按照 channel 分布的特点。本文发现,优化器状态中异常值的分布较为复杂,其中有些张量的异常值分布在固定的行,而另外一些张量的异常值分布在固定的列。

9441f174-5079-11ee-a25d-92fbcf53809c.jpg

对于异常值按列分布的张量,以行为优先的分块归一化可能会遇到困难。因此,本文提出采用更小的块,块大小为 128,这能够在减小量化误差的同时使额外的内存开销保持在可控的范围内。下图展示了不同块大小的量化误差。

9450b38a-5079-11ee-a25d-92fbcf53809c.jpg

二阶矩的压缩方法

与一阶矩相比,二阶矩的量化更加困难并且会带来训练的不稳定性。本文确定了零点问题是量化二阶矩的主要瓶颈,此外针对病态的异常值分布提出了改进的归一化方法:rank-1 normalization。本文也尝试了对二阶矩的分解方法(factorization)。

零点问题

在参数、激活值、梯度的量化中,零点往往是不可缺少的,并且在也是量化后频率最高的点。但是,在 Adam 的迭代公式中,更新的大小正比于二阶矩的 -1/2 次方,因此在零附近的范围内改变会极大影响更新的大小,进而造成不稳定。

9476a54a-5079-11ee-a25d-92fbcf53809c.jpg

下图以直方图的形式展示了量化前后 Adam 二阶矩 -1/2 次方的分布, 即 h (v)=1/(√v+10^(-6) )。如果将零点包括在内(图 b),那么大多数值都被推到了 10^6, 从而导致极大的近似误差。一个简单的办法是在动态指数映射中将零点移除,在这样做之后(图 c),对二阶矩的近似变得更加精确。在实际情况中,为了有效利用低精度数值的表达能力,我们提出采用移除零点的线性映射,在实验中取得了很好的效果。

94942002-5079-11ee-a25d-92fbcf53809c.jpg

Rank-1 归一化

基于一阶矩和二阶矩复杂的异常值分布,并受 SM3 优化器所启发,本文提出了一种新的归一化方法,命名为 rank-1 归一化。对一个非负的矩阵张量 x∈R^(n×m), 它的一维统计量定义为:

94ac37f0-5079-11ee-a25d-92fbcf53809c.jpg

进而 rank-1 归一化可以被定义为:

94f20c3a-5079-11ee-a25d-92fbcf53809c.jpg

rank-1 归一化以更细粒度的方式利用了张量的一维信息,能够更聪明且有效地处理按行分布或按列分布的异常值。此外,rank-1 归一化能够简单的推广到高维张量中,并且随着张量规模的增大,它所产生的额外内存开销要小于分块归一化。

此外,本文发现 Adafactor 优化器中对于二阶矩的低秩分解方法能够有效的避免零点问题,因此也对低秩分解和量化方法的结合进行了探索。下图展示了针对二阶矩的一系列消融实验,证实了零点问题是量化二阶矩的瓶颈,同时也验证了 rank-1 归一化,低秩分解方法的有效性。

94fd3f60-5079-11ee-a25d-92fbcf53809c.jpg

实验结果

研究根据所观察的现象和使用的方式,最终提出两种低精度优化器:4 比特 AdamW 和 4 比特 Factor,并与其他优化器进行对比,包括 8 比特 AdamW,Adafactor, SM3。研究选择在广泛的任务上进行评估,包括自然语言理解、图像分类、机器翻译和大模型的指令微调。下表展示了各优化器在不同任务上的表现。

951070a8-5079-11ee-a25d-92fbcf53809c.jpg  95311772-5079-11ee-a25d-92fbcf53809c.jpg

可以看到,在所有的微调任务上,包括 NLU,QA,NLG,4 比特优化器可以匹配甚至超过 32 比特 AdamW,同时在所有的预训练任务上,CLS,MT,4 比特优化器达到与全精度可比的水平。从指令微调的任务中可以看到,4 比特 AdamW 并不会破坏预训练模型的能力,同时能较好地使它们获得遵守指令的能力。

之后,我们测试了 4 比特优化器的内存和计算效率,结果如下表所示。相比 8 比特优化器,本文提出的 4 比特优化器能够节省更多内存,在 LLaMA-7B 微调的实验中最高节省 57.7%。此外,我们提供了 4 比特 AdamW 的融合算子版本,它能够在节省内存的同时不影响计算效率。对于 LLaMA-7B 的指令微调任务,由于缓存压力减小,4 比特 AdamW 也为训练带来了加速效果。详细的实验设置和结果可参考论文链接。

954124c8-5079-11ee-a25d-92fbcf53809c.jpg

替换一行代码即可在 PyTorch 中使用

importlpmm

optimizer=lpmm.optim.AdamW(model.parameters(),lr=1e-3,betas=(0.9,0.999))

我们提供了开箱即用的 4 比特优化器,只需要将原有的优化器替换为 4 比特优化器即可,目前支持 Adam 和 SGD 的低精度版本。同时,我们也提供了修改量化参数的接口,以支持定制化的使用场景。


声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉
  • 图像分类
    +关注

    关注

    0

    文章

    87

    浏览量

    11838
  • 机器翻译
    +关注

    关注

    0

    文章

    138

    浏览量

    14794
  • 大模型
    +关注

    关注

    2

    文章

    1532

    浏览量

    1130

原文标题:LLaMA微调显存需求减半,清华提出4比特优化器

文章出处:【微信号:zenRRan,微信公众号:深度学习自然语言处理】欢迎添加关注!文章转载请注明出处。

收藏 人收藏

    评论

    相关推荐

    【飞腾派4G版免费试用】仙女姐姐的嵌入式实验室之五~LLaMA.cpp及3B“小模型”OpenBuddy-StableLM-3B

    预训练语言模型。该模型最大的特点就是基于以较小的参数规模取得了优秀的性能,根据官网提供的信息,LLaMA的模型包含4个版本,最小的只有70亿参数,最大的650亿参数,但是其性能相比较之前的OPT
    发表于 12-22 10:18

    比特币的减半将如何影响整个加密行业

    历史不会重复,但有它的节奏。比特币在2012年和2016年曾经发生过两次减半事件。在过去的两次减半事件中,我们看到了某种模式的减半效应。当然,历史可以分析,但未来很难预测。未来的事情只
    发表于 07-16 10:52 755次阅读
    <b class='flag-5'>比特</b>币的<b class='flag-5'>减半</b>将如何影响整个加密行业

    比特币​2020年减半会发生什么

    在每一轮产量减半的前后比特币都会迎来一波牛市,这是比特币自身的驱动力,同时这也是比特币自身的生命周期。
    发表于 07-19 11:55 833次阅读

    比特减半是否会仍然迎来大牛市

    比特币诞生之初,中本聪制定了一条规则,规定矿工们每个区块可获得50个比特币奖励,大约在21万个区块交易完成后,这笔奖金就会减半。 现在10年过去了,自2008年比特币诞生以来,已
    发表于 07-29 11:11 801次阅读

    什么是挖矿奖励减半影响奖励减半的是什么

    挖矿奖励是比特币及其他仿比特币包括LTC、BCH等加密数字货币的唯一发行机制。中本聪在设计比特币的时候将每21万个区块(4年时间)设为一个梯度,对挖矿奖励进行减半处理。
    发表于 08-07 11:06 3745次阅读
    什么是挖矿奖励<b class='flag-5'>减半</b>影响奖励<b class='flag-5'>减半</b>的是什么

    比特币现金区块奖励减半将会带来怎样的结果

    从较早的减半计划可以看出,比特币现金网络经历了“裸挖”的时期,即比特币矿工转向比特币现金网络,可以更轻松地获取奖励。
    发表于 10-31 15:43 1104次阅读

    比特减半前价格上涨的具体原因分析

    加密货币专家中一种流行说法是比特币价格在减半前一年左右开始上涨。Blockchain研究负责人兼Mosaic联合创始人Garrick Hileman在2018年对《福布斯》解释说:“在过去两次减半发生的前几个月,我们看到
    发表于 11-20 11:41 1596次阅读

    如何验证比特币的减半效应

    比特减半是否带来牛市,币圈内人都渴望至极,但几乎没有算的上严格的逻辑。币圈外的人,都嘲笑我们是傻X。而今年,币圈内人,随着2019年,比特币年中小勃起一下就痿了,也开始各种自我怀疑。 在计算机世界里,摩尔定律,是否也是
    发表于 12-11 09:47 933次阅读

    比特减半的历史影响全面分析

    从历史上讲,减半比特币本身已成为价格行动的一个非常积极的事件。对于不熟悉的人来说,比特减半是指网络的发行率(或通胀率)每四年降低50%。
    发表于 02-11 17:20 6846次阅读
    <b class='flag-5'>比特</b>币<b class='flag-5'>减半</b>的历史影响全面分析

    比特币的区块挖矿奖励减半后会产生哪些影响

    比特币的总供应量为 2100 万枚。比特币的区块挖矿奖励每 21 万个区块奖励减半,按照比特币大约 10 分钟出块的时间间隔来算,减半的时间
    发表于 02-29 11:40 1766次阅读

    比特减半可以带动牛市

    本次减半依旧如过往那样带来一波牛市,当然这波牛市不会把数字货币的整体市值带到和股市的体量,也不会让比特币普及到大众都会参与。
    发表于 02-25 10:41 926次阅读

    减半比特币网络转账费用会有所改变吗

    关于比特减半事件的猜测和言论没有达到目标-人们已经知道它已经有十多年了,并不希望有任何惊喜。
    发表于 02-27 10:22 1196次阅读
    <b class='flag-5'>减半</b>后<b class='flag-5'>比特</b>币网络转账费用会有所改变吗

    iPhone都能微调大模型了嘛

    的是,与原驼一起提出的新方法 QLoRA 把微调大模型的 显存需求从>780GB降低到 。 开源社区直接开始狂欢,相关论文成为24小时内关注度最高的AI论文。   以Meta的美洲驼
    的头像 发表于 06-02 15:26 461次阅读
    iPhone都能<b class='flag-5'>微调</b>大模型了嘛

    8G显存一键训练,解锁Llama2隐藏能力!XTuner带你玩转大模型

    针对 GPU 计算特点,在显存允许的情况下,XTuner 支持将多条短数据拼接至模型最大输入长度,以此最大化 GPU 计算核心的利用率,可以显著提升训练速度。例如,在使用 oasst1 数据集微调 Llama2-7B 时,数据拼
    的头像 发表于 09-04 16:12 1516次阅读
    8G<b class='flag-5'>显存</b>一键训练,解锁<b class='flag-5'>Llama</b>2隐藏能力!XTuner带你玩转大模型

    怎样使用QLoRA对Llama 2进行微调呢?

    使用QLoRA对Llama 2进行微调是我们常用的一个方法,但是在微调时会遇到各种各样的问题
    的头像 发表于 09-22 14:27 1097次阅读
    怎样使用QLoRA对<b class='flag-5'>Llama</b> 2进行<b class='flag-5'>微调</b>呢?