Transformer升级之路:多头潜在注意力机制(MLA)究竟好在哪里?

©PaperWeekly 原创 · 作者 | 苏剑林
单位 | 科学空间
研究方向 | NLP、神经网络


自从 DeepSeek 爆火后,它所提的 Attention 变体 MLA(Multi-head Latent Attention)也愈发受到关注。
MLA 通过巧妙的设计实现了 MHA 与 MQA 的自由切换,使得模型可以根据训练和推理的不同特性(Compute-Bound or Memory-Bound)选择最佳的形式,尽可能地达到效率最大化。 
诚然,MLA 很有效,但也有观点认为它不够优雅,所以寻找 MLA 替代品的努力一直存在,包括我们也有在尝试。
然而,经过一段时间的实验,我们发现很多 KV Cache 相同甚至更大的 Attention 变体,最终效果都不如 MLA。这不得不让我们开始反思:MLA 的出色表现背后的关键原因究竟是什么? 
接下来,本文将详细介绍笔者围绕这一问题的思考过程以及相关实验结果。


观察
MLA 提出自 DeepSeek-V2 [1],本文假设读者已经熟悉 MLA,至少了解之前的文章缓存与效果的极限拉扯:从MHA、MQA、GQA到MLA所介绍的内容,因此 MLA 自身的细节将不会过多展开。 
MLA 的主要特点如下:
1. MLA 在训练阶段是一个 qk_head_dims=(128+64)、v_head_dims=128 的 MHA;
2. MLA 在解码阶段是一个 qk_head_dims=(512+64)、v_head_dims=512、KV-Shared 的 MQA; 
3. MLA 的 [qc, qr]、[kc, kr] 拼接,可以理解为一种Partial RoPE。

猜测

MHA、GQA 常用的 head_dims 是 128,而对于 MLA 来说,不管是从训练看的 128+64,还是从推理看的 512+64,都要大于 128,再结合《突破瓶颈,打造更强大的 Transformer》[2] 的经验,我们有: 
猜测 1:增大 head_dims 是 MLA 好的关键之一。 
另外,KV-Shared 这个特性,可以在同等 KV Cache 大小下,增大 GQA 的 head_dims 或者 num_groups,所以有: 
猜测 2:KV-Shared 是 MLA 好的关键之一。 
最后,此前有一些理论和实验显示 Partial RoPE 可能会对效果有正面帮助(参考Transformer升级之路:RoPE的底数选择原则),所以有:
猜测 3:Partial RoPE 是 MLA 好的关键之一。


实验
现在我们通过实验逐一检验以上猜测。
设置
所有实验公共部分的超参数如下:
1. 类似 LLAMA3 的 Dense 模型;
 
2. hidden_size=2048,num_layers=12,num_heads=16; 

3. 优化器是Muon,Attention 部分 per head 更新;

4. 训练长度为 4096,总 tokens 数为 16B,总训练步数为 16k;

5. 所有实验都是只改变 Attention,所以参数量不会严格对齐。

Part I
MLA 的 KV Cache 大小是 512+64,约等于 GQA2-128(第一个数字是 num_groups,第二个数字是 head_dims),所以对比的 baseline 为 GQA2-128 和 GQA1-256
为了验证 Partial RoPE,我们增加了 GQA1-256-PR,具体做法是将 Q、K 的 256 dims 分成 192+64 两部分,在 64 上加 RoPE,192 不加。 
结果如下:
即:
初步验证了增大 head_dims 和 Partial RoPE 的作用。这样看来,MLA 的设计中,RoPE 和 NoPE 拼接这部分看似无奈的设计,极有可能是它效果优异的关键原因!原论文声称 MLA 甚至优于 MHA,大概率也是因为所对比的 MHA 的 head_dims 只有 128。
Part II
为了进一步验证增大 head_dims 的作用,我们另外跑了 MHA、GQA2-192、MLA-256 三个实验,MHA 是 head_dims=128 的常规 MHA,GQA2-192 是直接增大 GQA2 的 head_dims 到 192,MLA-256 是将 MLA 的 128+64 提升到 192+64,对照如下:
可以看到,MHA 总参数量更多,KV Cache 更是 7 倍于 MLA,但 Loss 才堪堪追平 MLA,这跟 DeepSeek-V2 里边的结论接近。
此外,GQA2-192 优于 GQA2-128,但不如 GQA1-256;MLA 的 head_dims 升到 (192+64) 后,相比 (128+64) 也还能进一步提升效果。这些现象都表明,增加 head_dims 远比增加 num_groups 更有效。
Part III
接下来我们验证 KV-Shared,即 K、V 共享全部或大部分 dims。这里我们主要考虑的替代品是 head_dims 不超过 256 的 GQA,并且控制 KV Cache 的总大小跟 MLA 接近,所以当 KV-Shared 时,我们可以至多可以考虑 GQA2-256。
由于 KV-Shared 跟 RoPE 不完全兼容,参考 MLA 的做法,我们将 256 分成 192+64 两部分,其中:
1. 192 部分不加 RoPE,在 K、V 间共享; 

