当前位置:AIGC资讯 > AIGC > 正文

LLaMa 原理+源码——拆解 (KV-Cache, Rotary Positional Embedding, RMS Norm, Grouped Query Attention, SwiGLU)

原理 Vanilla Transformer 与 LLaMa 的区别 Embedding RMS Norm Rotary Positional Encodding SwiGLU Function KV-Cache Grouped Multi-Query Attention Multi Query Attention Grouped Multi-Query Attention 源码

原理

Vanilla Transformer 与 LLaMa 的区别

Vanilla Transformer 与 LLaMa 的对比:LLaMa与普通的Transformer架构不同的地方,包括采用了前置了层归一化(Pre-normalization)并使用RMSNorm 归一化函数(Normalizing Function)、使用了旋转位置嵌入(RoPE)、激活函数由ReLU更换为SwiGLU,并且将self-attention改进为使用KV-Cache的Grouped Query,整体Transformer架构与GPT-2 类似。

LLaMa -> Alpaca -> Vicuna 的演进:

LLaMa:Meta开源的Pre-trained Model,模型参数从7B、13B、32B、65B不等,LLaMa-7B在大多数基准测试上超过了Text-davinci-003(即GPT3-173B),相比于ChatGPT或者GPT4来说,LLaMa可能效果上还有差距,目前hugging face已集成了LLaMa的代码实现和开源模型。学术界和工业界都可以在此基础上进行学习和研究。

Alpaca:斯坦福在LLaMa-7B的基础上监督微调出来的模型,斯坦福是用OpenAI的Text-davinci-003(即GPT3-173B)的API配合self-instruct技术,使用175个提示语种子自动生成了52K条提示-回复的指示数据集,在LLaMa-7B上微调得到的模型,在8张80G的A100上训练了3小时。

Vicuna:在LLaMa-13B的基础上使用监督微调得到的模型,数据集来自于ShareGPT 产生的用户对话数据,共70K条。使用Pytorch FSDP在8张A100上训练了一天。相较于Alpaca,Vicuna在训练中将序列长度由512扩展到了2048,并且通过梯度检测和flash attention来解决内存问题;调整训练损失考虑多轮对话,并仅根据模型的输出进行微调。通过GPT4来打分评测,Vicuna可以达到ChatGPT 90%的效果。

LLaMa2:采用了Llama 1的大部分预训练设置和模型架构。LLaMa2和LLaMa1的最大差别是增加了文本长度,并在训练34B、70B的模型中应用了GQA

Embedding

Embedding的过程:word -> token_id -> embedding_vector,其中第一步转化使用tokenizer的词表进行,第二步转化使用 learnable 的 Embedding layer。

RMS Norm

对比 Batch Norm 和 Layer Norm:都是减去均值Mean,除以方差Var,最终将归一化为正态分布N(0,1)。只不过两者是在不同的维度(batch还是feature)求均值和方差,(其中,减均值:re-centering 将均值mean变换为0,除方差:re-scaling将方差varance变换为1)。

RMS Norm(Root Mean Layer Norm):RMS Norm认为,Layer Norm成功的原因是re-scaling,因为方差Var计算的过程中使用了均值Mean,因此RMS Norm不再使用均值Mean,而是构造了一个特殊的统计量RMS代替方差Var。为什么使用RMS Norm?(1)RMS Norm计算量更小。(2)RMS的效果和Layer Norm一样好。

针对输入向量 a 的RMS Norm 函数计算公式如下:

此外,RMSNorm 还可以引入可学习的缩放因子gi 和偏移参数bi,从而得到

RMSNorm 在HuggingFace Transformer 库中代码实现如下所示:

class LlamaRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        LlamaRMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps # eps 防止取倒数之后分母为0
    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        # weight 是末尾乘的可训练参数, 即g_i
        return (self.weight * hidden_states).to(input_dtype)

为了使得模型训练过程更加稳定,GPT-2 相较于GPT 就提出了将Layer Norm前置,将第一个层归一化移动到多头自注意力层之前,第二个层归一化也移动到了全连接层之前,同时残差连接的位置也调整到了多头自注意力层与全连接层之后。层归一化中也采用了RMSNorm 归一化函数。

