当前位置:AIGC资讯 > AIGC > 正文

Llama改进之——分组查询注意力

引言

今天介绍LLAMA2模型引入的关于注意力的改进——分组查询注意力(Grouped-query attention,GQA)1。

Transformer中的多头注意力在解码阶段来说是一个性能瓶颈。多查询注意力2通过共享单个key和value头,同时不减少query头来提升性能。多查询注意力可能导致质量下降和训练不稳定,因此常用的是分组查询注意力。

然后我们结合上篇文章3探讨的旋转位置编码,将选择位置编码应用到分组查询注意力上。

多头注意力

我们先回顾以下原始多头注意力的实现。

import torch
from torch import nn, Tensor

import math
from dataclasses import dataclass


@dataclass
class ModelArgs:
    hidden_size: int = 512
    num_heads: int = 8
    attention_dropout: float = 0.1


class MultiHeadAttention(nn.Module):
    def __init__(self, args: ModelArgs) -> None:
        super().__init__()
        self.hidden_size = args.hidden_size
        self.num_heads = args.num_heads
        self.head_dim = self.hidden_size // self.num_heads
        self.attention_dropout = args.attention_dropout

        self.q_proj = nn.Linear(
            self.hidden_size, self.num_heads * self.head_dim, bias=False
        )
        self.k_proj = nn.Linear(
            self.hidden_size, self.num_heads * self.head_dim, bias=False
        )
        self.v_proj = nn.Linear(
            self.hidden_size, self.num_heads * self.head_dim, bias=False
        )
        self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)

    def forward(self, hidden_states: Tensor, attention_mask: Tensor = None):

        batch_size, seq_len, _ = hidden_states.shape

        query_states, key_states, value_states = (
            self.q_proj(hidden_states),
            self.k_proj(hidden_states),
            self.v_proj(hidden_states),
        )

        query_states = query_states.view(
            batch_size, seq_len, self.num_heads, self.head_dim
        ).transpose(1, 2)
        key_states = key_states.view(
            batch_size, seq_len, self.num_heads, self.head_dim
        ).transpose(1, 2)
        value_states = value_states.view(
            batch_size, seq_len, self.num_heads, self.head_dim
        ).transpose(1, 2)

        attn_weights = torch.matmul(
            query_states, key_states.transpose(2, 3)
        ) / math.sqrt(self.head_dim)

        if attention_mask is not None:
            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
            attn_weights = attn_weights + causal_mask

        # upcast attention to fp32 see https://github.com/huggingface/transformers/pull/17437
        attn_weights = nn.functional.softmax(
            attn_weights, dim=-1, dtype=torch.float32
        ).to(query_states.dtype)

        attn_weights = nn.functional.dropout(
            attn_weights, p=self.attention_dropout, training=self.training
        )
        attn_output = torch.matmul(attn_weights, value_states)

        attn_output = attn_output.transpose(1, 2).contiguous()

        attn_output = attn_output.reshape(batch_size, seq_len, self.hidden_size)

        attn_output = self.o_proj(attn_output)

        return attn_output


别忘了测试一下:

    args = ModelArgs()
    attention = MultiHeadAttention(args)
    inputs = torch.randn(32, 8, args.hidden_size)
    print(attention(inputs).shape)
torch.Size([32, 8, 512])

原始多头注意力就不再赘述了,之前的文章有过详细介绍。

分组查询注意力

分组查询注意力使用折中数量的key-value头(超过一个,但少于多头注意力全部的头数量)来提升性能。

多头注意力、分组查询注意力以及多查询注意力之间的区别如下:

该图来自参考1中的论文。

如上图所示,分组查询注意力是针对多头注意力的一种改进,每组Query头(这里两个Query一组)共享同一个Key和Value头,使得推理更加高效。

实际上在实现的时候,会将共享的Key和Value头进行广播(复制)成与Query头相同的数量:

这样,我们就可以像普通多头注意力一样去计算了。

我们增加num_key_value_heads表示key、value头数;num_heads还是表示query头数。

@dataclass
class ModelArgs:
    hidden_size: int = 512
    num_heads: int = 8
    num_key_value_heads: int = 4
    attention_dropout: float = 0.1

分组查询注意力和多查询注意力可以合并在一起实现:

