当前位置:AIGC资讯 > AIGC > 正文

传知代码-Llama 2:开放基础和微调聊天模型以及法律判决数据集分类实战(论文复现)

代码以及视频讲解

本文所涉及所有资源均在传知代码平台可获取

本文概述

本文首先会介绍一下Llama2大模型,然后会使用一个公开的中文法律判决数据集(部分)进行Llama2提示学习的分类实战。

论文主要内容

这篇文章介绍了Llama 2,这是由Meta AI团队开发和发布的一系列预训练和微调的大型语言模型(LLMs),参数规模从7亿到70亿不等。Llama 2-Chat是为对话用例优化的微调模型,它在大多数测试基准上的表现超过了现有的开源聊天模型,并且在人类评估的有用性和安全性方面,可能是封闭源模型的合适替代品。

文章详细描述了Llama 2-Chat的微调和安全性改进方法,以便社区能够在这项工作的基础上进行构建,并为LLMs的负责任发展做出贡献。此外,文章还分享了在开发Llama 2和Llama 2-Chat过程中观察到的新现象,例如工具使用的出现和知识的时空组织。

文章的主要内容包括:

介绍:大型语言模型(LLMs)作为AI助手的潜力和能力。 预训练:Llama 2模型的预训练方法,包括数据、训练细节和评估。 微调:包括监督式微调(SFT)和强化学习与人类反馈(RLHF)的方法。 安全性:在预训练和微调阶段采取的安全性措施,以及通过红队测试和迭代评估提高模型安全性的方法。 讨论:在开发过程中的学习和观察,包括模型的局限性和伦理考虑。 相关工作:对大型语言模型领域的相关研究进行回顾。 结论:Llama 2模型的发布对社区和负责任的AI发展的潜在影响。

文章还包含了一些附录部分,提供了作者贡献、预训练和微调的额外细节、安全性的额外信息、数据注释、数据污染问题和模型卡片等。

论文主要贡献

文章的主要贡献可以总结为以下几点:

Llama 2模型的开发与发布:开发了一系列从7亿到70亿参数规模的预训练和微调大型语言模型(LLMs),特别是Llama 2-Chat,这是一个为对话用例优化的微调模型。

性能优化:Llama 2-Chat在多个基准测试中表现优于现有的开源聊天模型,并且在人类评估的有用性和安全性方面与一些封闭源模型相当。

微调方法的详细描述:文章提供了对Llama 2-Chat微调方法的详细描述,包括监督式微调(SFT)和强化学习与人类反馈(RLHF)。

安全性改进:介绍了为提高Llama 2-Chat模型的安全性所做的努力,包括安全特定的数据注释、微调和红队测试。

社区贡献:通过开放Llama 2模型,鼓励社区在这项工作的基础上进行构建,为负责任的LLMs发展做出贡献。

新观察和发现:分享了在开发Llama 2和Llama 2-Chat过程中的新观察,例如工具使用的出现和知识的时空组织。

负责任的发布策略:讨论了如何安全地发布这些模型,以及如何通过负责任的使用指南和代码示例来促进安全部署。

环境影响考量:评估了预训练过程中的碳足迹,并讨论了通过开放模型来减少其他公司重复预训练成本的潜在环境效益。

这些贡献展示了Meta AI在推动大型语言模型技术发展、提高模型性能和安全性、以及促进AI领域负责任创新方面的努力。

Llama 2-chat的训练流程图:

这个过程从使用公开的在线资源对Llama进行预训练开始。然后通过应用监督微调创建了Llama 2-Chat的初始版本。随后,使用带人反馈的强化学习(RLHF)方法,特别是通过拒绝采样和近端策略优化(PPO),对模型进行迭代优化。在整个RLHF阶段,与模型增强并行的迭代奖励建模数据的积累对于确保奖励模型保持在分布范围内至关重要。

Llama2技术细节

预训练

为了创建新的Llama 2模型家族,文章使用了优化的自回归变压器,但进行了一些更改以提高性能。具体来说,文章执行了更稳健的数据清理,更新了数据混合,在40%以上的tokens上进行了训练,将上下文长度增加了一倍,并使用分组查询注意力(GQA)来提高大型模型的推理可扩展性。表1比较了新的Llama 2模型与Llama 1模型的属性。

预训练数据

相较于llama的预训练1.4T个tokens数据,llama2使用了2Ttokens数据进行训练。

这是预训练数据语种分布。可以看到英文占绝大部分,所以原始的模型对英语的效果最好。

模型架构

Llama 2采用了Llama 1中的大部分预训练设置和模型架构:
文章使用标准transformer架构
使用RMSNorm应用预归一化
使用SwiGLU激活函数
旋转位置嵌入(RoPE)

