当前位置:AIGC资讯 > AIGC > 正文

【HuggingFace Transformers】LlamaModel源码解析

LlamaModel源码解析

1. LlamaModel 介绍 2. LlamaModel类 源码解析 3. 4维因果注意力掩码生成

1. LlamaModel 介绍

LlamaModel 是一个基于 Transformer 架构的解码器模型,用于自然语言处理任务。它是 Meta 的 LLaMA (Large Language Model Meta AI) 系列的一部分,设计用于生成任务和自回归文本生成。它通过解码器层、位置编码和归一化层来处理输入序列,并提供了对缓存和注意力机制的支持。它在大规模自然语言生成任务中表现出色,并能够处理复杂的序列依赖关系。其结构如下:

2. LlamaModel类 源码解析

源码地址:transformers/src/transformers/models/llama/modeling_llama.py

# -*- coding: utf-8 -*-
# @time: 2024/8/28 14:36
"""PyTorch LLaMA model."""

import torch

from typing import List, Optional, Tuple, Union
from torch import nn
from transformers import LlamaPreTrainedModel, LlamaConfig, Cache, DynamicCache, StaticCache
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRMSNorm, LlamaRotaryEmbedding, LLAMA_START_DOCSTRING, LLAMA_INPUTS_DOCSTRING
from transformers.utils import logging, add_start_docstrings, add_start_docstrings_to_model_forward

logger = logging.get_logger(__name__)