2. 64 部分加 RoPE,只用于 K;

3. V 另外再投影 64 dims,concat 到共享的 192 dims 上去。
这样一来,K、V 的 head_dims 都是 256,KV Cache 总大小是 (192+64+64)*2=640,略大于 MLA 的 512+64=576,这个版本我们简记为“GQA2-(192+64)-S1”,其实“S1”是“Shared-1”的缩写。
Part IV
另外一种 KV-Shared 的方案是:
1. 192 部分不加 RoPE,在 K、V 间共享;

2. 64 部分加 RoPE,同样在 K、V 间共享;

3. 做 Attention,由于 V 带 RoPE,此时是绝对位置编码效果;

4. 为了保证相对位置编码,将输出分成 192+64 两部分,64 部分再加一次逆向 RoPE。
这种做法是 K、V 完全共享,KV Cache 大小是 (192+64)*2=512,略小于 MLA。这个版本我们称为“GQA2-(192+64)-S2”,“S2”是“Shared-2”的缩写,背后的原理是笔者新提出的 VO-RoPE,参考Transformer升级之路:第二类旋转位置编码 [3]
Part V
另外,根据同样思路补了几个 GQA4 和 GQA1 的实验。所有实验结果汇总如下:
这里“GQA1-(512+64)-S3”是按照 MLA 的推理形式实现的 MQA,形式介乎 S1 与 S2 之间,它的主要特点是 head_dims 大。
结果解读:
1. KV-Shared 的 GQA 自带 Partial RoPE;

2. KV-Shared 的 GQA2-256,也能超过 MLA;

3. VO-RoPE 的引入,似乎有利于效果(S1 ≲ S2); 

4. 同等 KV Cache 下,head_dims 越大越好;

5. GQA2-(192+64)-S2 略微超过 GQA1-256-PR; 

6. GQA4-(128+64)-S2 的 KV Cache 最大,但效果不是最优,再次表明 head_dims 更关键。
关于 KV-Shared,还有两点观察:
1. 训练过程中,GQA1-256-PR 前期是明显领先 GQA2-(192+64)-S2,但后期被追平甚至略微反先,猜测 GQA1-256-PR 可能有后劲不足的嫌疑;

2. 如果没有 KV-Shared,GQA 顶多是 GQA1-256,也就是说 head_dims 顶天了 256,但有 KV-Shared 的话,GQA 可以做到 GQA1-512-S,单纯从 head_dims 看,KV-Shared 天花板更高。

Part VI
由于没有严格对齐参数量,可能读者会有“到底是增加参数量还是增加 head_dims 更本质”的疑虑,所以这里补充几个对齐参数量的实验。
这里考虑的对齐参数量的方式有三种:
1. double-heads:以“GQA2-128 vs GQA1-256”为例,将 GQA2-128 的 num_heads 翻倍,可以让 GQA2-128 的参数量跟 GQA1-256 相同; 

2. 缩减 MLP:缩小 MLP(SwiGLU)的 intermediate_size,也可以使得 GQA1-256 的参数量跟 GQA2-128 大致相同;

3. Q&O LoRA:GQA 的主要参数量来自 Query 和 Output 的投影矩阵,对这两个矩阵改用 LoRA,也可以降低 GQA1-256 的参数量。
实验结果如下:

结果主要分三块:
1. heads 翻倍相比 head_dims 翻倍,loss 稳定差 0.003 左右; 

2. 缩小 MLP 比 head_dims 减半,loss 稳定优 0.004 左右;

3. Q&O LoRA 性能损失最小,可以实现 head_dims 翻倍但参数量不增,且 loss 明显降。
结论:如果从增加参数量角度看,增大 head_dims 可能是效果增益较大的方向,配合 Q&O LoRA 可以实现参数量几乎不增,但收益仍相当。


小结
初步结论是:
1. 增大 head_dims 收益最大;

2. Partial RoPE 对 Loss 也有一定帮助;

3. KV-Shared 应该也有一定作用。
这样看来,此前我们一直在 head_dims=128 下找 MLA 的替代品,感觉是起点就先天不足了,难怪一直比不上 MLA。要想追平 MLA,head_dims 应该要 192 起步了,并辅以 Partial RoPE。至于 KV-Shared,也可能有用,但应该还需要更大规模的验证。

意义
其实这里边的意义,就看我们换掉 MLA 的决心有多强。 
假设 GQA2-(192+64)-S2 可以替代 MLA,但 MLA 也可以升到 256,目前看来 GQA2-(192+64)-S2 比不上 MLA-256 。那么换掉 MLA 的唯二好处是:
1. 结构更简单,可以方便加 QK-Norm;

2. 解码阶段的 head_dims 由 512+64 变成了 256,同时 num_groups 变为 2,可以 TP。


(文:PaperWeekly)

发表评论

×

下载每时AI手机APP

 

和大家一起交流AI最新资讯!

立即前往