Rotary Positional Encodding

普通绝对Positional Encodding的使用过程:word -> token_id -> embedding_vector + position_encodding -> Encoder_Input,其中第一步转化使用tokenizer的词表进行,第二步转化使用 learnable 的 Embedding layer。将得到的embedding_vector 和 position_encodding 进行element-wise的相加,然后才做为input送入LLM的encoder。


对比Absolute PE 和 Relative PE

Absolute PE 绝对位置编码:每次单独1个token的PE,每个token的PE之间没有关系,是一组固定的vector,反映每个token在sequence中的绝对位置。 Relative PE 相对位置编码:每次处理2个token的PE,只在计算attention时使用(在query@key时加在key上),反映2个token的相关度。

旋转位置编码(RoPE):RoPE 借助了复数的思想,出发点是通过绝对位置编码的方式实现相对位置编码。其目标是通过下述 f 运算,来给q,k 添加绝对位置信息m和n,得到˜qm 和˜kn,然后进行q@k:

实际上,我们借助了复数的思想,寻找了一个 g 运算来合并 f 运算q@k这两个操作,这样只需要token qk 以及两者的在seqence中的绝对位置mn即可:


可以看到与普通的相对位置编码不同,旋转相对位置编码用于在计算attention_score=q@k之后,对attention_score强调每个token之间的相对位置:

为什么叫旋转位置编码?因为使用欧拉公式构造旋转矩阵,将q@k的计算结果旋转到空间中对应的位置,实现对计算结果添加位置信息。

上面是2维的例子,只有2个token xmxn,LLaMa中是n维的,n个token也是一样操作:

由于上述旋转矩阵Rn 具有稀疏性,有大量元素是0,因此可以使用逐位相乘⊗ 操作进一步加快计算速度。

RoPE 在HuggingFace Transformer 库中代码实现如下所示:

