1-bit大模型还能再突破!新一代BitNet架构启用4位激活值



  新智元报道  

编辑:alan
【新智元导读】近日,BitNet系列的原班人马推出了新一代架构:BitNet a4.8,为1 bit大模型启用了4位激活值,支持3 bit KV cache,效率再突破。

量化到1 bit的LLM还能再突破?

这次,他们对激活值下手了!

近日,BitNet系列的原班人马推出了新一代架构:BitNet a4.8,为1 bit大模型启用了4位激活值:

论文地址:https://arxiv.org/pdf/2411.04965

众所周知,激活值量化通常是比较难办的。

本次的BitNet a4.8采用混合量化和稀疏化策略,来减轻异常通道引入的量化误差。

简单来说就是,对注意力层和FFN层的输入采用4位量化,同时用8位整数稀疏化中间状态。

大量实验表明,BitNet a4.8在相同的训练成本下,实现了与前代BitNet b1.58相当的性能,同时因为可以吃到4位(INT4/FP4)内核的计算红利,实现了更快的推理速度。

BitNet a4.8仅激活55%的参数,并支持3 bit KV cache,进一步提升了大规模LLM部署和推理的效率。

BitNet a4.8

模型架构

模型的整体架构如图1所示,BitNet a4.8采用了与BitNet b1.58相同的布局。

作者使用BitLinear替换注意力(MHA)和前馈网络(FFN)中的线性投影,以从头开始学习1.58 bit权重。对于激活值,采用混合量化和稀疏化策略来减轻异常值维度引入的误差。

图2说明了模型大小为7B的BitNet b1.58中,每个模块输入的分布。

注意力层和FFN层的输入通常类似高斯分布,而在FFN下采样之前的激活值和注意力中的输出投影中,发现了很多异常值通道和大量接近零的条目(全精度LLM也有类似观察结果)。

如图3所示,直接将低位量化应用于这些中间状态会引入很大的量化误差。

因此,作者使用Q-Sparse的稀疏化方法,将这些中间状态保持在8位(同时消除了计算瓶颈)。

对于自注意层的输出投影,使用sparsify-then-quantize函数:

两个Q分别表示权重W和激活X的量化函数,M是掩码,根据激活X的绝对值取topK,⊙是元素乘法。

具体来说,权重量化和激活值量化函数可以表述为:

对于FFN,这里采用squared ReLU和门控线性单元(GLU)来进一步提高激活的稀疏性:

根据初步实验的结果,使用squared ReLU时,下采样输入的稀疏性超过了80%,且对性能的影响最小。

此外,作者还观察到gate + squared ReLU的输出也表现出高激活稀疏性(7B模型为67.5%)。通过首先计算gate projection,然后仅在非零通道上执行up projection,可以进一步减少推理的计算量。

相比之下,attention和FFN的输入中包含的异常值特征要少得多,可以使用absmean函数将激活值量化为4位整数:

模型训练

初始化

BitNet a4.8使用BitNet b1.58的权重开始训练,分为W1.58A8与W1.58A4两阶段。

第一阶段使用8位激活和GLU + squared ReLU训练模型;第二阶段采用上面介绍过的混合量化和稀疏化。

BitNet a4.8只需少量训练,即可快速适应4bit位宽和稀疏激活,同时性能损失可以忽略不计。

梯度近似

作者使用直通估计器(STE)对BitNet a4.8进行梯度逼近,使用混合精度训练来更新参数。

这里直接绕过了不可微函数,包括反向传播过程中的量化函数和topK稀疏函数。对于混合精度训练,保持全精度latent weight来累积参数更新。

模型量化

浮点量化提供了比基于整数的量化更宽的动态范围,这对于处理激活值的长尾分布至关重要。

研究人员将FFN下采样层的输入保留为8位整数,其他激活值使用MinMax量化器量化为FP4:

公式中E和M分别表示指数和尾数部分的位宽。这里采用E2M1格式,因为它的动态范围更大。

实验

本文将BitNet a4.8、BitNet b1.58,以及各种参数量大小的FP16精度LLaMA进行了比较。

其中的1.58 bit模型,遵循BitNet b1.58的训练方案,采用了两阶段权重衰减和学习率调度。

所有模型都使用RedPajama数据集中的100B token进行训练,以确保公平比较。

对于BitNet a4.8,作者首先使用95B token来训练8位激活值的模型。然后重用优化器状态,并使用5B token进行混合量化和稀疏化的训练。实验将topK设置为50%(attention的输出投影位置)。

作者使用lm-evaluation-harness工具包,评估模型在一系列语言任务上的zero-shot准确性,包括ARC-Easy(ARCe)、ARCChallenge(ARCc)、Hellaswag(HS)、Winogrande(WGe)和PIQA(PQ)。另外还测试了在C4数据集(测试集)上的困惑度。

主要结果

表1总结了BitNet a4.8、BitNet b1.58和FP16 LLaMA的详细测试结果。

全精度(FP16)LLaMA和BitNet b1.58之间的性能差距,随着模型大小的增长而缩小。对于7B模型,BitNet b1.58在语言模型困惑度和任务的平均准确性方面与LLaMA相当。

此外,相比于BitNet b1.58,BitNet a4.8的平均精度几乎没有损失。

表2展示了各种大小的BitNet a4.8、BitNet b1.58 和 FP16 LLaMA中每个模块的详细稀疏性(使用C4验证集上的非嵌入参数计算)。

值得注意的是,BitNet a4.8的稀疏性明显高于BitNet b1.58和LLaMA。

比如在7B模型中,BitNet a4.8的整体稀疏性达到了44.5%,只有3.4B的活跃参数。down projection层的输入显示出特别高的稀疏性,且中间状态分布以零为中心。

此外,gate projection的输出非常稀疏,导致了up projection的高稀疏性(因为只需要在从Gate中选择非零通道来执行投影)。

具体来说,对于7B BitNet a4.8,Gate和up projection的稀疏率分别为67.5%和12.0%。

表3显示了BitNet a4.8在3B和7B模型大小下,low-bit attention的详细情况。模型使用4位KV或QKV头,精度损失可忽略不计,同时KV cache可以量化为3位整数。

low-bit attention对于高效的长序列建模至关重要,它减少了KV cache的内存占用和IO,并加速了注意力计算。

在本文的实验中,作者采用RoPE后量化。使用absmax函数将QKV头直接量化为无符号整数,无需任何校准数据集。

对于3 bit KV量化,研究人员将bos token的头保留为4 bit,因为它包含更多的异常值特征。

消融实验

图4显示了700M BitNet a4.8的训练损耗曲线,比较了使用完整的INT4/FP4量化,以及本文的混合量化和稀疏化。

完整的INT4量化会导致发散,而混合架构在训练困惑度方面明显优于完整的FP4架构。

使用RedPajama数据集中25B token,来进行模型的第一阶段训练,采用absmean和MinMax量化器分别进行完整的INT4和FP4量化。

对于完整的INT4量化,由于其输入具有更大的异常值,这里设置β = 2*mean(|X|)。

接下来为1.3B BitNet a4.8的down projection层输入,设置不同的量化或激活函数。

所有模型都使用RedPajama数据集中的50B token进行第一阶段训练。为了确保公平比较,其他激活值都保留在8位。

图5显示了这些模型的训练损失曲线。Squared ReLU的训练困惑度比Swish略好,同时实现了更高的稀疏性。

此外,对down projection的输入应用FP4量化会导致性能显著下降,而将INT4激活与STE一起使用会导致发散。

(文:新智元)

欢迎分享

发表评论