【文献阅读】MQA和GQA

MQA(Multi Query Attention)和GQA(Grouped Query Attention)是在Attention上加速大模型计算的tricks,可以缩短模型训练周期和加快推理速度。

论文地址:

参考:
https://mp.weixin.qq.com/s/_4OxoRLxhOcjGf0Q4Tvp2Q

模型优化

先来看一下MHA(Multi Head Attention)、MQA(Multi Query Attention)和GQA(Grouped Query Attention)的区别:

首先是原始的MHA(Multi Head Attention),QKV三部分有相同数量的头,且一一对应。每次做Attention,每个头的QKV只需要做好自己部分运算就可以了,输出时各个头concat起来就行了。

MQA(Multi Query Attention)是让Q保持原来的头数,但K和V变成只有一个,即所有头的Q共享一组K和V。但这种方法稍微会带来一些性能降低。 当然相对它能带来的收益,性能的些微降低是可以接受的。实验发现一般能提高 30%-40% 的吞吐。

GQA(Grouped Query Attention),是MHA和MQA的折中方案,既想获得MQA的MQA带来的加速效果,又不想损失太多性能。 具体思想是,Q仍然保持原来的头数,KV的头数减少为g个。多个头的Q共享一组KV,如图中就是两个Q共享一组KV。

LLAMA2中给出了效果对比,如下:

但是从上面的图来看,MQA和GQA的运算量和MHA其实是一样的,那么为什么会带来加速效果呢?这就要说到冯诺依曼架构和Memory Wall。

冯诺依曼架构和Memory Wall

预备知识

目前大模型基本上用的都是Transformer结构,Transformer包括Encoder和Decoder两个部分。

先来看Encoder部分,Encoder可以理解为AE模型,每个timestep的token可以看到所有timestep的token,因此是可以并行得到每个timestep的输出的,即一次inference得到所有timestep的结果。

而Decoder部分相当与是AR模型,每个timestep的输入是上一个timestep的输出,所以没法并行输出所有timestep的结果,只能一个一个的向后生成。

当然上述问题也有优化手段,如使用KV Cache。

Decoder的每次前向过程中,当前timestep之前的KV值都是计算过的,只是之前每次前向完成后计算结果都被丢掉了,只保留了最后的输出结果。于是一个很自然的想法就是Cache,每次前向完,将KV都保留下来,用于之后计算。代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
#q、k、v 当前 timestep 的 query,key,value
# K_prev,V_prev 之前所有 timestep 的 key 和 value
for _ in range(time_step):
...
K = torch.cat([K_prev, k], dim=-2) #[b, h, n, d]
V = torch.cat([V_prev, v], dim=-2) #[b, h, n, d]

logits = torch.einsum("bhd,bhnd->bhn", q, K)
weights = torch.softmax(logits/math.sqrt(d), dim=-1)
outs = torch.einsum("bhn,bhnd->bhd", weights, V)
...

K_prev, V_prev = K, V

但在大模型时代,上述办法存在缺陷。比如LLaMa 7B模型,hidden_size是4096,那么每个timestep需缓存参数量为\(4096 \times 2 \times 32=262144\),假设半精度保存就是512KB,1024长度那就要512MB。而现在英伟达最好的卡H100的SRAM缓存大概是 50MB,而A100则是40MB,KV计算值完全存不进缓存。

而且目前SRAM太贵了,我们没有办法直接做大SRAM内存呢,所以这条路现在是不太行的。于是退一步,放不进缓存可以放DRAM上去,而DRAM内存也就是我们常说的GPU显存。 但DRAM读取到计算芯片和SRAM读取到计算芯片的速度,差了一个量级,这会让计算芯片一直在等待数据读取。

现在我们遇到了当今芯片领域,冯诺依曼架构下最大的一个问题,也就是:Memory Wall(内存墙)。

加速原理

冯诺依曼架构包含四个部分:输入,输出,计算单元,加上存储单元。

现在随着摩尔定律的见顶,虽然计算和内存的发展速度在变缓,但这并不是最大的问题,最大的问题是存储单元与计算单元间的交互。

冯诺依曼架构需要先从内存中调取数据,送入计算单元进行处理,但现在计算单元的速度是显著提升的,而从内存中读取数据的速度却没跟上,所以计算和内存这里就形成了一个瓶颈。因为短板效应,内存读取速度限制了整体速度。计算单元能很快将数据处理完,但新数据却还没到,于是就只能等待,造成利用率不高。这就是内存墙。

因为内存墙问题,现在的大模型训练,一张A100卡计算单元的利用率到四五十就不错了,用上各种技巧优化到60%已经很高了。而对于H100卡问题会更严重,因为它的计算速度相对A100提高了6倍,而内存读取带宽只增加了1.6倍,所以也要大量优化来提高利用率。

内存墙怎么越过呢?

硬件层面上,比如现在已在使用的HBM(高速带宽内存)提高读取速度,或者更彻底些,抛弃冯诺依曼架构,改变计算单元从内存读数据的方式,不再以计算单元为中心,而以存储为中心,做成计算和存储一体的“存内计算”。

软件层面上的话,最近的很多优化,比如Flash Attention,Paged Attention都可以算。Flash Attention就是减少了计算Softmax时从DRAM内存读取数据次数,从而提高了效率。

同样,MQA和GQA也是一个软件层面上翻墙的一个方法。MQA和GQA形式在推理加速方面,主要是通过两方面来完成:

  1. 降低了从内存中读取的数据量,所以也就减少了计算单元等待时间,提高了计算利用率。
  2. KV cache变小了,也就是显存中需要保存的tensor变小了,空出来空间就可以加大batch size,从而又能提高利用率。

如果要用MQA和GQA,可以是从头训练的时候就加上,也可以像GQA论文里面一样,用已有的开源模型,挑一些头取个mean用来初始化MQA或GQA继续训练一段时间。