原理 Vanilla Transformer 与 LLaMa 的区别 Embedding RMS Norm Rotary Positional Encodding SwiGLU Function KV-Cache Grouped Multi-Query Attention Multi Query Attention Grouped Multi-Query Attention 源码
原理
Vanilla Transformer 与 LLaMa 的区别
Vanilla Transformer 与 LLaMa
的对比:LLaMa与普通的Transformer架构不同的地方,包括采用了前置了层归一化(Pre-normalization)并使用RMSNorm 归一化函数(Normalizing Function)、使用了旋转位置嵌入(RoPE)、激活函数由ReLU更换为SwiGLU,并且将self-attention改进为使用KV-Cache的Grouped Query,整体Transformer架构与GPT-2 类似。
LLaMa -> Alpaca -> Vicuna
的演进:
LLaMa:Meta开源的Pre-trained Model,模型参数从7B、13B、32B、65B不等
,LLaMa-7B在大多数基准测试上超过了Text-davinci-003(即GPT3-173B),相比于ChatGPT或者GPT4来说,LLaMa可能效果上还有差距,目前hugging face已集成了LLaMa的代码实现和开源模型。学术界和工业界都可以在此基础上进行学习和研究。
Alpaca:斯坦福在LLaMa-7B的基础上监督微调
出来的模型,斯坦福是用OpenAI的Text-davinci-003(即GPT3-173B)的API
配合self-instruct
技术,使用175个提示语种子自动生成了52K条提示-回复的指示数据集,在LLaMa-7B上微调得到的模型,在8张80G的A100上训练了3小时。
Vicuna:在LLaMa-13B的基础上使用监督微调
得到的模型,数据集来自于ShareGPT 产生的用户对话数据
,共70K条。使用Pytorch FSDP在8张A100上训练了一天。相较于Alpaca,Vicuna在训练中将序列长度由512扩展到了2048
,并且通过梯度检测和flash attention来解决内存问题;调整训练损失考虑多轮对话,并仅根据模型的输出进行微调。通过GPT4来打分评测,Vicuna可以达到ChatGPT 90%的效果。
LLaMa2:采用了Llama 1的大部分预训练设置和模型架构。LLaMa2和LLaMa1的最大差别是增加了文本长度
,并在训练34B、70B
的模型中应用了GQA
。
Embedding
Embedding的过程:word -> token_id -> embedding_vector
,其中第一步转化使用tokenizer的词表进行,第二步转化使用 learnable 的 Embedding layer。
RMS Norm
对比 Batch Norm 和 Layer Norm
:都是减去均值Mean,除以方差Var,最终将归一化为正态分布N(0,1)
。只不过两者是在不同的维度(batch还是feature)求均值和方差,(其中,减均值:re-centering
将均值mean变换为0,除方差:re-scaling
将方差varance变换为1)。
RMS Norm(Root Mean Layer Norm)
:RMS Norm认为,Layer Norm成功的原因是re-scaling
,因为方差Var计算的过程中使用了均值Mean,因此RMS Norm不再使用均值Mean,而是构造了一个特殊的统计量RMS
代替方差Var。为什么使用RMS Norm?(1)RMS Norm计算量更小。(2)RMS的效果和Layer Norm一样好。
针对输入向量 a 的RMS Norm 函数计算公式如下:
此外,RMSNorm 还可以引入可学习的缩放因子gi 和偏移参数bi,从而得到
RMSNorm 在HuggingFace Transformer 库中代码实现如下所示:
class LlamaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps # eps 防止取倒数之后分母为0
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
# weight 是末尾乘的可训练参数, 即g_i
return (self.weight * hidden_states).to(input_dtype)
为了使得模型训练过程更加稳定,GPT-2 相较于GPT 就提出了将Layer Norm前置,将第一个层归一化移动到多头自注意力层之前,第二个层归一化也移动到了全连接层之前,同时残差连接的位置也调整到了多头自注意力层与全连接层之后。层归一化中也采用了RMSNorm 归一化函数。
Rotary Positional Encodding
普通绝对Positional Encodding的使用过程:word -> token_id -> embedding_vector + position_encodding -> Encoder_Input
,其中第一步转化使用tokenizer的词表进行,第二步转化使用 learnable 的 Embedding layer。将得到的embedding_vector 和 position_encodding 进行element-wise的相加,然后才做为input送入LLM的encoder。
对比Absolute PE 和 Relative PE
:
Absolute PE 绝对位置编码
:每次单独1个token的PE,每个token的PE之间没有关系,是一组固定的vector,反映每个token在sequence中的绝对位置。
Relative PE 相对位置编码
:每次处理2个token的PE,只在计算attention时使用(在query@key
时加在key上),反映2个token的相关度。
旋转位置编码(RoPE)
:RoPE 借助了复数的思想,出发点是通过绝对位置编码的方式实现相对位置编码。其目标是通过下述 f 运算,来给q,k 添加绝对位置信息m和n,得到˜qm 和˜kn,然后进行q@k:
实际上,我们借助了复数的思想
,寻找了一个 g 运算
来合并 f 运算
和q@k
这两个操作,这样只需要token q
和k
以及两者的在seqence中的绝对位置m
和n
即可:
可以看到与普通的相对位置编码不同,旋转相对位置编码用于在计算attention_score=q@k
之后,对attention_score强调每个token之间的相对位置:
为什么叫旋转位置编码?因为使用欧拉公式
构造旋转矩阵
,将q@k的计算结果旋转到空间中对应的位置,实现对计算结果添加位置信息。
上面是2维的例子,只有2个token xm
和xn
,LLaMa中是n维的,n个token也是一样操作:
由于上述旋转矩阵Rn 具有稀疏性,有大量元素是0,因此可以使用逐位相乘⊗
操作进一步加快计算速度。
RoPE 在HuggingFace Transformer 库中代码实现如下所示:
class LlamaRotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
self.register_buffer("inv_freq", inv_freq)
# Build here to make `torch.jit.trace` work.
self.max_seq_len_cached = max_position_embeddings
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device,
dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation
# in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
dtype = torch.get_default_dtype()
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
# This `if` block is unlikely to be run after we build sin/cos in `__init__`.
# Keep the logic here just in case.
if seq_len > self.max_seq_len_cached:
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation
# in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype),
persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype),
persistent=False)
return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
SwiGLU Function
SwiGLU 激活函数是Shazeer 在文献中提出,并在PaLM等模中进行了广泛应用,并且取得了不错的效果,相较于ReLU 函数在大部分评测中都有不少提升。在LLaMA 中全连接层使用带有SwiGLU 激活函数的FFN(Position-wise Feed-Forward Network)的计算公式如下:
其中,σ(x) 是Sigmoid 函数。下图给出了Swish 激活函数在参数β 不同取值下的形状。可以看到当β 趋近于0 时,Swish 函数趋近于线性函数y = x,当β 趋近于无穷大时,Swish 函数趋近于ReLU 函数,β 取值为1 时,Swish 函数是光滑且非单调。
HuggingFace 的Transformer 库中
S
w
i
s
h
β
=
1
Swish_{\beta=1}
Swishβ=1函数使用 SILU 函数 代替。
KV-Cache
首先来了解一下LLama的训练(下词预测任务):seq2seq的生成,但迭代T次,seq_len
逐渐增加。
下句预测时的Self-Attention:
timpstep=1时seq_len=1
,给[SOS]时,预测Love;timpstep=2时
seq_len=2
,给[SOS] 和 Love时,预测thattimpstep=4时
seq_len=4
,给[SOS] 和 Love 和 can 和 quickly时,预测seize…每个timestep我们只关注生成的最后一个token
,但因为LLaMa是一个seq2seq的model,每次必须重新计算和生成前面的token,因此我们希望能将之前timestep计算生成过的token给缓存起来,下个timestep不用再次计算,这样的背景下,KV-Cache就产生了。
再来分析一下,每次个timestep的self-attention中我们到底需要哪些:因为我们只关注最后一个token的attention_output
,如下图timestep=4,我们只需要attention_output的第4个token。
因此我们只需要Q的最后一个token和K的所有token相乘,得到最后一个token的attention_score
,然后用V的所有token再与attention_score
点积(相乘求和),得到最后一个token的attention_output
:
由上分析可知,每个timestep,我们的Q只需要新增的那个token即可,而K和V要缓存之前timestep的token,保证token是全的。每次计算出来的attention_output就是那个新增的token的attention。 这样就可以节省大量计算开销。
Grouped Multi-Query Attention
回顾原始的多头注意力Multi-Head Attention:时间开销的瓶颈在于矩阵的运算matrix computation
。
当我们使用KV-Cache后:时间开销的瓶颈在于内存的访问memory access
。
Multi Query Attention
多查询注意力(Multi Query Attention,MQA
) 是多头注意力的一种变体。其主要区别在于,在多查询注意力中不同的注意力头共享一个键和值的集合,每个头只单独保留了一份查询参数。 具体操作上,去除 K和V 的head维度,只为Q保留head维度
。因此这就是被叫做Multi Query Attention的原因。
因此K和V的矩阵仅有一份(不分head),这大幅度减少了显存占用,使其更高效。由于多查询注意力改变了注意力机制的结构,因此模型通常需要从训练开始就支持多查询注意力。
研究结果表明,可以通过对已经训练好的模型进行微调来添加多查询注意力支持,仅需要约 5% 的原始训练数据量就可以达到不错的效果。包括Falcon、SantaCoder、StarCoder等在内很多模型都采用了多查询注意力机制。
以LLM Foundry 为例,多查询注意力实现代码如下,与LLM Foundry 中实现的多头自注意力代码相对比,其区别仅在于建立Wqkv 层上:
class MultiQueryAttention(nn.Module):
"""Multi-Query self attention.
Using torch or triton attention implemetation enables user to also use
additive bias.
"""
def __init__(
self,
d_model: int,
n_heads: int,
device: Optional[str] = None,
):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.Wqkv = nn.Linear( # Multi-Query Attention 创建
d_model,
d_model + 2 * self.head_dim, # 只创建查询的头向量,所以只有1 个d_model
device=device, # 而键和值则共享各自的一个head_dim 的向量
)
self.attn_fn = scaled_multihead_dot_product_attention
self.out_proj = nn.Linear(
self.d_model,
self.d_model,
device=device
)
self.out_proj._is_residual = True # type: ignore
def forward(
self,
x,
):
qkv = self.Wqkv(x) # (1, 512, 960)
query, key, value = qkv.split( # query -> (1, 512, 768)
[self.d_model, self.head_dim, self.head_dim], # key -> (1, 512, 96)
dim=2 # value -> (1, 512, 96)
)
context, attn_weights, past_key_value = self.attn_fn(
query,
key,
value,
self.n_heads,
multiquery=True,
)
return self.out_proj(context), attn_weights, past_key_value
Grouped Multi-Query Attention
就是在 Multi-Query Attention的基础上,对input进行分组,每组都有自己的K,V,以及多头Q。
源码
[LLMs 实践] 01 llama、alpaca、vicuna 整体介绍及 llama 推理过程