当前位置:AIGC资讯 > AIGC > 正文

源码解析LLaMA-Factory/src/llmtuner/data/template.py + Qwen模板

@dataclass
class Template:
    format_user: "Formatter"
    format_assistant: "Formatter"
    format_system: "Formatter"
    format_function: "Formatter"
    format_observation: "Formatter"
    format_tools: "Formatter"
    format_separator: "Formatter"
    default_system: str
    stop_words: List[str]
    efficient_eos: bool
    replace_eos: bool
    force_system: bool

    def encode_oneturn(
        self,
        tokenizer: "PreTrainedTokenizer",
        messages: List[Dict[str, str]],
        system: Optional[str] = None,
        tools: Optional[str] = None,
        cutoff_len: int = 1_000_000,
        reserved_label_len: int = 1,
    ) -> Tuple[List[int], List[int]]:
        r"""
        Returns a single pair of token ids representing prompt and response respectively.
        """
        encoded_pairs = self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len)
        prompt_ids = []
        for query_ids, resp_ids in encoded_pairs[:-1]:
            prompt_ids += query_ids + resp_ids
        prompt_ids = prompt_ids + encoded_pairs[-1][0]
        answer_ids = encoded_pairs[-1][1]
        return prompt_ids, answer_ids

    def encode_multiturn(
        self,
        tokenizer: "PreTrainedTokenizer",
        messages: List[Dict[str, str]],
        system: Optional[str] = None,
        tools: Optional[str] = None,
        cutoff_len: int = 1_000_000,
        reserved_label_len: int = 1,
    ) -> Sequence[Tuple[List[int], List[int]]]:
        r"""
        Returns multiple pairs of token ids representing prompts and responses respectively.
        """
        return self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len)

    def _encode(
        self,
        tokenizer: "PreTrainedTokenizer",
        messages: List[Dict[str, str]],
        system: str,
        tools: str,
        cutoff_len: int,
        reserved_label_len: int,
    ) -> Sequence[Tuple[List[int], List[int]]]:
        r"""
        Encodes formatted inputs to pairs of token ids.
        Turn 0: system + query        resp
        Turn t: sep + query           resp
        """
        system = system or self.default_system
        encoded_messages = []
        for i, message in enumerate(messages):
            elements = []
            if i == 0 and (system or tools or self.force_system):
                tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
                elements += self.format_system.apply(content=(system + tool_text))
            elif i > 0 and i % 2 == 0:
                elements += self.format_separator.apply()

            if message["role"] == Role.USER.value:
                elements += self.format_user.apply(content=message["content"], idx=str(i // 2))
            elif message["role"] == Role.ASSISTANT.value:
                elements += self.format_assistant.apply(content=message["content"])
            elif message["role"] == Role.OBSERVATION.value:
                elements += self.format_observation.apply(content=message["content"])
            elif message["role"] == Role.FUNCTION.value:
                elements += self.format_function.apply(content=message["content"])
            else:
                raise NotImplementedError("Unexpected role: {}".format(message["role"]))

            encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))

        return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len)

LLamafactory用作Formatter的方法Qwen/lib/python3.12/abc.py-CSDN博客

这段代码定义了一个名为 Template 的数据类,并包含了多个方法用于将消息编码为一系列的 token id。我们逐行解释这段代码:

数据类定义

from dataclasses import dataclass
from typing import List, Dict, Optional, Sequence, Tuple

@dataclass
class Template:
使用 @dataclass 装饰器定义了一个 Template 类。dataclass 会自动生成初始化方法和其他方法。

类属性定义

    format_user: "Formatter"
    format_assistant: "Formatter"
    format_system: "Formatter"
    format_function: "Formatter"
    format_observation: "Formatter"
    format_tools: "Formatter"
    format_separator: "Formatter"
    default_system: str
    stop_words: List[str]
    efficient_eos: bool
    replace_eos: bool
    force_system: bool
定义了多个类属性,这些属性主要是不同角色的格式化器(Formatter),以及一些控制编码行为的布尔值和字符串。

