旋转矩阵计算
rotary_emb 对应 L l a m a R o t a r y E m b e d d i n g LlamaRotaryEmbedding LlamaRotaryEmbedding层,其中内置 i n i t init init 初始化方法和 f o r w a r d forward forward 前向调用,负责生成旋转矩阵中的 c o s cos cos 和 s i n sin sin。
代码
class LlamaRotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
self.register_buffer("inv_freq", inv_freq)
# Build here to make `torch.jit.trace` work.
self.max_seq_len_cached = max_position_embeddings
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
if seq_len > self.max_seq_len_cached:
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
)
__init__函数关键代码
根据公式计算 θ \theta θ
源码:inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
公式:θ i = 1000 0 − 2 i d \theta_i = 10000^{\frac {-2i}d} θi=10000d−2i example
假设base=10000, dim=8, device=“cpu”
dim, base, device=8, 10000, 'cpu'
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
inv_freq
inv_freq为size=torch.Size([dim//2])的tensor
tensor([1.0000, 0.1000, 0.0100, 0.0010])
生成所有位置对应的ID
源码:t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
example:假设max_position_embeddings=10,则
max_position_embeddings =10
t = torch.arange(max_position_embeddings, dtype=inv_freq.dtype)
t
输出:
tensor([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.])
计算 m θ m\theta mθ
源码:freqs = torch.einsum("i,j->ij", t, self.inv_freq)
example:
freqs = torch.einsum("i,j->ij", t, inv_freq)
freqs
输出:
tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
[1.0000e+00, 1.0000e-01, 1.0000e-02, 1.0000e-03],
[2.0000e+00, 2.0000e-01, 2.0000e-02, 2.0000e-03],
[3.0000e+00, 3.0000e-01, 3.0000e-02, 3.0000e-03],
[4.0000e+00, 4.0000e-01, 4.0000e-02, 4.0000e-03],
[5.0000e+00, 5.0000e-01, 5.0000e-02, 5.0000e-03],
[6.0000e+00, 6.0000e-01, 6.0000e-02, 6.0000e-03],
[7.0000e+00, 7.0000e-01, 7.0000e-02, 7.0000e-03],
[8.0000e+00, 8.0000e-01, 8.0000e-02, 8.0000e-03],
[9.0000e+00, 9.0000e-01, 9.0000e-02, 9.0000e-03]])
将 m θ m\theta mθ拼接两次
源码:emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
example:
emb = torch.cat((freqs, freqs), dim=-1)
emb
tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00],
[1.0000e+00, 1.0000e-01, 1.0000e-02, 1.0000e-03, 1.0000e+00, 1.0000e-01,
1.0000e-02, 1.0000e-03],
[2.0000e+00, 2.0000e-01, 2.0000e-02, 2.0000e-03, 2.0000e+00, 2.0000e-01,
2.0000e-02, 2.0000e-03],
[3.0000e+00, 3.0000e-01, 3.0000e-02, 3.0000e-03, 3.0000e+00, 3.0000e-01,
3.0000e-02, 3.0000e-03],
[4.0000e+00, 4.0000e-01, 4.0000e-02, 4.0000e-03, 4.0000e+00, 4.0000e-01,
4.0000e-02, 4.0000e-03],
[5.0000e+00, 5.0000e-01, 5.0000e-02, 5.0000e-03, 5.0000e+00, 5.0000e-01,
5.0000e-02, 5.0000e-03],
[6.0000e+00, 6.0000e-01, 6.0000e-02, 6.0000e-03, 6.0000e+00, 6.0000e-01,
6.0000e-02, 6.0000e-03],
[7.0000e+00, 7.0000e-01, 7.0000e-02, 7.0000e-03, 7.0000e+00, 7.0000e-01,
7.0000e-02, 7.0000e-03],
[8.0000e+00, 8.0000e-01, 8.0000e-02, 8.0000e-03, 8.0000e+00, 8.0000e-01,
8.0000e-02, 8.0000e-03],
[9.0000e+00, 9.0000e-01, 9.0000e-02, 9.0000e-03, 9.0000e+00, 9.0000e-01,
9.0000e-02, 9.0000e-03]])
计算sin、cos
源码:self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
rotate_half
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
将原始向量从中间分为x1、x2 两部分,然后拼接为 [-x2, x1] :
[q1,q2,q3,q4,q5,q6,q7,q8,q9,q10] -> [-q6,-q7,-q8,-q9,-q10,q1,q2,q3,q4,q5]
apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1]
gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
生成index tensor
gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1]
gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
lookup index获取对应位置的 cos m θ \cos{mθ} cosmθ 和 sin m θ \sin{mθ} sinmθ 值
cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
分别计算前面的 c o s cos cos 部分,再计算后面的 sin θ ∗ \sinθ * sinθ∗ rotate_half 部分
q_embed = (q * cos) + (rotate_half(q) * sin)
注意
q、cos、sin对应下标顺序与
q = [-q6,-q7,-q8,-q9,-q10,q1,q2,q3,q4,q5]
cos = [cosθ1,cosθ2,cosθ3,cosθ4,cosθ5,cosθ1,cosθ2,cosθ3, cosθ4, cosθ5]
sin = [sinθ1,sinθ2,sinθ3,sinθ4,sinθ5,sinθ1,sinθ2,sinθ3,sinθ4, sinθ5]
参考
LLM - 旋转位置编码 RoPE 代码详解
RoPE旋转位置编码深度解析:理论推导、代码实现、长度外推
图解RoPE旋转位置编码及其特性
Rotary Positional Embeddings (RoPE)