Llama 2与Llama 1的主要架构差异包括增加的上下文长度和分组查询注意力(GQA)

GQA如下:

中间就是GQA,可以看到,Q进行分组,每组共享相同的V和K,这样能够节省显存而又不怎么影响模型效果。

超参数设置

Llama 2使用AdamW优化器进行训练,其中β1=0.9,β2=0.95,eps=10−5。
使用余弦学习率计划,预热2000步,并将最终学习率衰减到峰值学习率的10%。
使用0.1的权重衰减和1.0的梯度剪裁。
图5(a)显示了具有这些超参数的Llama 2的训练损失。

分词器

Llama 2使用与Llama 1相同的标记器;它采用了字节对编码(BPE)算法,使用了来自PensionePiece的实现。与Llama 1一样,将所有数字拆分为单个数字,并使用字节分解未知的UTF-8字符。总词汇大小为32k个标记

因此,Llama 1和Llama 2的一些区别如下:

Llama2预训练模型的评估结果

微调

指令微调

Quality Is All You Need. SFT数据集的质量很重要,万级别的高质量效果就能达到很好的效果。使用供应商精标了27,540条(人工撰写prompt 和 answer,包括 helpfulness 和 safety 两大类 ),发现效果比几百万公开的还要好。

使用余弦学习率,初始学习速率为2e−5,权重衰减为0.1,批量大小为64,序列长度为4096个tokens。

对于微调过程,每个样本都包含一个提示和一个答案。为了确保模型序列长度正确填充,将训练集中的所有提示和答案连接起来。使用一个特殊的标记来分隔提示段和应答段。
使用自回归目标,并从用户提示中消除tokens的损失,因此只对回答tokens进行反向传播。最后,使用两个epoch微调模型。

人的反馈强化学习(RLHF)

RLHF是一种模型训练过程,应用于微调的语言模型,以进一步使模型行为与人类偏好和指令遵循相一致。文章收集的数据代表了经验上9个样本的人类偏好,通过这些数据,人类注释者可以选择他们喜欢的两个模型输出中的哪一个。这种人工反馈随后被用于训练奖励模型,该模型学习人工注释者偏好的模式,然后可以自动进行偏好决策。

人类偏好数据

在表6中,文章报告了随着时间的推移收集的奖励建模数据的统计数据,并将其与多个开源偏好数据集进行了比较,这些数据集包括Anthropic Helpful和Harmless、OpenAI Summary、OpenAI-WebGPT、StackExchange、Stanford Human Preferences和Synthetic GPT-J。作者收集了一个大型数据集,其中包含超过100万个基于人类应用我们指定的指南进行的二进制比较,称之为元奖励建模数据。请注意,提示和答案中的标记数量因文本域而异。摘要和在线论坛数据的提示通常较长,而对话式的提示通常较短。与现有的开源数据集相比,这个偏好数据的特点是对话次数更多,平均时间更长。

奖励模型

文章训练了两个独立的奖励模型,一个针对帮助性(称为helpfulness RM)进行了优化,另一个针对安全性(safety RM)进行了优化。并从预训练的聊天模型检查点初始化奖励模型,因为它确保两个模型都受益于预训练中获得的知识。简而言之,奖励模式“知道”聊天模型知道的是什么。

迭代微调

近端策略优化(PPO)、拒绝采样微调。

这两种强化学习算法的主要区别在于:
•广度-在拒绝抽样中,模型为给定提示探索K个样本,而PPO只进行一次生成。

•深度-在PPO中,在步骤t的训练期间,样本是在前一步梯度更新后从t−1更新的模型策略的函数。在拒绝采样微调中,在应用类似于SFT的微调之前,我们对给定模型初始策略的所有输出进行采样以收集新的数据集。然而,由于应用了迭代模型更新,两种强化学习算法之间的根本差异就不那么明显了

文章还有很多关于模型结果、安全性等方面的详细分析,这里就不再赘述了

中文LLaMA2模型Atom-7B-Chat实战

从Llama2的训练数据就能看得出,这个大模型对中文的支持并不好,因为模型在大量的英文数据集上进行预训练,对英文的支持较好。

因此,我们需要使用Llama2在大量中文数据集预训练微调后的模型上使用才会有比较好的中文效果

Atom就是在Llama2的基础上,采用大规模的中文数据进行持续预训练,包含百科、书籍、博客、新闻、公告、小说、金融数据、法律数据、医疗数据、代码数据、专业论文数据、中文自然语言处理竞赛数据集等

我们可以使用这个模型来做一个简单的对法律案件识别的任务。

数据集介绍

使用的是一个法律判决数据集,因为原来的数据集数量太多,因此我就抽取了其中的三类犯罪类型进行实验。