encode_oneturn 方法

python

复制

    def encode_oneturn(
        self,
        tokenizer: "PreTrainedTokenizer",
        messages: List[Dict[str, str]],
        system: Optional[str] = None,
        tools: Optional[str] = None,
        cutoff_len: int = 1_000_000,
        reserved_label_len: int = 1,
    ) -> Tuple[List[int], List[int]]:
        r"""
        Returns a single pair of token ids representing prompt and response respectively.
        """
定义了一个用于编码单轮对话的 encode_oneturn 方法。

python

复制

        encoded_pairs = self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len)
调用私有方法 _encode 进行实际的编码工作,返回编码后的消息对。

python

复制

        prompt_ids = []
        for query_ids, resp_ids in encoded_pairs[:-1]:
            prompt_ids += query_ids + resp_ids
        prompt_ids = prompt_ids + encoded_pairs[-1][0]
        answer_ids = encoded_pairs[-1][1]
        return prompt_ids, answer_ids
将所有轮次的 query_ids 和 resp_ids 拼接成一个完整的 prompt_ids,并返回最后一轮的 response_ids 作为 answer_ids

encode_multiturn 方法

python

复制

    def encode_multiturn(
        self,
        tokenizer: "PreTrainedTokenizer",
        messages: List[Dict[str, str]],
        system: Optional[str] = None,
        tools: Optional[str] = None,
        cutoff_len: int = 1_000_000,
        reserved_label_len: int = 1,
    ) -> Sequence[Tuple[List[int], List[int]]]:
        r"""
        Returns multiple pairs of token ids representing prompts and responses respectively.
        """
定义了一个用于编码多轮对话的 encode_multiturn 方法。

python

复制

        return self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len)
直接调用 _encode 方法并返回其结果。

_encode 私有方法

python

复制

    def _encode(
        self,
        tokenizer: "PreTrainedTokenizer",
        messages: List[Dict[str, str]],
        system: str,
        tools: str,
        cutoff_len: int,
        reserved_label_len: int,
    ) -> Sequence[Tuple[List[int], List[int]]]:
        r"""
        Encodes formatted inputs to pairs of token ids.
        Turn 0: system + query        resp
        Turn t: sep + query           resp
        """
定义了一个私有方法 _encode,用于实际的编码过程。
        system = system or self.default_system
如果 system 参数为空,使用类属性 default_system

python

复制

        encoded_messages = []
        for i, message in enumerate(messages):
            elements = []
            if i == 0 and (system or tools or self.force_system):
                tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
                elements += self.format_system.apply(content=(system + tool_text))
            elif i > 0 and i % 2 == 0:
                elements += self.format_separator.apply()
循环遍历消息列表,根据消息位置和内容组装元素列表。 对于第一条消息,如果有 system 或 tools 或者 force_system 为真,则应用 format_system 和(如果有的话) format_tools

从 _encode 方法的消息处理循环部分:

python

