当前位置:AIGC资讯 > AIGC > 正文

从AI推理性能优化角度看LLaMA的模型结构和源码

本篇文章讲讲LLaMA的结构,已经有很多文章已经对LLaMA在一些结构上任务表现上做了一些解析,本文主要从优化的角度、实现kernel的角度解析一下LLaMA,读者事先对transformer的结构有基本认识最好。本文首发于我的公众号“AI不止算法”,文章链接在此

LLaMA简单介绍

几个月前,FB开源了LLAMA,LLAMA1包括三个参数量的模型7B、13B、65B, 证明了完全可以通过公开数据集来训练最先进的模型,而无需使用专有和不可获取的数据集,同时LLaMA-13B 在大多数benchmark优于 GPT-3,尽管大小只有后者的1/10。在更大规模上,LLaMA-65B 参数模型也与可以与Chinchilla或PaLM-540B相竞争,这是之前bloom、OPT等没有做到的。本文不谈LLaMA的预训练数据多么多么怎么样,也不谈LLaMA在各个任务上的表现如何,重点从性能优化的角度谈谈LLaMA的模型结构。

模型结构

LLaMA主体结构依然是transformer组成,和其它LLM不同的是:

使用RMSNorm(即Root Mean square Layer Normalization)对每个Transformer子层的input进行Pre Norm 使用激活函数SwiGLU 使用RoPE进行相对位置编码 使用了AdamW优化器,并使用cosine learning rate schedule (AdamW和Adam的区别我不是特别清楚,先放着不讲)

RMSNorm为layerNorm的变体,在分子分母都省去了Mean,同时少了beta参数,虽然不用再计算variance了,但我觉得Welford依然是Normlization类算子性能的最优解

    # RMSNorm
class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps # ε
        self.weight = nn.Parameter(torch.ones(dim)) 
    def _norm(self, x):
        # RMSNorm
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight

激活函数SwiGLU整合了Swish和GLU这两个函数,网上的文章对这一块讲的似懂非懂,不如直接看CUDA源码,我去翻了一下SwiGLU的实现,得出SwiGLU可以理解为SiLU和mul的fused kernel,前者为x * sigmoid(x),本质上来看依然是一个element wise kernel


对于RoPE,这是一个新鲜的玩意,我们要做的就是实现这样一个rotary_embedding kernel , 它作用与QK矩阵上,在QK的batch GEMM之前,采用绝对位置编码来达到相对位置编码的效果,绝对位置编码的优点是计算简单高效,缺点是一般效果不如相对位置编码。相对位置编码的优点是效果较好,缺点是计算效率不如绝对位置编码。在相对位置编码中,注意力权重的结果仅仅和参与注意力计算的token向量的相对位置有关,不和绝对位置直接关联。这符合NLP领域在序列长度方向上具有平移不变性的特点,所以相对位置编码一般效果会优于绝对位置编码。

RoPE公式推导我个人有点看不下去,直接看公式吧,将旋转位置编码过程由GEMM简化成两次向量的哈达玛积求和,这也是一个element wise kernel,要把x给索引好,送给cos和sin相乘


python源代码,还是比较straightforward

LLaMA Attention

和普遍的attention结构没有太大区别,除了把上面的那些新增结构RMS norm,RoPE给添加到各个transformer layer开头和QK之后。想谈论的是Tensor Parallel 版本的attention,这里对qkv的weight采用了列切分,output linear采用了行切分,这循序了NV megatron的张量并行切分思想,有助于最小化多卡通讯开销。

LLaMA MLP

同理对于MLP,也采用了linear的列切分行切分版本,同时把SwiGLU给加了进去

LLaMA TransformerLayer

对于每个layer,把attention和MLP叠起来就完事

Llama generate

transformerlayer出来后的经过LMhead(其实就是个linear)+ softmax得到probs,然后就开始sample,可以topP,可以贪心,可以beam search,主要就看怎么设计了,在这份代码里,采用了topP或贪心,最后再detokenize,吐出token到构造的buffer tokens = torch.full((bsz,total_len), self.tokenizer.pad_id).cuda().long()

class LLaMA:
    def __init__(self, model: Transformer, tokenizer: Tokenizer):
        self.model = model
        self.tokenizer = tokenizer

    def generate(
        self,
        prompts: List[str],
        max_gen_len: int,
        temperature: float = 0.8,
        top_p: float = 0.95,
    ) -> List[str]:
        bsz = len(prompts)
        params = self.model.params
        assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)

        prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts]

        min_prompt_size = min([len(t) for t in prompt_tokens])
        max_prompt_size = max([len(t) for t in prompt_tokens])

        total_len = min(params.max_seq_len, max_gen_len + max_prompt_size)

        tokens = torch.full((bsz, total_len), self.tokenizer.pad_id).cuda().long()
        

        for k, t in enumerate(prompt_tokens):

            tokens[k, : len(t)] = torch.tensor(t).long()
        input_text_mask = tokens != self.tokenizer.pad_id
        start_pos = min_prompt_size
        prev_pos = 0
        # start generate
        for cur_pos in range(start_pos, total_len):
            logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
            if temperature > 0:
                probs = torch.softmax(logits / temperature, dim=-1)
                # sample by top P
                next_token = sample_top_p(probs, top_p)
            else:
                # greedy search
                next_token = torch.argmax(logits, dim=-1)
            next_token = next_token.reshape(-1)
            # only replace token if prompt has already been generated
            next_token = torch.where(
                input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
            )
            tokens[:, cur_pos] = next_token
            prev_pos = cur_pos
        # detokenize
        decoded = []
        for i, t in enumerate(tokens.tolist()):
            # cut to max gen len
            t = t[: len(prompt_tokens[i]) + max_gen_len]
            # cut to eos tok if any
            try:
                t = t[: t.index(self.tokenizer.eos_id)]
            except ValueError:
                pass
            decoded.append(self.tokenizer.decode(t))
        return decoded

# sample the one which is the cum prob < p

def sample_top_p(probs, p):
    probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
    probs_sum = torch.cumsum(probs_sort, dim=-1)
    mask = probs_sum - probs_sort > p
    probs_sort[mask] = 0.0
    probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))

    # extract a sample

    next_token = torch.multinomial(probs_sort, num_samples=1)

    # find next_token's id

        next_token = torch.gather(probs_idx, -1, next_token)

    return next_token

LLaMA 2

LLaMA2在1的基础上又做了一些改进,在模型结构上引入了GQA来降低KV cache的显存占用,以此来增大batch size,获得更高的吞吐量,后面单独开篇文章讲讲MQA和GQA

另外

1、attention mask的构造上面也有一些要点:

_make_causal_mask用于构造下三角这种mask结构以实现语言模型的单向注意力。

_expand_mask用于将mask信息展开成和attention矩阵相同的张量结构。

2、对优化器AdamW的具体实现不是很了解,后续补补课再来聊聊

3、LLM的inference本身并不像general的inference engine或者framework那么有太大的复杂度,主要还是实现那几个kernel,整体我个人感觉在性能优化的角度,还是不会带来太大的额外工作量,多数kernel都可以reuse已有实现

更新时间 2024-03-04