当前位置:AIGC资讯 > AIGC > 正文

【Llama源码】旋转位置编码ROPE--源码阅读

旋转矩阵计算

rotary_emb 对应 L l a m a R o t a r y E m b e d d i n g LlamaRotaryEmbedding LlamaRotaryEmbedding层,其中内置 i n i t init init 初始化方法和 f o r w a r d forward forward 前向调用,负责生成旋转矩阵中的 c o s cos cos 和 s i n sin sin。

代码

class LlamaRotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
        self.register_buffer("inv_freq", inv_freq)

        # Build here to make `torch.jit.trace` work.
        self.max_seq_len_cached = max_position_embeddings
        t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
        self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
        if seq_len > self.max_seq_len_cached:
            self.max_seq_len_cached = seq_len
            t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
            freqs = torch.einsum("i,j->ij", t, self.inv_freq)
            # Different from paper, but it uses a different permutation in order to obtain the same calculation
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
            self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
            self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
        return (
            self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
            self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
        )

__init__函数关键代码

根据公式计算 θ \theta θ
源码: inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) 公式:
θ i = 1000 0 − 2 i d \theta_i = 10000^{\frac {-2i}d} θi​=10000d−2i​ example
假设base=10000, dim=8, device=“cpu”
dim, base, device=8, 10000, 'cpu'
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
inv_freq

inv_freq为size=torch.Size([dim//2])的tensor

tensor([1.0000, 0.1000, 0.0100, 0.0010])
生成所有位置对应的ID
源码: t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) example:
假设max_position_embeddings=10,则
max_position_embeddings =10
t = torch.arange(max_position_embeddings, dtype=inv_freq.dtype)
t

输出:

tensor([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.])
计算 m θ m\theta mθ
源码:freqs = torch.einsum("i,j->ij", t, self.inv_freq) example:
freqs = torch.einsum("i,j->ij", t, inv_freq)
freqs

输出:

tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [1.0000e+00, 1.0000e-01, 1.0000e-02, 1.0000e-03],
        [2.0000e+00, 2.0000e-01, 2.0000e-02, 2.0000e-03],
        [3.0000e+00, 3.0000e-01, 3.0000e-02, 3.0000e-03],
        [4.0000e+00, 4.0000e-01, 4.0000e-02, 4.0000e-03],
        [5.0000e+00, 5.0000e-01, 5.0000e-02, 5.0000e-03],
        [6.0000e+00, 6.0000e-01, 6.0000e-02, 6.0000e-03],
        [7.0000e+00, 7.0000e-01, 7.0000e-02, 7.0000e-03],
        [8.0000e+00, 8.0000e-01, 8.0000e-02, 8.0000e-03],
        [9.0000e+00, 9.0000e-01, 9.0000e-02, 9.0000e-03]])
将 m θ m\theta mθ拼接两次
源码: emb = torch.cat((freqs, freqs), dim=-1).to(x.device) example:
emb = torch.cat((freqs, freqs), dim=-1)
emb
tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00],
        [1.0000e+00, 1.0000e-01, 1.0000e-02, 1.0000e-03, 1.0000e+00, 1.0000e-01,
         1.0000e-02, 1.0000e-03],
        [2.0000e+00, 2.0000e-01, 2.0000e-02, 2.0000e-03, 2.0000e+00, 2.0000e-01,
         2.0000e-02, 2.0000e-03],
        [3.0000e+00, 3.0000e-01, 3.0000e-02, 3.0000e-03, 3.0000e+00, 3.0000e-01,
         3.0000e-02, 3.0000e-03],
        [4.0000e+00, 4.0000e-01, 4.0000e-02, 4.0000e-03, 4.0000e+00, 4.0000e-01,
         4.0000e-02, 4.0000e-03],
        [5.0000e+00, 5.0000e-01, 5.0000e-02, 5.0000e-03, 5.0000e+00, 5.0000e-01,
         5.0000e-02, 5.0000e-03],
        [6.0000e+00, 6.0000e-01, 6.0000e-02, 6.0000e-03, 6.0000e+00, 6.0000e-01,
         6.0000e-02, 6.0000e-03],
        [7.0000e+00, 7.0000e-01, 7.0000e-02, 7.0000e-03, 7.0000e+00, 7.0000e-01,
         7.0000e-02, 7.0000e-03],
        [8.0000e+00, 8.0000e-01, 8.0000e-02, 8.0000e-03, 8.0000e+00, 8.0000e-01,
         8.0000e-02, 8.0000e-03],
        [9.0000e+00, 9.0000e-01, 9.0000e-02, 9.0000e-03, 9.0000e+00, 9.0000e-01,
         9.0000e-02, 9.0000e-03]])
计算sin、cos
源码:
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)

rotate_half

def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

将原始向量从中间分为x1、x2 两部分,然后拼接为 [-x2, x1] :

[q1,q2,q3,q4,q5,q6,q7,q8,q9,q10] -> [-q6,-q7,-q8,-q9,-q10,q1,q2,q3,q4,q5]

apply_rotary_pos_emb

def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
    gather_indices = position_ids[:, None, :, None]  # [bs, 1, seq_len, 1]
    gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
    cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
    sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed
生成index tensor
gather_indices = position_ids[:, None, :, None]  # [bs, 1, seq_len, 1]
gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
lookup index获取对应位置的 cos ⁡ m θ \cos{mθ} cosmθ 和 sin ⁡ m θ \sin{mθ} sinmθ 值
 cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
 sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
分别计算前面的 c o s cos cos 部分,再计算后面的 sin ⁡ θ ∗ \sinθ * sinθ∗ rotate_half 部分
 q_embed = (q * cos) + (rotate_half(q) * sin)
注意

q、cos、sin对应下标顺序与

q = [-q6,-q7,-q8,-q9,-q10,q1,q2,q3,q4,q5]
cos = [cosθ1,cosθ2,cosθ3,cosθ4,cosθ5,cosθ1,cosθ2,cosθ3, cosθ4, cosθ5]
sin = [sinθ1,sinθ2,sinθ3,sinθ4,sinθ5,sinθ1,sinθ2,sinθ3,sinθ4, sinθ5]

参考

LLM - 旋转位置编码 RoPE 代码详解
RoPE旋转位置编码深度解析:理论推导、代码实现、长度外推
图解RoPE旋转位置编码及其特性
Rotary Positional Embeddings (RoPE)

更新时间 2024-07-03