GQA(Grouped-Query Attention)是近年来大语言模型推理优化中的一项关键技术,其设计初衷是在不显著牺牲生成质量的前提下,大幅降低推理阶段的内存占用和计算延迟。要深入理解 GQA,需从标准注意力机制的瓶颈出发,逐步剖析其结构动机、数学形式、工程实现细节及实际影响。
背景:MHA 与 MQA 的权衡困境
在原始 Transformer 中,多头注意力(Multi-Head Attention, MHA)为每个头维护独立的 Query(Q)、Key(K)和 Value(V)投影矩阵。假设有 个注意力头,则 Q、K、V 各自被划分为 组,每组独立计算注意力。这种设计赋予模型强大的表达能力,使其能并行关注不同语义子空间(如语法、指代、主题等)。然而,在自回归生成任务中,所有历史 token 的 K 和 V 必须缓存(即 KV Cache),以避免重复计算。KV Cache 的大小正比于 ,当模型规模增大(如 Llama-2-70B 使用 64 个头),缓存体积迅速成为显存和带宽瓶颈,尤其在长上下文或高并发场景下。
为缓解此问题,多查询注意力(Multi-Query Attention, MQA)提出:仅使用一个共享的 K/V 头,所有 Q 头共用同一组 K/V。此时 KV Cache 大小降至原来的 ,推理速度显著提升。但 MQA 的代价是表达能力严重受限——所有注意力头被迫基于相同的上下文表示进行加权,难以捕捉多样化的依赖关系,导致生成质量下降,尤其在需要精细上下文建模的任务(如代码生成、多跳问答)中表现明显退化。
GQA 的核心思想:分组共享
GQA 在 MHA 与 MQA 之间引入一个连续可调的中间态:将 个 Query 头划分为 个组( 整除 ),每组分配一个独立的 Key/Value 头。因此,总共有 个 K/V 头,而 保持不变。例如,Llama-3-8B 使用 32 个 Q 头、8 个 K/V 头,即 ,每 4 个 Q 头共享一组 K/V。
这种设计保留了 Q 空间的多样性(仍可学习不同视角的查询),同时将 KV Cache 大小从 倍压缩至 倍。当 时,GQA 退化为 MHA;当 时,退化为 MQA。因此,GQA 是一个广义框架,通过调节组数 实现效率与效果的灵活平衡。
数学形式与前向计算
设输入序列为 ,其中 为序列长度, 为隐藏维度。GQA 的投影过程如下:
在注意力计算中,对第 个 Q 头(),其对应的 K/V 头索引为 。注意力输出为:
所有头的输出拼接后经线性变换得到最终结果。
工程实现要点
在训练阶段,GQA 与 MHA 几乎无异:只需调整 K/V 投影矩阵的输出维度,并在注意力计算时正确映射 Q 到对应的 K/V 组。主流框架(如 PyTorch)可通过 reshape 和 broadcast 高效实现,无需自定义 CUDA kernel。
在推理阶段,优势显著体现:
- • 内存带宽压力降低,尤其在 A100/H100 等 GPU 上,可提升吞吐 1.5–2 倍;
- • 批处理效率提高,因更少的缓存占用允许更大 batch size。
此外,GQA 对训练稳定性无负面影响,且可与 FlashAttention、PagedAttention 等高效注意力机制无缝结合。
实际效果与选择建议
大量实验表明,当 时(如 32Q/8KV),GQA 在多数任务上与 MHA 性能几乎持平,但在推理速度和显存上优势巨大。因此,现代开源模型(如 Llama-3、Mixtral、Qwen2)普遍采用 GQA 作为默认配置。
选择组数 时,需考虑:
- • 任务复杂度:高精度任务(如数学、代码)建议 不小于 8;
- • 硬件约束:显存紧张时可适当减小 ,但需验证质量损失。
总结
GQA 并非理论突破,而是面向实际部署的精巧工程折中。它通过引入“分组共享”的归纳偏置,在保持模型表达能力的同时,有效缓解了大模型推理的核心瓶颈——KV Cache 膨胀。这一设计体现了当前大模型架构演进的重要趋势:在硬件约束下,通过结构化稀疏和参数复用,实现效率与性能的帕累托最优。理解 GQA,不仅有助于掌握现代 Transformer 的实现细节,也为未来模型压缩与加速提供思路。