@add_start_docstrings(
    "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
    LLAMA_START_DOCSTRING,
)
class LlamaModel(LlamaPreTrainedModel):
    """
    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]

    Args:
        config: LlamaConfig
    """

    def __init__(self, config: LlamaConfig):
        super().__init__(config)
        self.padding_idx = config.pad_token_id  # 设置 padding token 的索引
        self.vocab_size = config.vocab_size  # 设置词汇表的大小

        # 1. 定义嵌入层:将输入的 token 转换为隐状态向量。它的大小为 vocab_size x hidden_size
        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
        # 2. 定义解码层:使用 nn.ModuleList 定义了一系列的 LlamaDecoderLayer,解码层的数量由 config.num_hidden_layers 决定。
        self.layers = nn.ModuleList(
            [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
        )
        # 3. 定义规范化层:使用 LlamaRMSNorm 进行层归一化处理
        self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        # 4. 定义旋转嵌入:使用 LlamaRotaryEmbedding 实现旋转嵌入,用于改进注意力机制中的位置编码
        self.rotary_emb = LlamaRotaryEmbedding(config=config)

        # 梯度检查点:用于在训练过程中节省内存的功能,默认为 False。
        self.gradient_checkpointing = False

        # Initialize weights and apply final processing
        self.post_init()  # 初始化权重并进行最终处理

    def get_input_embeddings(self):
        return self.embed_tokens

    def set_input_embeddings(self, value):
        self.embed_tokens = value

    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
    def forward(
        self,
        input_ids: torch.LongTensor = None,  # 输入的 token ID
        attention_mask: Optional[torch.Tensor] = None,  # 注意力掩码
        position_ids: Optional[torch.LongTensor] = None,  # 位置 ID
        past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,  # 缓存的 key-value 对
        inputs_embeds: Optional[torch.FloatTensor] = None,  # 输入嵌入
        use_cache: Optional[bool] = None,  # 是否使用缓存
        output_attentions: Optional[bool] = None,  # 是否输出注意力权重
        output_hidden_states: Optional[bool] = None,  # 是否输出隐藏状态
        return_dict: Optional[bool] = None,  # 是否返回字典类型的输出
        cache_position: Optional[torch.LongTensor] = None,  # 缓存位置
    ) -> Union[Tuple, BaseModelOutputWithPast]:

        # -----------------------------1. 初始化一系列输入变量,用于 decoder_layer 的前向传播计算-------------------------------
        # 初始化 output_attentions / output_hidden_states / use_cache / return_dict
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 输入验证:确保 input_ids 和 inputs_embeds 不能同时被指定,但必须指定其中之一。
        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError(
                "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
            )

        # 如果训练过程中使用梯度检查点且使用缓存,则发出警告,并禁用缓存
        if self.gradient_checkpointing and self.training and use_cache:
            logger.warning_once(
                "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
            )
            use_cache = False

        # 嵌入计算:根据 input_ids 计算 inputs_embeds,如果已经提供 inputs_embeds,则使用该值。
        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        # 如果使用旧的缓存格式,将其转换为新的格式
        return_legacy_cache = False
        if (
            use_cache and not isinstance(past_key_values, Cache) and not self.training
        ):  # kept for BC (non `Cache` `past_key_values` inputs)
            return_legacy_cache = True
            past_key_values = DynamicCache.from_legacy_cache(past_key_values)
            logger.warning_once(
                "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. "
                "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)"
            )

        # 如果没有指定缓存位置,则根据已处理的 token 数量设置缓存位置
        if cache_position is None:
            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
            cache_position = torch.arange(
                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            )

        # 如果没有指定位置 ID,则使用缓存位置作为位置 ID
        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        # 更新因果掩码,用于确保解码器只看见当前时间步之前的 token
        causal_mask = self._update_causal_mask(
            attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
        )

        hidden_states = inputs_embeds

        # create position embeddings to be shared across the decoder layers
        # 位置编码:生成位置编码,用于结合输入嵌入进行旋转嵌入
        position_embeddings = self.rotary_emb(hidden_states, position_ids)

        # -----------------------------2. 初始化一系列输出变量,用于保存 decoder_layer 前向传播计算的输出结果-------------------------------
        # decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None
        next_decoder_cache = None

        # -----------------------------3. 依次通过每一层,执行 decoder_layer 前向传播计算,同时更新相应的变量值-------------------------------
        # 遍历每一层解码器层,并将输入和注意力掩码传递给每一层
        for decoder_layer in self.layers:
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            # 如果使用梯度检查点,则使用特殊方法处理解码器层
            if self.gradient_checkpointing and self.training:
                layer_outputs = self._gradient_checkpointing_func(
                    decoder_layer.__call__,
                    hidden_states,
                    causal_mask,
                    position_ids,
                    past_key_values,
                    output_attentions,
                    use_cache,
                    cache_position,
                    position_embeddings,
                )
            else:
                # 否则,直接调用解码器层的前向传播方法
                layer_outputs = decoder_layer(
                    hidden_states,
                    attention_mask=causal_mask,
                    position_ids=position_ids,
                    past_key_value=past_key_values,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                    cache_position=cache_position,
                    position_embeddings=position_embeddings,
                )

            # 更新隐藏状态
            hidden_states = layer_outputs[0]

            # 如果使用缓存,则更新缓存
            if use_cache:
                next_decoder_cache = layer_outputs[2 if output_attentions else 1]

            # 如果输出注意力权重,则将其添加到所有注意力权重的列表中
            if output_attentions:
                all_self_attns += (layer_outputs[1],)

        # -----------------------------4. 获取相应的输出变量,根据条件进行处理后返回结果-------------------------------
        # 最后一层的隐藏状态经过归一化处理
        hidden_states = self.norm(hidden_states)

        # add hidden states from the last decoder layer
        # 如果输出隐藏状态,则将最终隐藏状态添加到元组中
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        # 根据是否使用缓存,返回下一个缓存
        next_cache = next_decoder_cache if use_cache else None
        if return_legacy_cache:
            next_cache = next_cache.to_legacy_cache()

        # 输出处理
        # 如果不返回字典,则返回元组类型的输出
        if not return_dict:
            return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
        # 返回 BaseModelOutputWithPast 类型的字典,包括最后的隐藏状态、缓存、隐藏状态和注意力权重
        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=next_cache,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
        )

    # 更新因果掩码方法,确保模型只能看到当前时间步之前的 token
    def _update_causal_mask(
        self,
        attention_mask: torch.Tensor,
        input_tensor: torch.Tensor,
        cache_position: torch.Tensor,
        past_key_values: Cache,
        output_attentions: bool,
    ):
        # TODO: 自 torch==2.2.0 以来,传递给模型的 `attention_mask` 在生成过程中是 2D 的,并且长度是动态的,即使使用了静态 KV 缓存。
        # 这会导致 torch.compile 在每个解码步骤中重新捕获 cudagraphs,因为形状是动态的,这是非常慢的。
        # 一种解决方法是使用 `@torch.compiler.disable`,但这会阻止使用 `fullgraph=True`。
        # 更多背景信息可以参考 https://github.com/huggingface/transformers/pull/29114

        # --------------------------flash_attention_2 注意力和 sdpa 注意力的配置判断-------------------------------
        # 如果配置的注意力实现是 "flash_attention_2"
        if self.config._attn_implementation == "flash_attention_2":
            # 如果提供了 attention_mask 且其中包含 0.0,则直接返回 attention_mask;否则返回 None,表示不需要额外处理
            if attention_mask is not None and 0.0 in attention_mask:
                return attention_mask
            return None

        # 对于 SDPA(Scaled Dot-Product Attention),我们将依赖它的 `is_causal` 参数而不是 `attn_mask` 参数,以便分派到 Flash Attention 2 实现。这种特性与静态缓存不兼容,因为 SDPA 无法推断出注意力掩码。
        past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0  # 获取已经看到的 token 数量
        using_static_cache = isinstance(past_key_values, StaticCache)  # 检查是否使用静态缓存

        # 当 output_attentions 为 True 时,SDPA 实现的前向传播方法会调用 eager(迫切)实现的前向传播方法
        if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
            # 检查是否可以忽略 SDPA 的因果掩码
            if AttentionMaskConverter._ignore_causal_mask_sdpa(
                attention_mask,
                inputs_embeds=input_tensor,
                past_key_values_length=past_seen_tokens,
                is_training=self.training,
            ):
                return None

        # -------------------------初始化一系列输入变量,用于 4d_causal_attention_mask 的计算--------------------------
        dtype, device = input_tensor.dtype, input_tensor.device  # 获取输入张量的 dtype 和设备信息
        min_dtype = torch.finfo(dtype).min  # 获取 dtype 的最小值,用于填充掩码
        sequence_length = input_tensor.shape[1]  # 获取序列长度

        # 如果使用静态缓存,target_length为缓存中已看到的最大长度
        if using_static_cache:
            target_length = past_key_values.get_max_length()
        else:
            # 否则target_length为注意力掩码的最后一个维度长度,或者已看到的 token 数量加上当前序列长度再加 1
            target_length = (
                attention_mask.shape[-1]
                if isinstance(attention_mask, torch.Tensor)
                else past_seen_tokens + sequence_length + 1
            )

        # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
        # 提供的 `attention_mask` 是 2D 的,我们在这里生成一个因果掩码(4D 的)。
        causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
            attention_mask,
            sequence_length=sequence_length,
            target_length=target_length,
            dtype=dtype,
            device=device,
            min_dtype=min_dtype,
            cache_position=cache_position,
            batch_size=input_tensor.shape[0],
        )

        # --------------------输出结果 causal_mask 的进一步操作(可选)----------------------
        # 如果配置的注意力实现是 "sdpa",且 `attention_mask` 存在,并且设备类型为 CUDA 且不输出注意力权重
        if (
            self.config._attn_implementation == "sdpa"
            and attention_mask is not None
            and attention_mask.device.type == "cuda"
            and not output_attentions
        ):
            # 在因果掩码中完全掩盖的行中,使所有 token 可见,例如使用左填充时的相关第一行。这是为了适应 F.scaled_dot_product_attention 的内存高效路径。详情请参考:https://github.com/pytorch/pytorch/issues/110213
            causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)

        # ---------------------返回最终的因果掩码------------------------------------------
        return causal_mask

