0. 简介
随着chatgpt的爆火,最近也有很多大模型在不断地出现,比如说Bloom系列以及以LLAMA为基础的ziya和baichuan。这些模型相较于chatglm来说,更加具有发展前景,因为其是完全可商用,并可以不断迭代更新的。最近作者在跟着hiyouga大佬的LLaMA-Efficient-Tuning进行学习,相较于其他的项目来说,该项目是非常适合跟着学习并入门的。
1. 什么是SFT
SFT(Scalable Fine-Tuning)是一种用于自然语言处理的技术,它通过对预训练的语言模型进行微调,使其适应特定任务。在大模型SFT中,使用的是大型的预训练语言模型,例如LLAMA、GPT等,这些模型具有数十亿甚至数百亿个参数,可以处理大量的文本数据。
SFT的主要思想是在一个大型的预训练模型的基础上,针对特定的任务对模型进行微调。在微调过程中,模型会根据任务的特点调整模型的参数和结构,以提高模型在该任务上的表现。在微调过程中,可以使用不同的技术,例如数据增强、正则化、优化算法等。
SFT的优点是可以快速地针对不同的任务进行微调,而无需重新训练整个模型。此外,由于使用的是大型的预训练模型,可以利用海量的文本数据进行训练,从而获得更好的性能。不过,SFT也有一些缺点,例如需要大量的计算资源和时间进行微调,以及可能会出现过拟合等问题。
目前常用的SFT方法有P-Tuning v2、LORA、QLoRA、冻结(Freeze)、全参数(full-parameter)等方法。我们先来看一看在LLaMA-Efficient-Tuning中是如何写SFT的
2. 代码阅读–train_sft.py
下面是sft对应大模型的脚本,主要包括模型和数据的准备,数据集的划分,训练和评估等步骤。
首先,代码导入了一些必要的模块和函数。这包括一些用于数据处理、训练、加载预训练模型和绘制损失图的工具函数。(这部分和pt中一样)
# Prepare pretrained model and dataset
model_args, data_args, training_args, finetuning_args = prepare_args(stage="sft")# 用于准备各种参数,包括模型参数、数据参数、训练参数和微调参数。
dataset = prepare_data(model_args, data_args)# 用于准备数据集
model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="sft")# 用于加载sft微调的模型和分词器。
dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="sft")# 用于预处理数据,例如将文本转换为模型可以理解的格式。
data_collator = DynamicDataCollatorWithPadding(tokenizer, data_args.ignore_pad_token_for_loss)# 动态地对数据进行填充,使得每个batch中的数据长度一致。
下面的代码是用于Seq2SeqTrainer的解码参数进行覆盖
# Override the decoding parameters of Seq2SeqTrainer
training_args.generation_max_length = training_args.generation_max_length if \
training_args.generation_max_length is not None else data_args.max_target_length# 设置训练参数(training_args)中的生成最大长度
training_args.generation_num_beams = data_args.eval_num_beams if \
data_args.eval_num_beams is not None else training_args.generation_num_beams # 设置训练参数中的生成束搜索数(generation_num_beams)
然后,根据是否进行训练,对数据集进行划分。如果进行训练,且开发集的比例大于0,那么数据集会被划分为训练集和开发集;否则,全部数据用于训练。如果不进行训练,那么全部数据用于评估或预测。
# Split the dataset
if training_args.do_train:
if data_args.dev_ratio > 1e-6:
dataset = dataset.train_test_split(test_size=data_args.dev_ratio)
trainer_kwargs = {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
else:
trainer_kwargs = {"train_dataset": dataset}
else: # do_eval or do_predict
trainer_kwargs = {"eval_dataset": dataset}
接着,初始化Seq2SeqPeftTrainer对象,传入微调参数、模型、训练参数、分词器、数据处理器、回调函数和计算度量等参数(都是继承自Seq2SeqTrainer),以及前面划分的数据集。这个我们下一节将会仔细阅读里面的操作