当前位置:AIGC资讯 > AIGC > 正文

LLaMA-Factory/scripts/length_cdf.py 源码解析

这段代码定义了一个函数 length_cdf,用来计算和打印数据集样本长度的累积分布函数(CDF),并在脚本直接运行时通过 fire 库将该函数暴露为命令行接口。我们逐行解释这段代码:

python

复制

from llmtuner.data import get_dataset
from llmtuner.hparams import get_train_args
from llmtuner.model import load_tokenizer
从 llmtuner 模块中导入 get_datasetget_train_args 和 load_tokenizer 函数。

python

复制

def length_cdf(
    model_name_or_path: str,
    dataset: Optional[str] = "alpaca_en",
    dataset_dir: Optional[str] = "data",
    template: Optional[str] = "default",
    interval: Optional[int] = 1000,
):
定义了一个名为 length_cdf 的函数,接受以下参数: model_name_or_path: 模型的名称或路径(字符串类型)。 dataset: 数据集的名称,默认为 "alpaca_en"(可选)。 dataset_dir: 数据集的目录,默认为 "data"(可选)。 template: 模板名称,默认为 "default"(可选)。 interval: 计算长度分布的区间,默认为 1000(可选)。

python

复制

    model_args, data_args, training_args, _, _ = get_train_args(
        dict(
            stage="sft",
            model_name_or_path=model_name_or_path,
            dataset=dataset,
            dataset_dir=dataset_dir,
            template=template,
            cutoff_len=1_000_000,
            output_dir="dummy_dir",
            overwrite_cache=True,
        )
    )
调用 get_train_args 函数获取模型参数、数据参数和训练参数。 传递一个字典作为参数,其中包括: stage: 训练阶段,这里是 "sft"(假设是某种微调)。 model_name_or_path: 模型的名称或路径。 dataset: 数据集名称。 dataset_dir: 数据集目录。 template: 模板名称。 cutoff_len: 截断长度,设置为 1,000,000output_dir: 输出目录,这里是 "dummy_dir"overwrite_cache: 是否覆盖缓存,设置为 True

python

复制

    tokenizer = load_tokenizer(model_args)
调用 load_tokenizer 函数加载分词器,传入 model_args

python

复制

    trainset = get_dataset(tokenizer, model_args, data_args, training_args, stage="sft")
调用 get_dataset 函数获取训练数据集,传入 tokenizermodel_argsdata_args 和 training_args,指定 stage 为 "sft"

python

复制

    total_num = len(trainset)
计算训练数据集的总样本数,存储在 total_num 中。

python

复制

    length_dict = defaultdict(int)