3. 4维因果注意力掩码生成

_prepare_4d_causal_attention_mask_with_cache_position 函数用于生成一个4维的因果注意力掩码(causal attention mask),适用于生成任务中的自回归解码器。这个掩码有助于确保模型在生成序列时仅能基于当前和之前的 token,而不查看未来的 token。以下是对该函数的源码解释:

# -*- coding: utf-8 -*-
# @time: 2024/8/28 14:36

# 生成4D的因果注意力掩码方法
def _prepare_4d_causal_attention_mask_with_cache_position(
    attention_mask: torch.Tensor,
    sequence_length: int,
    target_length: int,
    dtype: torch.dtype,
    device: torch.device,
    min_dtype: float,
    cache_position: torch.Tensor,
    batch_size: int,
):
    """
    Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
    `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.

    Args:
        attention_mask (`torch.Tensor`):
            A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
        sequence_length (`int`):
            The sequence length being processed.
        target_length (`int`):
            The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
        dtype (`torch.dtype`):
            The dtype to use for the 4D attention mask.
        device (`torch.device`):
            The device to plcae the 4D attention mask on.
        min_dtype (`float`):
            The minimum value representable with the dtype `dtype`.
        cache_position (`torch.Tensor`):
            Indices depicting the position of the input sequence tokens in the sequence.
        batch_size (`torch.Tensor`):
            Batch size.
    """
    # 1. 检查掩码维度:如果输入的 attention_mask 是4维的,直接使用它作为 causal_mask,因为它已经是所需的形式。
    if attention_mask is not None and attention_mask.dim() == 4:
        # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
        causal_mask = attention_mask
    else:
        # 2. 生成默认4D掩码:创建一个全0的2D掩码,形状为 (sequence_length, target_length),并用 min_dtype 填充。
        causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
        # 如果 sequence_length 不等于1,将掩码的上三角部分设置为 min_dtype,以创建因果掩码(上三角矩阵,确保每个位置只能关注自己及之前的位置)。
        if sequence_length != 1:
            causal_mask = torch.triu(causal_mask, diagonal=1)
        # 3. 调整掩码以考虑缓存位置:根据 cache_position 计算掩码的位置,causal_mask 的值将根据 cache_position 进行调整,以便正确处理缓存。
        causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
        # 4. 扩展掩码以适应批处理:将掩码扩展到4维,并适应批处理大小 (batch_size),最终的形状为 (batch_size, 1, sequence_length, target_length)。
        causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
        # 5. 融合外部注意力掩码:如果提供了外部的 attention_mask,则将其与生成的 causal_mask 结合。通过掩码位置设置正确的填充,以确保只关注有效位置。
        if attention_mask is not None:
            causal_mask = causal_mask.clone()  # copy to contiguous memory for in-place edit
            mask_length = attention_mask.shape[-1]
            padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
            padding_mask = padding_mask == 0
            causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
                padding_mask, min_dtype
            )

    return causal_mask

总结

更新时间 2024-09-21