当前位置:AIGC资讯 > AIGC > 正文

llama大模型提前停止策略,实现工具调用——以Llama3为例

在大模型的generate过程中为了实现工具调用功能,模型输出到了指定的token需要执行工具的时候,需要模型的generate停止。model.generate()的参考链接:https://github.com/huggingface/transformers/blob/v4.20.1/src/transformers/generation_utils.py#L844

其中有一个参数:stopping_criteria,用于判断模型输出是否应该停止,下面介绍具体方案。

Step1:重新定义一个类继承StoppingCriteria并定义停止逻辑

class KeyWordOne_StoppingCriteria(StoppingCriteria):
    def __init__(self,keyword,tokenizer,device):
        self.keyword = tokenizer.encode(keyword,add_special_tokens = False,return_tensors = 'pt').squeeze().to(device)
    def __call__(self,input_ids,scores,**kwards):
        if len(input_ids[0]) < len(self.keyword):
            return False
        if input_ids[0][len(input_ids[0] - len(self.keyword):].equal(self.keyword):
            return True
        else:
            return False

上面的代码中,构造函数一般用来确定我们要停止的tokens
call调用函数是用来确定停止逻辑的。

input_ids:代表的是模型单步generate的所有batch的tokens。 scores:暂时不知道是什么。

Step2:实例化类,并将类传入到generate函数中即可实现提前停止。

from self import KeyWordOne_StoppingCriteri
from transformers import StoppingCriteriaList
stopcrieria = KeyWordOne_StoppingCriteri("<Retrieval>",tokenizer=tokenizer,device=device)
data = [{"content": "针对下列问题,请结合你的知识判断是否需要检索外部数据库以协助回答。如果需要检索,请你给出<Retrieval>标签代表需要检索外部数据库,如果你给出了<Retrieval>标签,那么在它和<Retrieval/>标签中间代表的是从外部数据库检索到的数据,不需要你进行预测。请你根据给出的知识或者自身的能力回答问题。请注意:1、有些问题可能无需外部资料即可回答。2、所提供的资料可能包含与问题无关的信息,请忽略并基于你已有的知识回答问题。", "role": "system"}, 
        {"content": "请在法律条文中上仔细的告知我,如果我在澳门特别行政区设立一家公司,会受到中国大陆财政政策的影响吗,还是澳门有自己独立的财政政策?", "role": "user"}]
data = tokenizer.apply_chat_template(data,add_generation_prompt=True,tokenize=False)
generate_ids = tokenizer(data,return_tensors='pt').to(device)
stop_ids = tokenizer.encode("<Retrieval>",return_tensors='pt')
generated_ids = model.generate(**generate_ids,stopping_criteria=StoppingCriteriaList([stopcrieria]))
tokenizer.batch_decode(generated_ids,skip_special_tokens=True)[0]

上面的代码中,使用StoppingCriteriaList方法包裹了我们实例化的自定义的StoppingCriteria的子类,并用列表包裹起来。

缺点:感觉这样的方案使得模型的推理失去了并行性。

总结

这篇文章主要介绍了在大模型生成(generate)过程中实现工具调用功能的方式,特别是通过自定义停止条件来提前结束文本生成过程。下面是对文章核心内容的总结:
### 背景与目标
在使用大模型(如通过Hugging Face Transformers库实现的模型)进行文本生成时,需要在特定条件下停止生成过程,以触发特定工具或功能的调用。这通常通过模型的`generate`函数中的`stopping_criteria`参数实现。
### 解决方案步骤
1. **定义自定义停止条件**:
- 通过继承`StoppingCriteria`基类,定义一个名为`KeyWordOne_StoppingCriteria`的类,专门用于检测特定的关键词(如``)是否出现在生成的文本中,从而确定是否需要停止生成过程。
- 在类的构造函数中,使用编码器(tokenizer)将关键词编码成模型可理解的格式,并保存到成员变量中。
- 实现`__call__`方法,该方法接收生成的tokens(`input_ids`)和可能的其他参数(如`scores`),通过比较当前的生成文本是否包含预设的关键词来决定是否返回停止信号。
2. **实例化并使用停止条件**:
- 实例化`KeyWordOne_StoppingCriteria`类,传入关键词、编码器和设备信息。
- 将该实例通过`StoppingCriteriaList`包装,并作为`stopping_criteria`参数传递给`generate`函数。这样做可以在生成过程中实时检查生成文本,一旦检测到预设的关键词就停止生成。
3. **数据准备与生成**:
- 准备待生成的输入数据,使用编码器适当地格式化数据。
- 通过模型的`generate`方法进行文本生成,并结合自定义的停止条件来控制生成过程的结束。
### 优缺点分析
- **优点**:
- 灵活控制生成过程,可以根据实际需要自定义停止条件。
- 提高生成效率,避免无用文本的生成。
- **缺点**:
- 由于需要逐步检查生成文本是否满足停止条件,可能导致模型推理的并行性受到影响,从而可能影响整体性能。
### 注意事项
- 需要确保传入给停止条件的设备信息与模型运行的设备信息一致。
- 在`__call__`方法的实现中,注意比较生成的tokens和关键词编码的逻辑是否正确,特别是切片操作和tensor的比较操作。
### 结论
通过自定义`StoppingCriteria`类并在模型生成过程中使用这一机制,可以有效地控制生成文本的输出并在特定条件下停止生成。这为大模型的文本生成功能增加了灵活性和实用性。然而,也应注意其对计算性能可能产生的影响,并在实际应用中进行适当的优化和调整。

更新时间 2024-08-01