class GroupedQueryAttention(nn.Module):
    def __init__(self, args: ModelArgs) -> None:
        super().__init__()

        self.hidden_size = args.hidden_size
        self.num_heads = args.num_heads
        # 每个头的维度计算和之前一样
        self.head_dim = self.hidden_size // self.num_heads
        # 保存key/value头数
        self.num_key_value_heads = args.num_key_value_heads
        # 每组内要复制的次数,若为1,即退化为多头注意力;若为num_heads,则为多查询注意力
        self.num_key_value_groups = self.num_heads // args.num_key_value_heads
        self.attention_dropout = args.attention_dropout

        self.q_proj = nn.Linear(
            self.hidden_size, self.num_heads * self.head_dim, bias=False
        )
        # 注意Key和Value的映射这里节省了参数,加速了推理效率。
        self.k_proj = nn.Linear(
            self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
        )
        self.v_proj = nn.Linear(
            self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
        )
        # 最后的输出映射和之前一样
        self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)

    def forward(self, hidden_states: Tensor, attention_mask: Tensor = None):

        batch_size, seq_len, _ = hidden_states.shape

        query_states, key_states, value_states = (
            self.q_proj(hidden_states),
            self.k_proj(hidden_states),
            self.v_proj(hidden_states),
        )

        query_states = query_states.view(
            batch_size, seq_len, self.num_heads, self.head_dim
        ).transpose(1, 2)
        # 转换为对应的形状
        key_states = key_states.view(
            batch_size, seq_len, self.num_key_value_heads, self.head_dim
        ).transpose(1, 2)
        value_states = value_states.view(
            batch_size, seq_len, self.num_key_value_heads, self.head_dim
        ).transpose(1, 2)
        
		# 重复num_key_value_groups次,使得和query头数一致
        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)
		# 后面和普通多头注意力一样计算
        attn_weights = torch.matmul(
            query_states, key_states.transpose(2, 3)
        ) / math.sqrt(self.head_dim)

        if attention_mask is not None:
            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
            attn_weights = attn_weights + causal_mask

        # upcast attention to fp32 see https://github.com/huggingface/transformers/pull/17437
        attn_weights = nn.functional.softmax(
            attn_weights, dim=-1, dtype=torch.float32
        ).to(query_states.dtype)

        attn_weights = nn.functional.dropout(
            attn_weights, p=self.attention_dropout, training=self.training
        )
        attn_output = torch.matmul(attn_weights, value_states)

        attn_output = attn_output.transpose(1, 2).contiguous()

        attn_output = attn_output.reshape(batch_size, seq_len, self.hidden_size)

        attn_output = self.o_proj(attn_output)

        return attn_output

其中num_key_value_groups为每组内要复制的次数,若为1,即退化为多头注意力;若为num_heads,则为多查询注意力。

复制时调用repeat_kv方法,如其名所示,只针对key和value:

def repeat_kv(hidden_states: Tensor, n_rep: int) -> Tensor:
    """
    The hidden states go from (batch, num_key_value_heads, seq_len, head_dim) to (batch, num_attention_heads, seq_len, head_dim)
    n_rep is the number of repeat times.
    """
    batch, num_key_value_heads, seq_len, head_dim = hidden_states.shape
    if n_rep == 1:
        # do nothing
        return hidden_states
    # add a new dimension and repeat n_rep times
    hidden_states = hidden_states[:, :, None, :, :].expand(
        batch, num_key_value_heads, n_rep, seq_len, head_dim
    )
    # reshape to (batch, num_attention_heads, seq_len, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, seq_len, head_dim)

有了分组查询注意力,下面我们来看如何应用上篇文章3介绍的旋转位置编码到query和key上。

应用旋转位置编码

注意,实现的时候要考虑维度,因此代码和上篇文章的旋转位置编码3有所不同。

首先,我们实现RotaryEmbedding,它缓存了频率张量inv_freq的计算。

class RotaryEmbedding(nn.Module):
    def __init__(
        self, dim: int, max_position_embeddings: int = 2048, theta: int = 10000
    ):
        super().__init__()
        self.dim = dim  # head dim
        self.max_position_embeddings = max_position_embeddings
        self.theta = theta
        inv_freq = 1.0 / (
            theta
            ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)
        )
        self.register_buffer("inv_freq", inv_freq, persistent=False)
	# 不需要计算梯度
    @torch.no_grad()
    def forward(self, position_ids: torch.LongTensor):
        freqs = torch.outer(position_ids, self.inv_freq).float()
        return torch.polar(torch.ones_like(freqs), freqs)

该实现修改自旋转位置编码文章3中的precompute_freqs_cis函数。

然后我们改写apply_rotary_emb函数,主要是确定了输入和输出维度的正确性:

