0. 简介
随着chatgpt的爆火,最近也有很多大模型在不断地出现,比如说Bloom系列以及以LLAMA为基础的ziya和baichuan。这些模型相较于chatglm来说,更加具有发展前景,因为其是完全可商用,并可以不断迭代更新的。最近作者在跟着hiyouga大佬的LLaMA-Efficient-Tuning进行学习,相较于其他的项目来说,该项目是非常适合跟着学习并入门的。
1. 二次预训练的目的
最近几年来,大量的研究工作表明,大型语料库上的预训练模型(PTM)可以学习通用的语言表征,这对于下游的 NLP 任务是非常有帮助的,可以避免从零开始训练新模型。而随着算力的发展、深层模型(Transformer)出现以及训练技能的不断提高,PTM 体系结构从浅层发展到了深层。
对于大模型而言,其基本训练流程一般可以分为两个阶段:预训练和微调。预训练阶段是模型学习语言知识的阶段,在这个阶段,模型会从大量的文本数据中学习如何理解和生成文本。这一阶段生成的模型是通用的,可以用于处理各种类型的文本任务。微调阶段是在预训练模型的基础上,针对特定任务进行训练,使模型能更好地完成该任务。
然而,由于预训练模型是在大规模的、通用的语料库上进行训练的,所以其学到的知识可能并不是特定任务所需的全部知识。例如,预训练模型可能没有足够的能力理解医学或法律文本,因为这些领域的专业知识可能在预训练语料库中并不充足。
这就是我们需要进行二次预训练的地方。在二次预训练中,我们将模型在特定领域的语料库上进行再训练,这样模型就能学习到更多这个领域的知识,从而更好地完成特定任务。二次预训练可以看作是预训练和微调之间的一个过渡阶段,它既保留了预训练的广泛性,又添加了针对特定任务的专业性。
除此之外,二次预训练还有一个重要的作用,就是可以有效地利用小规模的、高质量的领域数据(通常仍然是GB级别)。在许多情况下,领域数据是昂贵的、难以获得的,因此我们需要尽可能地利用这些数据。通过在大模型上进行二次预训练,我们可以将领域数据的知识充分地转化为模型的性能提升。
2. 代码阅读–train_pt.py
下面是一个预训练模型的脚本,主要包括模型和数据的准备,数据集的划分,训练和评估等步骤。
首先,代码导入了一些必要的模块和函数。这包括一些用于数据处理、训练、加载预训练模型和绘制损失图的工具函数。
# Prepare pretrained model and dataset
model_args, data_args, training_args, finetuning_args = prepare_args(stage="pt")# 用于准备各种参数,包括模型参数、数据参数、训练参数和微调参数。
dataset = prepare_data(model_args, data_args)# 用于准备数据集
model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="pt") # 用于加载预训练的模型和分词器。
dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="pt")# 用于预处理数据,例如将文本转换为模型可以理解的格式。
data_collator = DynamicDataCollatorWithPadding(tokenizer, data_args.ignore_pad_token_for_loss)# 动态地对数据进行填充,使得每个batch中的数据长度一致。
然后,根据是否进行训练,对数据集进行划分。如果进行训练,且开发集的比例大于0,那么数据集会被划分为训练集和开发集;否则,全部数据用于训练。如果不进行训练,那么全部数据用于评估或预测。
if training_args.do_train:
if data_args.dev_ratio > 1e-6:
dataset = dataset.train_test_split(test_size=data_args.dev_ratio)
trainer_kwargs = {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
else:
trainer_kwargs = {"train_dataset": dataset}
else: # do_eval or do_predict
trainer_kwargs = {"eval_dataset": dataset}
接着,初始化PeftTrainer对象,传入微调参数、模型、训练参数、分词器、数据处理器和回调函数等参数,以及前面划分的数据集。这个我们下一节将会仔细阅读里面的操作
trainer = PeftTrainer(
finetuning_args=finetuning_args,
model=model,
args=training_args,
tokenizer=tokenizer,
data_collator=data_collator,
callbacks=[LogCallback()],
**trainer_kwargs
)
在进行训练后,代码会记录训练的结果,并保存模型和训练结果。如果模型在所有进程中的进程号为0,并且设定了绘制损失图,那么会绘制训练损失和评估损失的图。
if training_args.do_train:
train_result = trainer.train()
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
trainer.save_model()
if trainer.is_world_process_zero() and model_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
在进行评估后,代码会计算模型的困惑度(一个常用的语言模型评估指标),并记录评估结果。
if training_args.do_eval:
metrics = trainer.evaluate(metric_key_prefix="eval")
try:
perplexity = math.exp(metrics["eval_loss"])
except OverflowError:
perplexity = float("inf")
metrics["perplexity"] = perplexity
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
3. 预训练处理–peft_trainer.py
这段代码定义了两个类,LogCallback和PeftTrainer。LogCallback类用于在训练过程中记录日志,PeftTrainer类是一个自定义的训练器,用于支持参数效率的检查点。
3.1 LogCallback函数
首先我们来先看一下LogCallback这个函数,它是从TrainerCallback类继承的,主要用于记录训练过程中的信息,比如损失、学习率、训练周期、当前的进度百分比和估计剩余时间等。这些信息会被写入到文件"trainer_log.jsonl"中。
首先存在一个__init__
函数,它在类的实例被创建时执行。这里记录了创建时的时间戳作为训练的开始时间。
def __init__(self):
self.start_time = time.time()
然后会进入on_log
方法,它在当训练过程需要记录日志的时候会被调用。这个方法接收四个参数:args
(训练参数),state
(训练的当前状态),control
(用于控制训练流程的对象),以及kwargs
(其他参数)。
def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs) -> None:
首先在方法体中,首先检查最后一次日志记录中是否包含"loss"这个键。如果不包含,就直接返回。其目的是为了查看是否有训练的loss输出
if "loss" not in state.log_history[-1]:
return
接着,计算从训练开始到现在的总时间(秒),以及到目前为止每步训练所用的平均时间。
cur_time = time.time()
elapsed_time = cur_time - self.start_time
avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0
根据平均时间和剩余的步数,预计训练的剩余时间。
remaining_steps = state.max_steps - cur_steps
remaining_time = remaining_steps * avg_time_per_step
然后,将这些信息,包括当前步数、总步数、损失、奖励、学习率、训练周期、完成的百分比、已用的时间以及预计的剩余时间,保存到一个字典中。
log_dict = {
"current_steps": cur_steps,
"total_steps": state.max_steps,
"loss": state.log_history[-1].get("loss", None),
"reward": state.log_history[-1].get("reward", None),
"learning_rate": state.log_history[-1].get("learning_rate", None),
"epoch": state.log_history[-1].get("epoch", None),
"percentage": round(cur_steps / state.max_steps * 100, 2) if state.max_steps != 0 else 100,
"elapsed_time": str(timedelta(seconds=int(elapsed_time))),
"remaining_time": str(timedelta(seconds=int(remaining_time)))
}
如果输出目录不存在,创建该目录,然后将上述字典以JSON格式写入到名为"trainer_log.jsonl"的文件中。这个文件的每一行都是一个JSON对象,记录了一次日志事件的信息。
os.makedirs(args.output_dir, exist_ok=True)
with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a") as f:
f.write(json.dumps(log_dict) + "\n")
3.2 PeftTrainer函数(同chatglm方法)
PeftTrainer类是从Seq2SeqTrainer类继承的,专门用于处理序列到序列的模型。这个类的构造函数接收一个FinetuningArguments对象,该对象包含微调过程的参数。
首先是__init__
函数,它在类的实例被创建时执行。这里首先调用父类的构造函数,并存储微调参数。然后,如果当前进程是主进程(编号为0),并且输出目录中已经存在"log file"文件,那么就删除这个文件。(这个和上面不冲突,因为在主函数里面LogCallback函数是作为一个callback函数返回的)