当前位置:AIGC资讯 > AIGC > 正文

LLama-Factory使用教程

本文是github项目llama-factory的使用教程
注意,最新的llama-factory的github中训练模型中,涉及到本文中的操作全部使用了.yaml配置。
新的.yaml的方式很简洁但不太直观,本质上是一样的。新的readme中的.yaml文件等于下文中的bash指令

PS: 大模型基础和进阶付费课程(自己讲的):《AIGC大模型理论与工业落地实战》-CSDN学院 或者《AIGC大模型理论与工业落地实战》-网易云课堂。感谢支持!

一,数据准备和模型训练

step1-下载项目:

​ 从github中克隆LLaMa-Factory项目到本地

step2-准备数据:

​ 将原始LLaMA-Factory/data/文件夹下的dataset_info.json,增加本地的数据。注意,本地数据只能改成LLama-Factory接受的形式,即本地数据只能支持”promtp/input/output“这种对话的格式,不支持传统的文本分类/实体抽取/关系抽取等等schema数据,如果需要,请想办法改成对话形式的数据。

​ 你需要参考其中的一个文件和它的配置,例如:alpaca_gpt4_data_zh.json,训练和验证数据同样改成这种格式,并在dataset_info.json中新增一个你自己的字典:

{
  "alpaca_en": {
    "file_name": "alpaca_data_en_52k.json",
    "file_sha1": "607f94a7f581341e59685aef32f531095232cf23"
},  
...

"your_train": {
    "file_name": "/path/to/your/train.json",
    "columns": {
      "prompt": "instruction",
      "query": "input",
      "response": "output"
    }
  },
...

​ 其中的key,your_train,将在训练/测试的shell命令中使用

step3-模型训练:

​ 数据准备好之后,编写shell脚本训练模型,以mixtral为例根目录下新建run_mixtral.sh

需要改动的主要是:model_name_or_path,dataset,output_dir;和其他可选的改动信息,例如save_steps,per_device_train_batch_size等等。

CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
    --stage sft \
    --do_train \
    --model_name_or_path /path/to/your/Mixtral-8x7B-Instruct-v0.1 \
    --dataset my_train \
    --template default \
    --finetuning_type lora \
    --lora_target q_proj,v_proj \
    --output_dir ./output/mixtral_train \
    --overwrite_output_dir \
    --overwrite_cache \
    --per_device_train_batch_size 4 \
    --gradient_accumulation_steps 4 \
    --lr_scheduler_type cosine \
    --logging_steps 10 \
    --save_steps 200 \
    --learning_rate 5e-5 \
    --num_train_epochs 1.0 \
    --plot_loss \
    --quantization_bit 4 \
    --fp16
step4-模型融合

​ 模型融合的意义在于合并训练后的lora权重,保持参数和刚从huggingface中下载的一致,以便更加方便地适配一些推理和部署框架

​ 基本流程/原理:将微调之后的lora参数,融合到原始模型参数中,以mixtral为例新建:LLama-Factory/run_mixtral_fusion.sh:

python src/export_model.py \
    --model_name_or_path path_to_huggingface_model \
    --adapter_name_or_path path_to_mixtral_checkpoint \
    --template default \
    --finetuning_type lora \
    --export_dir path_to_your_defined_export_dir \
    --export_size 2 \
    --export_legacy_format False
step5-模型推理

​ 模型推理即模型在新的验证集上的推理和验证过程

​ 指令和训练的基本一致,只是差别几个参数:

        1.增加了do_predict,2.数据集改成一个新的eval数据集

​ LLama-Factory/runs/run_mixtral_predict.sh

CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
    --stage sft \
    --do_predict \
    --model_name_or_path /path/to/huggingface/Mixtral-8x7B-Instruct-v0.1 \
    --adapter_name_or_path /path/to/mixtral_output/checkpoint-200 \
    --dataset my_eval \
    --template default \
    --finetuning_type lora \
    --output_dir ./output/mixtral_predict \
    --per_device_eval_batch_size 4 \
    --predict_with_generate \
    --quantization_bit 4 \
    --fp16
step6-API接口部署

部署接口的作用是可以让你把接口开放出去给到外部调用

CUDA_VISIBLE_DEVICES=0 API_PORT=8000 python src/api_demo.py \
    --model_name_or_path path_to_llama_model \
    --adapter_name_or_path path_to_checkpoint \
    --template default \
    --finetuning_type lora

总结

# LLaMa-Factory 项目使用教程总结
## 引言
本文是GitHub项目**LLaMa-Factory**的使用教程,介绍了如何使用该项目进行数据准备、模型训练、融合、推理以及API接口部署的过程。最新版本的LLaMa-Factory中,训练模型操作全部采用`.yaml`配置文件,这些文件虽然简洁但可能不直观,但本质上与之前的bash指令相同。
## 一、数据准备和模型训练
### Step 1: 下载项目
- 从GitHub克隆LLaMa-Factory项目到本地。
### Step 2: 准备数据
- 在`LLaMA-Factory/data/`文件夹下的`dataset_info.json`文件中增加本地数据,数据格式为`"promtp/input/output"`的对话形式,不支持传统的文本分类/实体抽取等格式。
- 修改或添加本地数据配置文件,如`your_train`,确保路径和列名(`prompt`、`query`、`response`)正确。
### Step 3: 模型训练
- 编写shell脚本(如`run_mixtral.sh`)进行模型训练,主要配置包括模型路径、数据集、输出目录等。
- 示例:使用mixtral模型进行训练,包括设置CUDA、训练参数、学习率、批大小等配置。
### Step 4: 模型融合
- 模型融合是将训练后的LORA权重合并到原始模型中,以适应不同的推理和部署框架。
- 编写融合脚本(如`run_mixtral_fusion.sh`),指定原始模型路径、LORA权重路径、输出路径等。
### Step 5: 模型推理
- 模型推理是在新的数据集上进行验证和推理的过程。
- 修改训练脚本的参数,包括增加`do_predict`、更换数据集为评估数据集。
### Step 6: API接口部署
- API部署允许开放接口给外部调用。
- 编写API部署脚本(如`api_demo.py`),设置模型路径、检查点路径、端口号等配置。
## 注意事项
- 最新版本的LLaMa-Factory使用`.yaml`配置文件来进行模型训练,与之前的bash指令相比更为简洁但可能不够直观。
- 推荐的大模型基础和进阶付费课程包括《AIGC大模型理论与工业落地实战》在CSDN学院和网易云课堂有售。
通过以上步骤,你可以顺利完成LLaMa-Factory项目的数据准备、模型训练、融合、推理以及API接口部署工作。

更新时间 2024-08-01