创建一个默认值为 0 的字典 length_dict,用来统计样本长度的分布。   
    for sample in tqdm(trainset["input_ids"]):
        length_dict[len(sample) // interval * interval] += 1
遍历训练数据集中的每个样本(input_ids)。 计算每个样本长度所在的区间(len(sample) // interval * interval),并在 length_dict 中进行计数。

python

复制

    length_tuples = list(length_dict.items())
    length_tuples.sort()
将 length_dict 转换为一个包含长度区间和计数的元组列表 length_tuples。 对列表进行排序。

python

复制

    count_accu, prob_accu = 0, 0
初始化累积计数 count_accu 和累积概率 prob_accu

python

复制

    for length, count in length_tuples:
        count_accu += count
        prob_accu += count / total_num * 100
        print("{:d} ({:.2f}%) samples have length < {}.".format(count_accu, prob_accu, length + interval))
遍历 length_tuples 中的每个长度区间和计数。 更新累积计数 count_accu 和累积概率 prob_accu

python

复制

        print("{:d} ({:.2f}%) samples have length < {}.".format(count_accu, prob_accu, length + interval))
打印当前长度区间的累积计数和累积概率。输出格式为:count_accu (prob_accu%) samples have length < length + interval.

python

复制

if __name__ == "__main__":
    fire.Fire(length_cdf)
检查脚本是否作为主程序运行。 如果是,则使用 fire.Fire 将 length_cdf 函数暴露为命令行接口,使得可以通过命令行参数直接调用该函数。

总结

这段代码的主要功能是计算数据集样本长度的累积分布函数(CDF),并打印每个长度区间的累积计数和累积概率。具体步骤包括:

从命令行参数获取模型和数据集的相关参数。 加载分词器和数据集。 统计样本长度的分布。 计算并打印累积分布。

示例用法

假设你将这段代码保存为 length_cdf.py,可以通过以下命令运行并查看结果:

sh

复制

python length_cdf.py --model_name_or_path "path/to/model" --dataset "custom_dataset" --dataset_dir "path/to/data"

这将根据提供的模型和数据集路径,计算并打印数据集中样本长度的累积分布情况。

为什么要用fire库,有什么好处,用和不用有什么区别

GPT-4o

Poe

使用 fire 库的主要目的是简化命令行接口的创建,使得 Python 脚本可以更容易地通过命令行参数进行调用和控制。下面我们详细解释一下 fire 库的好处,以及使用和不使用 fire 库的区别。

fire 库的好处

自动生成命令行接口:

fire 可以自动将 Python 对象(函数、类、模块等)转换为命令行接口。开发者无需手动解析命令行参数,fire 会自动处理这些工作。

简化代码:

使用 fire 只需一行代码就能将函数或类暴露为命令行接口,极大地简化了代码量。

易用性:

fire 提供了一致且直观的方式来调用 Python 脚本,降低了使用命令行参数的复杂性。

灵活性:

fire 支持多种调用方式,包括函数、类和模块,使得开发者可以根据需要灵活选择如何组织代码。

使用和不使用 fire 的区别

使用 fire 库

使用 fire 库的代码示例如下:

python

复制

import fire

def greet(name="World"):
    print(f"Hello, {name}!")

if __name__ == "__main__":
    fire.Fire(greet)
当你运行这个脚本时,可以通过命令行参数直接调用 greet 函数并传递参数:

sh

复制

python script.py --name John
输出:

复制

Hello, John!
不使用 fire 库

如果不使用 fire 库,你需要手动解析命令行参数,例如使用 argparse 库:

import argparse

def greet(name="World"):
    print(f"Hello, {name}!")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Greet someone.")
    parser.add_argument("--name", type=str, default="World", help="Name of the person to greet.")
    args = parser.parse_args()
    greet(args.name)
当你运行这个脚本时,同样可以通过命令行参数传递参数:

sh

复制

python script.py --name John
输出:

复制

Hello, John!

对比分析

代码简洁性:

使用 fire 库,代码更加简洁,只需一行代码即可实现命令行接口。 不使用 fire 库,需要更多的代码来设置和解析命令行参数。

功能和灵活性:

fire 提供了自动化和简化的功能,适合快速开发和测试。 argparse 等库提供了更多的自定义选项和复杂功能,适合复杂的命令行接口需求。

结论

如果你需要快速创建一个简单的命令行接口,fire 是一个很好的选择,因为它可以极大地简化开发过程。 如果你需要更多的控制和自定义选项,比如处理复杂的参数解析逻辑,argparse 或其他专门的命令行解析库可能更适合。

综上所述,选择是否使用 fire 取决于你的具体需求和项目的复杂性。

总结

**文章总结**
这篇文章描述了一个名为`length_cdf`的函数,它计算并打印数据集样本长度的累积分布函数(CDF)。这个函数接收模型的名称或路径、数据集名称、数据集目录、模板名称以及计算长度分布的区间作为参数。函数内部首先通过`llmtuner`模块的函数来获取模型和数据的相关参数,并加载分词器和数据集。接着,它创建一个字典用来统计每个长度区间的样本数量,并遍历数据集更新统计数据。随后,函数将统计数据转换为元组列表并排序,然后通过遍历元组列表来计算并打印每个长度区间的累积计数和累积概率。
当脚本作为主程序运行时,使用`fire`库将`length_cdf`函数暴露为一个命令行接口,使得通过命令行可以直接调用这个函数并传入参数。文章中还说明了使用`fire`库的好处,包括自动生成命令行接口、简化代码、提供直观和一致的命令行使用方式,以及支持多种调用方式。最后,对比了使用和不使用`fire`库时创建命令行接口的差异,并得出结论:`fire`库适合快速创建简单的命令行接口,而在需要更多控制和自定义选项时,`argparse`等专门的库可能更为适合。
**示例用法**
通过下面的命令可以直接运行`length_cdf.py`脚本,并传入相应的参数来计算和打印数据集的样本长度累积分布情况:
```bash
python length_cdf.py --model_name_or_path "path/to/model" --dataset "custom_dataset" --dataset_dir "path/to/data"
```

更新时间 2024-07-21