@dataclass
class Template:
format_user: "Formatter"
format_assistant: "Formatter"
format_system: "Formatter"
format_function: "Formatter"
format_observation: "Formatter"
format_tools: "Formatter"
format_separator: "Formatter"
default_system: str
stop_words: List[str]
efficient_eos: bool
replace_eos: bool
force_system: bool
def encode_oneturn(
self,
tokenizer: "PreTrainedTokenizer",
messages: List[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
cutoff_len: int = 1_000_000,
reserved_label_len: int = 1,
) -> Tuple[List[int], List[int]]:
r"""
Returns a single pair of token ids representing prompt and response respectively.
"""
encoded_pairs = self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len)
prompt_ids = []
for query_ids, resp_ids in encoded_pairs[:-1]:
prompt_ids += query_ids + resp_ids
prompt_ids = prompt_ids + encoded_pairs[-1][0]
answer_ids = encoded_pairs[-1][1]
return prompt_ids, answer_ids
def encode_multiturn(
self,
tokenizer: "PreTrainedTokenizer",
messages: List[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
cutoff_len: int = 1_000_000,
reserved_label_len: int = 1,
) -> Sequence[Tuple[List[int], List[int]]]:
r"""
Returns multiple pairs of token ids representing prompts and responses respectively.
"""
return self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len)
def _encode(
self,
tokenizer: "PreTrainedTokenizer",
messages: List[Dict[str, str]],
system: str,
tools: str,
cutoff_len: int,
reserved_label_len: int,
) -> Sequence[Tuple[List[int], List[int]]]:
r"""
Encodes formatted inputs to pairs of token ids.
Turn 0: system + query resp
Turn t: sep + query resp
"""
system = system or self.default_system
encoded_messages = []
for i, message in enumerate(messages):
elements = []
if i == 0 and (system or tools or self.force_system):
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
elements += self.format_system.apply(content=(system + tool_text))
elif i > 0 and i % 2 == 0:
elements += self.format_separator.apply()
if message["role"] == Role.USER.value:
elements += self.format_user.apply(content=message["content"], idx=str(i // 2))
elif message["role"] == Role.ASSISTANT.value:
elements += self.format_assistant.apply(content=message["content"])
elif message["role"] == Role.OBSERVATION.value:
elements += self.format_observation.apply(content=message["content"])
elif message["role"] == Role.FUNCTION.value:
elements += self.format_function.apply(content=message["content"])
else:
raise NotImplementedError("Unexpected role: {}".format(message["role"]))
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len)
LLamafactory用作Formatter的方法Qwen/lib/python3.12/abc.py-CSDN博客
这段代码定义了一个名为 Template
的数据类,并包含了多个方法用于将消息编码为一系列的 token id。我们逐行解释这段代码:
数据类定义
from dataclasses import dataclass
from typing import List, Dict, Optional, Sequence, Tuple
@dataclass
class Template:
使用 @dataclass
装饰器定义了一个 Template
类。dataclass
会自动生成初始化方法和其他方法。
类属性定义
format_user: "Formatter"
format_assistant: "Formatter"
format_system: "Formatter"
format_function: "Formatter"
format_observation: "Formatter"
format_tools: "Formatter"
format_separator: "Formatter"
default_system: str
stop_words: List[str]
efficient_eos: bool
replace_eos: bool
force_system: bool
定义了多个类属性,这些属性主要是不同角色的格式化器(Formatter
),以及一些控制编码行为的布尔值和字符串。
encode_oneturn
方法
python
复制
def encode_oneturn(
self,
tokenizer: "PreTrainedTokenizer",
messages: List[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
cutoff_len: int = 1_000_000,
reserved_label_len: int = 1,
) -> Tuple[List[int], List[int]]:
r"""
Returns a single pair of token ids representing prompt and response respectively.
"""
定义了一个用于编码单轮对话的 encode_oneturn
方法。
python
复制
encoded_pairs = self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len)
调用私有方法 _encode
进行实际的编码工作,返回编码后的消息对。
python
复制
prompt_ids = []
for query_ids, resp_ids in encoded_pairs[:-1]:
prompt_ids += query_ids + resp_ids
prompt_ids = prompt_ids + encoded_pairs[-1][0]
answer_ids = encoded_pairs[-1][1]
return prompt_ids, answer_ids
将所有轮次的 query_ids
和 resp_ids
拼接成一个完整的 prompt_ids
,并返回最后一轮的 response_ids
作为 answer_ids
。
encode_multiturn
方法
python
复制
def encode_multiturn(
self,
tokenizer: "PreTrainedTokenizer",
messages: List[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
cutoff_len: int = 1_000_000,
reserved_label_len: int = 1,
) -> Sequence[Tuple[List[int], List[int]]]:
r"""
Returns multiple pairs of token ids representing prompts and responses respectively.
"""
定义了一个用于编码多轮对话的 encode_multiturn
方法。
python
复制
return self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len)
直接调用 _encode
方法并返回其结果。
_encode
私有方法
python
复制
def _encode(
self,
tokenizer: "PreTrainedTokenizer",
messages: List[Dict[str, str]],
system: str,
tools: str,
cutoff_len: int,
reserved_label_len: int,
) -> Sequence[Tuple[List[int], List[int]]]:
r"""
Encodes formatted inputs to pairs of token ids.
Turn 0: system + query resp
Turn t: sep + query resp
"""
定义了一个私有方法 _encode
,用于实际的编码过程。
system = system or self.default_system
如果 system
参数为空,使用类属性 default_system
。
python
复制
encoded_messages = []
for i, message in enumerate(messages):
elements = []
if i == 0 and (system or tools or self.force_system):
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
elements += self.format_system.apply(content=(system + tool_text))
elif i > 0 and i % 2 == 0:
elements += self.format_separator.apply()
循环遍历消息列表,根据消息位置和内容组装元素列表。
对于第一条消息,如果有 system
或 tools
或者 force_system
为真,则应用 format_system
和(如果有的话) format_tools
。
从 _encode
方法的消息处理循环部分:
python
复制
if message["role"] == Role.USER.value:
elements += self.format_user.apply(content=message["content"], idx=str(i // 2))
elif message["role"] == Role.ASSISTANT.value:
elements += self.format_assistant.apply(content=message["content"])
elif message["role"] == Role.OBSERVATION.value:
elements += self.format_observation.apply(content=message["content"])
elif message["role"] == Role.FUNCTION.value:
elements += self.format_function.apply(content=message["content"])
else:
raise NotImplementedError("Unexpected role: {}".format(message["role"]))
根据消息的角色(role
)应用不同的格式化器:
如果角色是用户(Role.USER
),使用 format_user
。
如果角色是助手(Role.ASSISTANT
),使用 format_assistant
。
如果角色是观察(Role.OBSERVATION
),使用 format_observation
。
如果角色是功能(Role.FUNCTION
),使用 format_function
。
如果遇到未实现的角色,抛出 NotImplementedError
。
python
复制
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
将格式化后的元素转换为 token id 并添加到 encoded_messages
列表中。
python
复制
return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len)
调用 _make_pairs
方法,将编码后的消息组装成 token id 对并返回。
其他私有方法(假设)
以下是一些假设的私有方法,用于完成编码和转换的具体操作:
python
复制
def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: List[str]) -> List[int]:
"""
Converts a list of formatted elements into token ids using the tokenizer.
"""
token_ids = []
for element in elements:
token_ids.extend(tokenizer.encode(element, add_special_tokens=False))
return token_ids
def _make_pairs(self, encoded_messages: List[List[int]], cutoff_len: int, reserved_label_len: int) -> Sequence[Tuple[List[int], List[int]]]:
"""
Creates pairs of prompt and response token ids with cutoff and reserved lengths.
"""
pairs = []
for i in range(0, len(encoded_messages), 2):
query_ids = encoded_messages[i]
resp_ids = encoded_messages[i + 1] if i + 1 < len(encoded_messages) else []
if len(query_ids) + len(resp_ids) > cutoff_len:
query_ids = query_ids[:cutoff_len - len(resp_ids)]
resp_ids = resp_ids[:reserved_label_len]
pairs.append((query_ids, resp_ids))
return pairs
_convert_elements_to_ids
方法:将格式化的元素列表转换为 token id 列表。
_make_pairs
方法:根据 cutoff_len
和 reserved_label_len
将编码后的消息组合成 token id 对。
总结
这段代码定义了一个 Template
类,通过不同的格式化器和编码方法,将一系列消息转换为 token id。主要方法包括 encode_oneturn
和 encode_multiturn
,分别用于编码单轮和多轮对话。私有方法 _encode
负责实际的编码过程,利用 _convert_elements_to_ids
和 _make_pairs
方法完成元素转换和对话对的创建。
def get_template_and_fix_tokenizer(
tokenizer: "PreTrainedTokenizer",
name: Optional[str] = None,
) -> Template:
if name is None:
template = templates["vanilla"] # placeholder
else:
template = templates.get(name, None)
if template is None:
raise ValueError("Template {} does not exist.".format(name))
stop_words = template.stop_words
if template.replace_eos:
if not stop_words:
raise ValueError("Stop words are required to replace the EOS token.")
_add_or_replace_eos_token(tokenizer, eos_token=stop_words[0])
stop_words = stop_words[1:]
if tokenizer.eos_token_id is None:
_add_or_replace_eos_token(tokenizer, eos_token="<|endoftext|>")
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
logger.info("Add pad token: {}".format(tokenizer.pad_token))
if stop_words:
num_added_tokens = tokenizer.add_special_tokens(
dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False
)
logger.info("Add {} to stop words.".format(",".join(stop_words)))
if num_added_tokens > 0:
logger.warning("New tokens have been added, make sure `resize_vocab` is True.")
try:
tokenizer.chat_template = _get_jinja_template(template, tokenizer)
except ValueError:
logger.info("Cannot add this chat template to tokenizer.")
return template
_register_template(
name="alpaca",
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n\n### Response:\n"]),
format_separator=EmptyFormatter(slots=["\n\n"]),
default_system=(
"Below is an instruction that describes a task. " "Write a response that appropriately completes the request."
),
)
逐行解释这段代码。
get_template_and_fix_tokenizer
函数
python
复制
def get_template_and_fix_tokenizer(
tokenizer: "PreTrainedTokenizer",
name: Optional[str] = None,
) -> Template:
定义一个函数 get_template_and_fix_tokenizer
,用于获取模板并修正 tokenizer。参数包括 tokenizer
和可选的模板名称 name
。
python
复制
if name is None:
template = templates["vanilla"] # placeholder
else:
template = templates.get(name, None)
if template is None:
raise ValueError("Template {} does not exist.".format(name))
如果 name
参数为空,默认使用 vanilla
模板。
否则,尝试获取指定名称的模板。如果模板不存在,抛出 ValueError
。
python
复制
stop_words = template.stop_words
if template.replace_eos:
if not stop_words:
raise ValueError("Stop words are required to replace the EOS token.")
_add_or_replace_eos_token(tokenizer, eos_token=stop_words[0])
stop_words = stop_words[1:]
获取模板中的 stop_words
。
如果模板要求替换 EOS(End Of Sentence)标记,但 stop_words
为空,抛出 ValueError
。
否则,用 stop_words
中的第一个词替换 EOS 标记,并移除已使用的词。
解释 get_template_and_fix_tokenizer
函数:
python
复制
if tokenizer.eos_token_id is None:
_add_or_replace_eos_token(tokenizer, eos_token="
二、Qwen模板
_register_template(
name="qwen",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"],
replace_eos=True,
)
新代码解释:
_register_template(name="qwen", ...)
: 这行代码注册了一个名为"qwen"的模板。
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"])
:
定义了用户输入的格式。它使用StringFormatter,将用户的内容包装在特定的标记中。
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"])
:
定义了系统消息的格式,同样使用StringFormatter。
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"])
:
定义了工具观察结果的格式,这在之前的代码中没有。
format_separator=EmptyFormatter(slots=["\n"])
:
定义了分隔符,这里使用了一个空行。
default_system="You are a helpful assistant."
:
设置了默认的系统消息。
stop_words=["<|im_end|>"]
:
定义了停止词,当模型生成这个标记时会停止继续生成。
replace_eos=True
:
指示是否替换结束标记。
与之前代码的主要区别: 旧代码指swift
实现方式:
新代码使用了函数式的方法来注册模板。 旧代码定义了一个类QwenTemplate
,继承自 Template
。
格式化方法:
新代码使用StringFormatter
和 EmptyFormatter
来定义各种消息的格式。
旧代码直接在 __init__
方法中定义字符串格式。
灵活性:
新代码提供了更多的格式化选项,如format_observation
用于工具观察结果。
新代码明确定义了 format_separator
,而旧代码没有。
默认系统消息:
新代码直接指定了默认系统消息。 旧代码使用了一个未显示的DEFAULT_SYSTEM
常量。
停止词和EOS替换:
新代码明确定义了停止词和EOS替换选项。 旧代码没有这些明确的定义。参数处理:
新代码没有auto_add_bos
参数。
旧代码包含了 auto_add_bos
参数,用于控制是否自动添加开始标记。
总的来说,新代码提供了更灵活和详细的模板定义方式,包括了更多的格式化选项和控制参数。它似乎是为了提供更通用和可配置的模板注册方法,而旧代码更像是一个特定的模板实现。新代码可能更容易扩展和修改,以适应不同的需求。
Qwen范例
这段代码是在 LLaMA Factory 中注册 Qwen 模型的对话模板。让我为您渲染一下这个模板的真实输出效果,假设我们有一个包含系统消息、用户输入、助手回复和工具观察的对话:
<|im_start|>system
You are a helpful assistant.
<|im_end|>
<|im_start|>user
你好,请告诉我今天的天气如何。
<|im_end|>
<|im_start|>assistant
当然,我很乐意为您提供今天的天气信息。不过,作为一个AI助手,我没有实时访问天气数据的能力。为了给您最准确的信息,我需要使用一个天气查询工具。让我为您查询一下。
<|im_start|>tool
正在查询当前位置的天气信息...
查询结果:今天天气晴朗,气温在20°C到25°C之间,微风,适合户外活动。
<|im_end|>
参考文献或函数引用:
LLamafactory用作Formatter的方法Qwen/lib/python3.12/abc.py-CSDN博客
总结