复制

            if message["role"] == Role.USER.value:
                elements += self.format_user.apply(content=message["content"], idx=str(i // 2))
            elif message["role"] == Role.ASSISTANT.value:
                elements += self.format_assistant.apply(content=message["content"])
            elif message["role"] == Role.OBSERVATION.value:
                elements += self.format_observation.apply(content=message["content"])
            elif message["role"] == Role.FUNCTION.value:
                elements += self.format_function.apply(content=message["content"])
            else:
                raise NotImplementedError("Unexpected role: {}".format(message["role"]))
根据消息的角色(role)应用不同的格式化器: 如果角色是用户(Role.USER),使用 format_user。 如果角色是助手(Role.ASSISTANT),使用 format_assistant。 如果角色是观察(Role.OBSERVATION),使用 format_observation。 如果角色是功能(Role.FUNCTION),使用 format_function。 如果遇到未实现的角色,抛出 NotImplementedError

python

复制

            encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
将格式化后的元素转换为 token id 并添加到 encoded_messages 列表中。

python

复制

        return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len)
调用 _make_pairs 方法,将编码后的消息组装成 token id 对并返回。

其他私有方法(假设)

以下是一些假设的私有方法,用于完成编码和转换的具体操作:

python

复制

    def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: List[str]) -> List[int]:
        """
        Converts a list of formatted elements into token ids using the tokenizer.
        """
        token_ids = []
        for element in elements:
            token_ids.extend(tokenizer.encode(element, add_special_tokens=False))
        return token_ids

    def _make_pairs(self, encoded_messages: List[List[int]], cutoff_len: int, reserved_label_len: int) -> Sequence[Tuple[List[int], List[int]]]:
        """
        Creates pairs of prompt and response token ids with cutoff and reserved lengths.
        """
        pairs = []
        for i in range(0, len(encoded_messages), 2):
            query_ids = encoded_messages[i]
            resp_ids = encoded_messages[i + 1] if i + 1 < len(encoded_messages) else []
            if len(query_ids) + len(resp_ids) > cutoff_len:
                query_ids = query_ids[:cutoff_len - len(resp_ids)]
            resp_ids = resp_ids[:reserved_label_len]
            pairs.append((query_ids, resp_ids))
        return pairs
_convert_elements_to_ids 方法:将格式化的元素列表转换为 token id 列表。 _make_pairs 方法:根据 cutoff_len 和 reserved_label_len 将编码后的消息组合成 token id 对。

总结

这段代码定义了一个 Template 类,通过不同的格式化器和编码方法,将一系列消息转换为 token id。主要方法包括 encode_oneturn 和 encode_multiturn,分别用于编码单轮和多轮对话。私有方法 _encode 负责实际的编码过程,利用 _convert_elements_to_ids 和 _make_pairs 方法完成元素转换和对话对的创建。

def get_template_and_fix_tokenizer(
    tokenizer: "PreTrainedTokenizer",
    name: Optional[str] = None,
) -> Template:
    if name is None:
        template = templates["vanilla"]  # placeholder
    else:
        template = templates.get(name, None)
        if template is None:
            raise ValueError("Template {} does not exist.".format(name))

    stop_words = template.stop_words
    if template.replace_eos:
        if not stop_words:
            raise ValueError("Stop words are required to replace the EOS token.")

        _add_or_replace_eos_token(tokenizer, eos_token=stop_words[0])
        stop_words = stop_words[1:]

    if tokenizer.eos_token_id is None:
        _add_or_replace_eos_token(tokenizer, eos_token="<|endoftext|>")

    if tokenizer.pad_token_id is None:
        tokenizer.pad_token = tokenizer.eos_token
        logger.info("Add pad token: {}".format(tokenizer.pad_token))

    if stop_words:
        num_added_tokens = tokenizer.add_special_tokens(
            dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False
        )
        logger.info("Add {} to stop words.".format(",".join(stop_words)))
        if num_added_tokens > 0:
            logger.warning("New tokens have been added, make sure `resize_vocab` is True.")

    try:
        tokenizer.chat_template = _get_jinja_template(template, tokenizer)
    except ValueError:
        logger.info("Cannot add this chat template to tokenizer.")

    return template


_register_template(
    name="alpaca",
    format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n\n### Response:\n"]),
    format_separator=EmptyFormatter(slots=["\n\n"]),
    default_system=(
        "Below is an instruction that describes a task. " "Write a response that appropriately completes the request."
    ),
)

逐行解释这段代码。

get_template_and_fix_tokenizer 函数

python

复制

def get_template_and_fix_tokenizer(
    tokenizer: "PreTrainedTokenizer",
    name: Optional[str] = None,
) -> Template:
定义一个函数 get_template_and_fix_tokenizer,用于获取模板并修正 tokenizer。参数包括 tokenizer 和可选的模板名称 name

python

复制

    if name is None:
        template = templates["vanilla"]  # placeholder
    else:
        template = templates.get(name, None)
        if template is None:
            raise ValueError("Template {} does not exist.".format(name))
