当前位置:AIGC资讯 > AIGC > 正文

Llama开源代码详细解读(2)

FlashAttention

if is_flash_attn_available(): # 检查flashattention的可用性
    from flash_attn import flash_attn_func, flash_attn_varlen_func
    from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input  # noqa

FlashAttention是Tranformer模型中用于改进注意力机制的技术,主要目的是减少计算复杂度和内存占用。

flash_attn_func用于标准的flashattention计算。 flash_attn_varlen_func用于处理变长序列(长度未能确定)的flashattention计算。 index_first_axis用于处理第一个索引轴。 pad_input将数据进行填充处理,从而确定长度。 unpad_input将填充后的输入还原为原始形态。

Logging模块

logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "LlamaConfig"

创建了名为logger的日志记录器对象,__name__用于保存模块的名称,确保每个模块都有自己的日志记录器。
_CONFIG_FOR_DOC前面带有下划线,因此可以看出其代表一个模块的内部变量。

get_unpad_data模块

def _get_unpad_data(padding_mask):
    seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32)
    indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten()
    max_seqlen_in_batch = seqlens_in_batch.max().item()
    cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
    return (
        indices,
        cu_seqlens,
        max_seqlen_in_batch,
    )

该模块的作用是padding_mask提取非填充的数据,分为以下几步:

seqlens_in_batch计算量每个张量的有效长度,sum()函数计算每个张量的有效长度。dim等于-1意味着以最按照最后一个维度进行求和,如果是二维,就可以理解为对跨列操作,即计算了每一行非填充元素的个数。 indices获取了非填充元素的索引。flatten()函数将张量展开成一维,as_tuple为flase意味着返回不是元组形式而是二维矩阵形式,由于返回的是二维矩阵,因此我们需要flatten()再次展平成一维。
——为什么不返回元组呢?
如果返回元组,那么返回的格式是包含一个一维张量的元组,然后还需要从元组中取出这个一维张量,类似:
torch.nonzero(padding_mask.flatten(), as_tuple=True)[0]

这样比较麻烦,不如直接返回二维数组再展平。
3. max_seqlen_in_batch获取了在seqlens_in_batch中的最大值并返回(即长度最长的那一个),然后 item()函数的作用是将一个元素的张量转换为python对应的标量,即一个数。
4. cu_seqlens计算累计长度并进行填充。cumsum()函数用于计算指定维度的累计和,(1,0)意味着只在左边添加一个元素,右边不添加。F.pad()是为张量进行填充的函数。这对于处理变长序列非常有用,因为即获得了每个序列的开始索引,容易确定起始和结束位置。
5. 最终返回的包括:非零元素的索引,左边填充过了的累计长度,最长序列的长度。
从而,达到了提取非填充数据的目的。

_make_causal_mask模块

def _make_causal_mask(
    input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
):
    """
    Make causal mask used for bi-directional self-attention.
    """
    bsz, tgt_len = input_ids_shape
    mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
    mask_cond = torch.arange(mask.size(-1), device=device)
    mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
    mask = mask.to(dtype)

    if past_key_values_length > 0:
        mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
    return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)

该模块用于生成因果掩码,通常用于双向自注意力机制。具体来说,该模块保证在计算注意力时,只能看到当前时间步之前的信息,而看不到未来的,来保持因果关系。

输入参数

input_ids_shape: torch.Size:输入张量的形状,通常为(batch_size, target_length)。
dtype: torch.dtype:用于生成掩码的张量类型。
device: torch.device:指定设备是GPU还是CPU。
past_key_values_length: int = 0:过去的键值对长度,用于增量计算。

步骤

获取批大小(bsz)和目标长度(tgt_len)。 创建初始掩码:形状为(tgt_len,tgt_len),所有的dtpye设置为最小,通常为负无穷。torch.full()函数创建一个均为min的矩阵 设置掩码条件:mask_cond生成一个mask最后一个维度大小-1长度的序列,并放置在指定设备 mask.masked_fill:将下三角矩阵设置为0。这里用到了pytorch的广播机制,将一个行为1和列为1的向量扩充进行比较,从而将下三角都变为了0。 将掩码转换为指定的数据类型。 如果past_key_values_length>0,那么就在最后一个维度拼接上一个(tgt_len, past_key_values_length)的张量,这是为了在处理增量计算时,能够考虑过去的键值对。
其中,zeros()函数创建了(tgt_len,past_key_value_length)的全零矩阵,用cat()在mask前添加了一个全0块。 最终将遮罩的维度扩展为四维形状(bsz, 1, tgt_len, tgt_len + past_key_values_length),并返回。

总结

### FlashAttention 技术概述
FlashAttention 作为一种在 Transformer 模型中用于改进注意力机制的技术,其核心优势在于显著降低计算复杂性和内存占用,从而提高处理大规模数据或多任务时的效率和性能。该技术通过特定的 `flash_attn_func` 和 `flash_attn_varlen_func` 函数实现标准与变长序列的注意力计算,同时还配合了 `index_first_axis`、`pad_input` 和 `unpad_input` 等工具函数,以处理数据预处理和还原的需求。
### Logging 模块与配置
文章中提到,通过创建名为 `logger` 的日志记录器对象,并使用 `__name__` 来确保每个模块都能拥有独立的日志记录器,便于追踪和调试。同时,内部变量 `_CONFIG_FOR_DOC` 被引入以指定文档或配置的相关信息。
### 数据处理模块:get_unpad_data
该模块关键在于处理和提取非填充数据,主要涉及以下步骤:
1. **计算有效长度**:利用 `sum(dim=-1)` 方法计算每个批量内元素的非填充长度。
2. **获取非填充索引**:通过 `torch.nonzero()` 和 `flatten()` 确定所有非填充元素的索引。
3. **找到最长序列**:`max().item()` 用于确定批次中最长的序列长度。
4. **计算累计长度**:`torch.cumsum()` 联合 `F.pad()` 为处理变长序列提供每个子序列的开始与结束位置。
上述过程最终返回非零元素的索引、进行了填充的累计长度以及最长序列的长度,为进一步的数据处理与注意力计算打下基础。
### 掩码生成模块:_make_causal_mask
此模块旨在生成因果掩码,用于双向自注意力机制中,确保模型在注意力计算时仅能访问到当前时间步及其之前的信息,防止了未来信息的泄露。具体步骤包括:
- **初始化掩码**:创建一个全为最小值(通常为负无穷)的掩码矩阵,形状与目标序列长度一致。
- **设置掩码下三角**:利用条件索引和 `torch.masked_fill_` 将掩码矩阵的下三角设置为0,以允许对当前及之前信息的可见性。
- **考虑历史数据**:如存在过去的键值对(`past_key_values_length > 0`),则在掩码的最后一维前拼接一个与过去键值对长度相等的全零矩阵。
- **维度扩展**:最终,根据批次大小和目标长度扩展掩码的维度,以满足注意力机制的输入需求。
通过上述步骤,该模块灵活地构建了适用于自注意力机制的掩码,支持了模型在数据处理中的因果关系需求。

更新时间 2024-09-13