def apply_rotary_emb(q: Tensor, k: Tensor, freq_cis: Tensor):
    """

    Args:
        q (Tensor): (batch_size, num_heads, seq_len, head_dim)
        k (Tensor): (batch_size, num_key_value_heads, seq_len, head_dim)
        freq_cis (Tensor): (seq_len, batch_size)
    """

    # q_ (batch_size, num_heads, seq_len, head_dim // 2, 2)
    q_ = q.float().reshape(*q.shape[:-1], -1, 2)
    # k_ (batch_size, num_key_value_heads, seq_len, head_dim // 2, 2)
    k_ = k.float().reshape(*k.shape[:-1], -1, 2)

    # turn to complex
    # q_ (batch_size, num_heads, seq_len, head_dim // 2)
    q_ = torch.view_as_complex(q_)
    # k_ (batch_size, num_key_value_heads, seq_len, head_dim // 2)
    k_ = torch.view_as_complex(k_)

    # freq_cis (batch_size, 1, seq_len, 1)
    freq_cis = reshape_for_broadcast(freq_cis, q_)

    # 应用旋转操作,然后将结果转回实数
    # view_as_real (batch_size, num_heads, seq_len, head_dim // 2, 2)
    # xq_out (batch_size, num_heads, seq_len, head_dim)
    xq_out = torch.view_as_real(q_ * freq_cis).flatten(-2)
    # view_as_real (batch_size, num_key_value_heads, seq_len, head_dim // 2, 2)
    # xk_out (batch_size, num_key_value_heads, seq_len, head_dim)
    xk_out = torch.view_as_real(k_ * freq_cis).flatten(-2)

    return xq_out.type_as(q), xk_out.type_as(k)

其中需要调用reshape_for_broadcast将频率张量的维度从(seq_len, batch_size)调整到(batch_size, 1, seq_len, 1)

def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    """
    Args:
        freqs_cis (torch.Tensor): (seq_len, batch_size)
        x (torch.Tensor): (batch_size, num_heads, seq_len, head_dim // 2)
    """
    # enumerate(x.shape) = [(0, batch_size), (1, num_heads), (2, seq_len), (3, head_dim // 2)]
    # (batch_size, 1, seq_len, 1)
    shape = [d if i == 0 or i == 2 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)

我们把每个维度都写出来就不会出错。

再确保下repeat_kv函数的维度:

def repeat_kv(hidden_states: Tensor, n_rep: int) -> Tensor:
    """
    The hidden states go from (batch, num_key_value_heads seq_len, head_dim) to (batch, num_attention_heads, seq_len, head_dim)
    n_rep is the number of repeat times.
    """
    batch, num_key_value_heads, seq_len, head_dim = hidden_states.shape
    if n_rep == 1:
        # do nothing
        return hidden_states
    # add a new dimension and repeat n_rep times
    hidden_states = hidden_states[:, :, None, :, :].expand(
        batch, num_key_value_heads, n_rep, seq_len, head_dim
    )
    # reshape to (batch, num_attention_heads, seq_len, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, seq_len, head_dim)

最后将旋转位置编码整合到GroupedQueryAttention中:

class GroupedQueryAttention(nn.Module):
    def __init__(self, args: ModelArgs) -> None:
        super().__init__()

        self.hidden_size = args.hidden_size
        self.num_heads = args.num_heads
        # 每个头的维度计算和之前一样
        self.head_dim = self.hidden_size // self.num_heads
        # 保存key/value头数
        self.num_key_value_heads = args.num_key_value_heads
        # 每组内要复制的次数,若为1,即退化为多头注意力;若为num_heads,则为多查询注意力
        self.num_key_value_groups = self.num_heads // args.num_key_value_heads
        self.attention_dropout = args.attention_dropout

        self.max_position_embeddings = args.max_position_embeddings
        self.rope_theta = args.theta

        self.q_proj = nn.Linear(
            self.hidden_size, self.num_heads * self.head_dim, bias=False
        )
        # 注意Key和Value的映射这里节省了参数,加速了推理效率。
        self.k_proj = nn.Linear(
            self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
        )
        self.v_proj = nn.Linear(
            self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
        )
        # 最后的输出映射和之前一样
        self.o_proj = nn.Linear(
            self.num_heads * self.head_dim, self.hidden_size, bias=False
        )
		# 定义了RotaryEmbedding实例
        self.rotary_emb = RotaryEmbedding(
            self.head_dim,
            max_position_embeddings=self.max_position_embeddings,
            theta=self.rope_theta,
        )

    def forward(
        self,
        hidden_states: Tensor,
        attention_mask: Tensor = None,
        position_ids: torch.LongTensor = None,
    ):

        batch_size, seq_len, _ = hidden_states.shape

        query_states, key_states, value_states = (
            self.q_proj(hidden_states),
            self.k_proj(hidden_states),
            self.v_proj(hidden_states),
        )
        # query_states(batch_size, num_heads, seq_len, head_dim)
        query_states = query_states.view(
            batch_size, seq_len, self.num_heads, self.head_dim
        ).transpose(1, 2)
        # 转换为对应的形状
        # key_states (batch_size, num_key_value_heads, seq_len, head_dim)
        key_states = key_states.view(
            batch_size, seq_len, self.num_key_value_heads, self.head_dim
        ).transpose(1, 2)
        # value_states (batch_size, num_key_value_heads, seq_len, head_dim)
        value_states = value_states.view(
            batch_size, seq_len, self.num_key_value_heads, self.head_dim
        ).transpose(1, 2)

        # 计算频率张量
        # freq_cis (seq_len, batch_size)
        freq_cis = self.rotary_emb(position_ids)

        # 针对query和key应用旋转位置编码
        # query_states (batch_size, num_heads, seq_len, head_dim)
        # key_states (batch_size, num_key_value_heads, seq_len, head_dim)
        query_states, key_states = apply_rotary_emb(query_states, key_states, freq_cis)

        # 重复num_key_value_groups次,使得和query头数一致
        # key_states (batch_size, num_heads, seq_len, head_dim)
        key_states = repeat_kv(key_states, self.num_key_value_groups)
        # value_states (batch_size, num_heads, seq_len, head_dim)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        # 后面和普通多头注意力一样计算
        attn_weights = torch.matmul(
            query_states, key_states.transpose(2, 3)
        ) / math.sqrt(self.head_dim)

        if attention_mask is not None:
            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
            attn_weights = attn_weights + causal_mask

        # upcast attention to fp32 see https://github.com/huggingface/transformers/pull/17437
        attn_weights = nn.functional.softmax(
            attn_weights, dim=-1, dtype=torch.float32
        ).to(query_states.dtype)

        attn_weights = nn.functional.dropout(
            attn_weights, p=self.attention_dropout, training=self.training
        )
        attn_output = torch.matmul(attn_weights, value_states)

        attn_output = attn_output.transpose(1, 2).contiguous()

        attn_output = attn_output.reshape(batch_size, seq_len, self.hidden_size)

        attn_output = self.o_proj(attn_output)

        return attn_output