data_fact是对案件的描述
data_accusation_clear是案件涉及的犯罪名称
data_accusation_labels是犯罪名称的数字标签(0,1,2)

data_fact的一个例子如下:

  七台河市新兴区人民检察院指控:1.2017年3月20日下午,被告人司某某流窜至新兴区北山技工校后侧的被害人李某某家中,趁户内无人,××人民币2000.00元,所得赃款被其挥霍。2.2017年3月23日左右,被告人司某流窜至新兴区北山技工校附近被害人赵某某家中,趁户内无人,××冰柜内的两个羊腿、康师傅牌方便面12袋(鉴定价格人民币208.00元),所得赃物被其食用。3.2017年3月31日,被告人司某某再次流窜至新兴区北山技工校附近赵某某家中,趁户内无人,××康师傅方便面12袋(鉴定价格人民币30.00元)。所得赃物被其食用。4.2015年7-8月份的一天,被告人司某某流窜至新兴区北山技工校附近被害人黄某某家中,趁户内无人,××人民币800.00元。所得赃款被其挥霍。5.2016年7-8月份的一天,被告人司某某流窜至新兴区东升村被害人陈某某家中,趁户内无人,××大米20斤、鸡蛋7斤(鉴定价格人民币81.00元)。所得赃物被其丢失。6.2016年11月份下旬的一天,被告人司某某流窜至新兴区二O四中转站附近被害人刘某某家中,趁户内无人,××人民币300.00元。所得赃款被其挥霍。经侦查,被告人司某某于2017年4月27日被公安机关在七台河市新兴区抓获。,盗窃,0,七台河市新兴区人民检察院指控:1.2017年3月20日下午,被告人司某某流窜至新兴区北山技工校后侧的被害人李某某家中,趁户内无人,××人民币2000.00元,所得赃款被其挥霍。2.2017年3月23日左右,被告人司某流窜至新兴区北山技工校附近被害人赵某某家中,趁户内无人,××冰柜内的两个羊腿、康师傅牌方便面12袋(鉴定价格人民币208.00元),所得赃物被其食用。3.2017年3月31日,被告人司某某再次流窜至新兴区北山技工校附近赵某某家中,趁户内无人,××康师傅方便面12袋(鉴定价格人民币30.00元)。所得赃物被其食用。4.2015年7-8月份的一天,被告人司某某流窜至新兴区北山技工校附近被害人黄某某家中,趁户内无人,××人民币800.00元。所得赃款被其挥霍。5.2016年7-8月份的一天,被告人司某某流窜至新兴区东升村被害人陈某某家中,趁户内无人,××大米20斤、鸡蛋7斤(鉴定价格人民币81.00元)。所得赃物被其丢失。6.2016年11月份下旬的一天,被告人司某某流窜至新兴区二O四中转站附近被害人刘某某家中,趁户内无人,××人民币300.00元。所得赃款被其挥霍。经侦查,被告人司某某于2017年4月27日被公安机关在七台河市新兴区抓获。

其对应的data_accusation_clear如下:

盗窃

表明上述案件中,主要的罪名为盗窃。

其对应的data_accusation_labels如下:

0

表明盗窃这个标签,使用0来表示。

使用到的数据集罪名类型数量分布如下:
盗窃 4947
故意伤害 4407
抢劫 4071

因此可以看到,这是一个三分类的问题

提示模板

考虑到设备的原因,我们只进行prompt提示学习而不对模型本身进行微调,即在零样本的情况下,测试模型对法律判决的能力。

提示模板设置如下:

prompt_template = ['<s>Human: 请你明确指出以下案例是盗窃、故意伤害、抢劫哪一项罪名?请注意,你的答案只能是“盗窃、故意伤害、抢劫”三者其中之一并且答案不能超过4个字。',
                   '\n</s><s>Assistant: ']

这里对模型的评判也比较简单,当模型的结果中包含案件分类的类别时,就判断模型的预测结果为正确,否则就是错误。

实验结果

这是使用20条数据进行测试的结果,可以看到,效果还是不错的,但是提示的约束并不是完全生效,大部分情况都生效,但是也有不符合提示的情况,因此,之后可以尝试更好的提示模板或者使用提示微调等,提升模型的表现效果。

当使用全数据时,当输入数据的长度为512时,三分类准确率为:54.04%
当输入长度为1024时,三分类准确率为:62.32%
可以看到,当给模型的数据越多时,模型的表现确实会提升。

代码运行

安装必要的包

pip install -r requirements.txt  --default-timeout=120 -i https://pypi.tuna.tsinghua.edu.cn/simple

运行代码

python LLaMA-CLF.py

源码下载

更新时间 2024-07-07