如果 name 参数为空,默认使用 vanilla 模板。 否则,尝试获取指定名称的模板。如果模板不存在,抛出 ValueError

python

复制

    stop_words = template.stop_words
    if template.replace_eos:
        if not stop_words:
            raise ValueError("Stop words are required to replace the EOS token.")

        _add_or_replace_eos_token(tokenizer, eos_token=stop_words[0])
        stop_words = stop_words[1:]
获取模板中的 stop_words。 如果模板要求替换 EOS(End Of Sentence)标记,但 stop_words 为空,抛出 ValueError。 否则,用 stop_words 中的第一个词替换 EOS 标记,并移除已使用的词。

解释 get_template_and_fix_tokenizer 函数:

python

复制

    if tokenizer.eos_token_id is None:
        _add_or_replace_eos_token(tokenizer, eos_token="

二、Qwen模板

_register_template(
    name="qwen",
    format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
    format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
    format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
    format_separator=EmptyFormatter(slots=["\n"]),
    default_system="You are a helpful assistant.",
    stop_words=["<|im_end|>"],
    replace_eos=True,
)

新代码解释:

_register_template(name="qwen", ...): 这行代码注册了一个名为"qwen"的模板。

format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]):
定义了用户输入的格式。它使用StringFormatter,将用户的内容包装在特定的标记中。

format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]):
定义了系统消息的格式,同样使用StringFormatter。

format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]):
定义了工具观察结果的格式,这在之前的代码中没有。

format_separator=EmptyFormatter(slots=["\n"]):
定义了分隔符,这里使用了一个空行。

default_system="You are a helpful assistant.":
设置了默认的系统消息。

stop_words=["<|im_end|>"]:
定义了停止词,当模型生成这个标记时会停止继续生成。

replace_eos=True:
指示是否替换结束标记。

与之前代码的主要区别: 旧代码指swift

实现方式:

新代码使用了函数式的方法来注册模板。 旧代码定义了一个类 QwenTemplate,继承自 Template

格式化方法:

新代码使用 StringFormatter 和 EmptyFormatter 来定义各种消息的格式。 旧代码直接在 __init__ 方法中定义字符串格式。

灵活性:

新代码提供了更多的格式化选项,如 format_observation 用于工具观察结果。 新代码明确定义了 format_separator,而旧代码没有。

默认系统消息:

新代码直接指定了默认系统消息。 旧代码使用了一个未显示的 DEFAULT_SYSTEM 常量。

停止词和EOS替换:

新代码明确定义了停止词和EOS替换选项。 旧代码没有这些明确的定义。

参数处理:

新代码没有 auto_add_bos 参数。 旧代码包含了 auto_add_bos 参数,用于控制是否自动添加开始标记。

总的来说,新代码提供了更灵活和详细的模板定义方式,包括了更多的格式化选项和控制参数。它似乎是为了提供更通用和可配置的模板注册方法,而旧代码更像是一个特定的模板实现。新代码可能更容易扩展和修改,以适应不同的需求。

Qwen范例

这段代码是在 LLaMA Factory 中注册 Qwen 模型的对话模板。让我为您渲染一下这个模板的真实输出效果,假设我们有一个包含系统消息、用户输入、助手回复和工具观察的对话:

<|im_start|>system
You are a helpful assistant.
<|im_end|>

<|im_start|>user
你好,请告诉我今天的天气如何。
<|im_end|>

<|im_start|>assistant
当然,我很乐意为您提供今天的天气信息。不过,作为一个AI助手,我没有实时访问天气数据的能力。为了给您最准确的信息,我需要使用一个天气查询工具。让我为您查询一下。

<|im_start|>tool
正在查询当前位置的天气信息...
查询结果:今天天气晴朗,气温在20°C到25°C之间,微风,适合户外活动。
<|im_end|>

参考文献或函数引用:

LLamafactory用作Formatter的方法Qwen/lib/python3.12/abc.py-CSDN博客

总结

更新时间 2024-09-14