class LlamaRotaryEmbedding(torch.nn.Module):

    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
        self.register_buffer("inv_freq", inv_freq)
        # Build here to make `torch.jit.trace` work.
        self.max_seq_len_cached = max_position_embeddings
        t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device,
        dtype=self.inv_freq.dtype)
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation
        # in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        dtype = torch.get_default_dtype()
        self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
        
    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        # This `if` block is unlikely to be run after we build sin/cos in `__init__`.
        # Keep the logic here just in case.
        if seq_len > self.max_seq_len_cached:
            self.max_seq_len_cached = seq_len
            t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
            freqs = torch.einsum("i,j->ij", t, self.inv_freq)
            # Different from paper, but it uses a different permutation
            # in order to obtain the same calculation
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
            self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype),
            persistent=False)
            self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype),
            persistent=False)
    
        return (
        self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
        self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
        )
    def rotate_half(x):
        """Rotates half the hidden dims of the input."""
        x1 = x[..., : x.shape[-1] // 2]
        x2 = x[..., x.shape[-1] // 2 :]
        return torch.cat((-x2, x1), dim=-1)

    def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
        # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
        cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
        sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
        cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
        sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
        q_embed = (q * cos) + (rotate_half(q) * sin)
        k_embed = (k * cos) + (rotate_half(k) * sin)
        return q_embed, k_embed

SwiGLU Function

SwiGLU 激活函数是Shazeer 在文献中提出,并在PaLM等模中进行了广泛应用,并且取得了不错的效果,相较于ReLU 函数在大部分评测中都有不少提升。在LLaMA 中全连接层使用带有SwiGLU 激活函数的FFN(Position-wise Feed-Forward Network)的计算公式如下:

其中,σ(x) 是Sigmoid 函数。下图给出了Swish 激活函数在参数β 不同取值下的形状。可以看到当β 趋近于0 时,Swish 函数趋近于线性函数y = x,当β 趋近于无穷大时,Swish 函数趋近于ReLU 函数,β 取值为1 时,Swish 函数是光滑且非单调。


HuggingFace 的Transformer 库中 S w i s h β = 1 Swish_{\beta=1} Swishβ=1​函数使用 SILU 函数 代替。

KV-Cache

首先来了解一下LLama的训练(下词预测任务):seq2seq的生成,但迭代T次,seq_len逐渐增加。

下句预测时的Self-Attention:

timpstep=1时seq_len=1,给[SOS]时,预测Love;
timpstep=2时seq_len=2,给[SOS] 和 Love时,预测that
timpstep=4时seq_len=4,给[SOS] 和 Love 和 can 和 quickly时,预测seize…

每个timestep我们只关注生成的最后一个token,但因为LLaMa是一个seq2seq的model,每次必须重新计算和生成前面的token,因此我们希望能将之前timestep计算生成过的token给缓存起来,下个timestep不用再次计算,这样的背景下,KV-Cache就产生了。

再来分析一下,每次个timestep的self-attention中我们到底需要哪些:因为我们只关注最后一个token的attention_output,如下图timestep=4,我们只需要attention_output的第4个token。

因此我们只需要Q的最后一个token和K的所有token相乘,得到最后一个token的attention_score,然后用V的所有token再与attention_score点积(相乘求和),得到最后一个token的attention_output

由上分析可知,每个timestep,我们的Q只需要新增的那个token即可,而K和V要缓存之前timestep的token,保证token是全的。每次计算出来的attention_output就是那个新增的token的attention。 这样就可以节省大量计算开销。


Grouped Multi-Query Attention

回顾原始的多头注意力Multi-Head Attention:时间开销的瓶颈在于矩阵的运算matrix computation

当我们使用KV-Cache后:时间开销的瓶颈在于内存的访问memory access

Multi Query Attention

多查询注意力(Multi Query Attention,MQA) 是多头注意力的一种变体。其主要区别在于,在多查询注意力中不同的注意力头共享一个键和值的集合,每个头只单独保留了一份查询参数。 具体操作上,去除 K和V 的head维度,只为Q保留head维度。因此这就是被叫做Multi Query Attention的原因。

因此K和V的矩阵仅有一份(不分head),这大幅度减少了显存占用,使其更高效。由于多查询注意力改变了注意力机制的结构,因此模型通常需要从训练开始就支持多查询注意力。

研究结果表明,可以通过对已经训练好的模型进行微调来添加多查询注意力支持,仅需要约 5% 的原始训练数据量就可以达到不错的效果。包括Falcon、SantaCoder、StarCoder等在内很多模型都采用了多查询注意力机制。

以LLM Foundry 为例,多查询注意力实现代码如下,与LLM Foundry 中实现的多头自注意力代码相对比,其区别仅在于建立Wqkv 层上:

class MultiQueryAttention(nn.Module):
"""Multi-Query self attention.
Using torch or triton attention implemetation enables user to also use
additive bias.
"""
    def __init__(
        self,
        d_model: int,
        n_heads: int,
        device: Optional[str] = None,
    ):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        self.Wqkv = nn.Linear( # Multi-Query Attention 创建
            d_model,
            d_model + 2 * self.head_dim, # 只创建查询的头向量,所以只有1 个d_model
            device=device, # 而键和值则共享各自的一个head_dim 的向量
        )
        self.attn_fn = scaled_multihead_dot_product_attention
        self.out_proj = nn.Linear(
            self.d_model,
            self.d_model,
            device=device
        )
        self.out_proj._is_residual = True # type: ignore
    def forward(
        self,
        x,
    ):
        qkv = self.Wqkv(x) # (1, 512, 960)
        query, key, value = qkv.split( # query -> (1, 512, 768)
            [self.d_model, self.head_dim, self.head_dim], # key -> (1, 512, 96)
            dim=2 # value -> (1, 512, 96)
        )
        context, attn_weights, past_key_value = self.attn_fn(
            query,
            key,
            value,
            self.n_heads,
            multiquery=True,
    )
        return self.out_proj(context), attn_weights, past_key_value
Grouped Multi-Query Attention

就是在 Multi-Query Attention的基础上,对input进行分组,每组都有自己的K,V,以及多头Q。

源码

[LLMs 实践] 01 llama、alpaca、vicuna 整体介绍及 llama 推理过程

更新时间 2024-01-22