主要修改是在调用repeat_kv之前应用旋转位置编码到(每个Attention的)query和key中:

# 计算频率张量
# freq_cis (seq_len, batch_size)
freq_cis = self.rotary_emb(position_ids)

# 针对query和key应用旋转位置编码
# query_states (batch_size, num_heads, seq_len, head_dim)
# key_states (batch_size, num_key_value_heads, seq_len, head_dim)
query_states, key_states = apply_rotary_emb(query_states, key_states, freq_cis)

这里简单探讨下为什么旋转位置编码只是应用到query和key上,没有应用到value上,考虑Attention的计算公式:
a m , n = exp ⁡ ( q m T k n d ) ∑ j = 1 N exp ⁡ q m T k j d o m = ∑ n = 1 N a m , n v n \begin{aligned} a_{m,n} &= \frac{\exp(\frac{\pmb q^T_m \pmb k_n}{\sqrt d})}{\sum_{j=1}^N \exp \frac{\pmb q^T_m \pmb k_j}{\sqrt d}} \\ \pmb o_m &= \sum_{n=1}^N a_{m,n}\pmb v_n \\ \end{aligned} am,n​ooom​​=∑j=1N​expd ​q​q​​qmT​kkkj​​exp(d ​q​q​​qmT​kkkn​​)​=n=1∑N​am,n​vvvn​​

我们可以看到,实际上只有query和key之间会进行交互(点乘),而value只是用于计算加权和,不参与交互,因此没有必要应用旋转位置编码,但也可以尝试应用到value上。

苏神在博客也说了:“通过在q,k中施行该位置编码,那么效果就等价于相对位置编码,而如果还需要显式的绝对位置信息,则可以同时在v上也施行这种位置编码。总的来说,我们通过绝对位置的操作,可以达到绝对位置的效果,也能达到相对位置的效果。”

最后,进行一个简单的测试:

@dataclass
class ModelArgs:
    hidden_size: int = 512
    num_heads: int = 8
    num_key_value_heads: int = 4
    attention_dropout: float = 0.1
    max_position_embeddings: int = 2048
    theta: int = 10000
    
if __name__ == "__main__":
    args = ModelArgs()
    attention = GroupedQueryAttention(args)

    inputs = torch.randn(32, 16, args.hidden_size)

    seq_len = inputs.size(1)

    position_ids = torch.arange(seq_len, dtype=torch.long)

    print(attention(inputs, position_ids=position_ids).shape)

torch.Size([32, 16, 512])

参考

[论文翻译]GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints ↩︎

Fast Transformer Decoding: One Write-Head is All You Need ↩︎

Llama改进之——RoPE旋转位置编码 ↩︎ ↩︎ ↩︎ ↩︎

总结

更新时间 2024-08-24