当前位置:AIGC资讯 > AIGC > 正文

LangGraph实战:从零分阶打造人工智能航空客服助手

客服助手机器人能够帮助团队更高效地处理日常咨询,但要打造一个能够稳定应对各种任务且不会让用户感到烦恼的机器人并非易事。

完成本教程后,你不仅会拥有一个功能完备的机器人,还将深入理解LangGraph的核心理念和架构设计。这些知识将帮助你在其他人工智能项目中运用相似的设计模式。

由于内容较多,本文将由浅入深,分四个阶段进行讲解,每个阶段都将打造出一个具备以上描述所有能力的机器人。但受限于LLM的能力,初期阶段的机器人的运行可能存在各类问题,但都将在后续阶段得到解决。

你最终完成的聊天机器人将类似于以下示意图:

最终示意图

现在,让我们开启这第一阶段段的学习之旅吧!

准备工作

在开始之前,我们需要搭建好环境。本教程将安装一些必要的先决条件,包括下载测试用的数据库,并定义一些在后续各部分中会用到的工具。

我们会使用 Claude 作为语言模型(LLM),并创建一些定制化的工具。这些工具大多数会连接到本地的 SQLite 数据库,无需额外依赖。此外,我们还会通过 Tavily 为代理提供网络搜索功能。

%%capture --no-stderr
% pip install -U langgraph langchain-community langchain-anthropic tavily-python pandas

数据库初始化

接下来,执行下面的脚本来获取我们为这个教程准备的 SQLite 数据库,并更新它以反映当前的数据状态。具体细节不是重点。

import os
import requests
import sqlite3
import pandas as pd
import shutil

# 下载数据库文件
db_url = "https://storage.googleapis.com/benchmarks-artifacts/travel-db/travel2.sqlite"
local_file = "travel2.sqlite"
backup_file = "travel2.backup.sqlite"
overwrite = False
if not os.path.exists(local_file) or overwrite:
    response = requests.get(db_url)
    response.raise_for_status()  # 确保请求成功
    with open(local_file, "wb") as file:
        file.write(response.content)

# 创建数据库备份,以便在每个教程部分开始时重置数据库状态
shutil.copy(local_file, backup_file)

# 将航班数据更新为当前时间,以适应我们的教程
conn = sqlite3.connect(local_file)
cursor = conn.cursor()

# 读取数据库中的所有表
tables = pd.read_sql(
    "SELECT name FROM sqlite_master WHERE type='table';", conn
).name.tolist()
tdf = {}
for table_name in tables:
    tdf[table_name] = pd.read_sql(f"SELECT * from {table_name}", conn)

# 找到最早的出发时间,并计算时间差
example_time = pd.to_datetime(
    tdf["flights"]["actual_departure"].replace("\\N", pd.NaT)
).max()
current_time = pd.to_datetime("now").tz_localize(example_time.tz)
time_diff = current_time - example_time

# 更新预订日期和航班时间
for column in ["book_date", "scheduled_departure", "scheduled_arrival", "actual_departure", "actual_arrival"]:
    tdf["flights"][column] = pd.to_datetime(
        tdf["flights"][column].replace("\\N", pd.NaT)
    ) + time_diff

# 将更新后的数据写回数据库
for table_name, df in tdf.items():
    df.to_sql(table_name, conn, if_exists="replace", index=False)
conn.commit()
conn.close()

# 在本教程中,我们将使用这个本地文件作为数据库
db = local_file

工具定义

现在,我们来定义一些工具,以便助手可以搜索航空公司的政策手册,以及搜索和管理航班、酒店、租车和远足活动的预订。这些工具将在教程的各个部分中重复使用,具体的实现细节不是关键。

查询公司政策

助手需要检索政策信息来回答用户的问题。请注意,这些政策的实施还需要在工具或 API 中进行,因为语言模型可能会忽略这些信息。以下工具受限于篇幅将仅提供定义及描述,详细代码[1]可在github上获取。

import re
import numpy as np
import openai
from langchain_core.tools import tool


@tool
def lookup_policy(query):
    """查询公司政策,以确定某些选项是否允许。"""

航班管理

定义一个工具来获取用户的航班信息,然后定义一些工具来搜索航班和管理用户的预订信息,这些信息存储在 SQL 数据库中。

我们使用 ensure_config 来通过配置参数传递 passenger_id。语言模型不需要显式提供这些信息,它们会在图的每次调用中提供,以确保每个用户无法访问其他乘客的预订信息。

from langchain_core.runnables import ensure_config
from typing import Optional
import sqlite3
import pytz
from datetime import datetime, timedelta, date

@tool
def fetch_user_flight_information():
    """获取用户的所有机票信息,包括航班详情和座位分配。"""

@tool
def search_flights(
    departure_airport=None,
    arrival_airport=None,
    start_time=None,
    end_time=None,
    limit=20,
):
    """根据出发机场、到达机场和出发时间范围来搜索航班。"""

@tool
def update_ticket_to_new_flight(ticket_no, new_flight_id):
    """将用户的机票更新到一个新的有效航班上。"""

@tool
def cancel_ticket(ticket_no):
    """取消用户的机票,并从数据库中移除。"""

租车服务

用户预订了航班后,可能需要租车服务。定义一些工具,让用户能够在目的地搜索和预订汽车。

from typing import Optional, Union
from datetime import datetime, date

@tool
def search_car_rentals(
    locatinotallow=None,
    name=None,
    price_tier=None,
    start_date=None,
    end_date=None,
):
    """
    根据位置、公司名称、价格等级、开始日期和结束日期来搜索租车服务。

    参数:
        location (Optional[str]): 租车服务的位置。
        name (Optional[str]): 租车公司的名称。
        price_tier (Optional[str]): 租车的价格等级。
        start_date (Optional[Union[datetime, date]]): 租车的开始日期。
        end_date (Optional[Union[datetime, date]]): 租车的结束日期。

    返回:
        list[dict]: 匹配搜索条件的租车服务列表。
    """

@tool
def book_car_rental(rental_id):
    """
    通过租车ID来预订租车服务。

    参数:
        rental_id (int): 要预订的租车服务的ID。

    返回:
        str: 预订成功与否的消息。
    """

@tool
def update_car_rental(
    rental_id,
    start_date=None,
    end_date=None,
):
    """
    通过租车ID来更新租车服务的开始和结束日期。

    参数:
        rental_id (int): 要更新的租车服务的ID。
        start_date (Optional[Union[datetime, date]]): 新的租车开始日期。
        end_date (Optional[Union[datetime, date]]): 新的租车结束日期。

    返回:
        str: 更新成功与否的消息。
    """

@tool
def cancel_car_rental(rental_id):
    """
    通过租车ID来取消租车服务。

    参数:
        rental_id (int): 要取消的租车服务的ID。

    返回:
        str: 取消成功与否的消息。
    """

酒店预订

用户需要住宿,因此定义一些工具来搜索和管理酒店预订。

@tool
def search_hotels(
    locatinotallow=None,
    name=None,
    price_tier=None,
    checkin_date=None,
    checkout_date=None,
):
    """
    根据位置、名称、价格等级、入住日期和退房日期来搜索酒店。

    参数:
        location (Optional[str]): 酒店的位置。
        name (Optional[str]): 酒店的名称。
        price_tier (Optional[str]): 酒店的价格等级。
        checkin_date
        
        # 入住日期和退房日期,用于搜索酒店
        checkin_date (Optional[Union[datetime, date]]): 酒店的入住日期。
        checkout_date (Optional[Union[datetime, date]]): 酒店的退房日期。

    返回:
        list[dict]: 符合搜索条件的酒店列表。
    """
    
@tool
def book_hotel(hotel_id):
    """
    通过酒店ID进行预订。

    参数:
        hotel_id (int): 要预订的酒店的ID。

    返回:
        str: 预订成功与否的消息。
    """
    
@tool
def update_hotel(
    hotel_id,
    checkin_date=None,
    checkout_date=None,
):
    """
    通过酒店ID更新酒店预订的入住和退房日期。

    参数:
        hotel_id (int): 要更新预订的酒店的ID。
        checkin_date (Optional[Union[datetime, date]]): 新的入住日期。
        checkout_date (Optional[Union[datetime, date]]): 新的退房日期。

    返回:
        str: 更新成功与否的消息。
    """
    
@tool
def cancel_hotel(hotel_id):
    """
    通过酒店ID取消酒店预订。

    参数:
        hotel_id (int): 要取消预订的酒店的ID。

    返回:
        str: 取消成功与否的消息。
    """

远足活动

最后,定义一些工具,让用户在到达目的地后搜索活动并进行预订。

@tool
def search_trip_recommendations(
    locatinotallow=None,
    name=None,
    keywords=None,
):
    """
    根据位置、名称和关键词搜索旅行推荐。

    参数:
        location (Optional[str]): 旅行推荐的地点。
        name (Optional[str]): 旅行推荐的名字。
        keywords (Optional[str]): 与旅行推荐相关的关键词。

    返回:
        list[dict]: 符合搜索条件的旅行推荐列表。
    """
    
@tool
def book_excursion(recommendation_id):
    """
    通过推荐ID预订远足活动。

    参数:
        recommendation_id (int): 要预订的旅行推荐的ID。

    返回:
        str: 预订成功与否的消息。
    """
    
@tool
def update_excursion(recommendation_id, details):
    """
    通过推荐ID更新旅行推荐的细节。

    参数:
        recommendation_id (int): 要更新的旅行推荐的ID。
        details (str): 旅行推荐的新细节。

    返回:
        str: 更新成功与否的消息。
    """
    
@tool
def cancel_excursion(recommendation_id):
    """
    通过推荐ID取消旅行推荐。

    参数:
        recommendation_id (int): 要取消的旅行推荐的ID。

    返回:
        str: 取消成功与否的消息。
    """

实用工具

定义一些辅助函数,以便在调试过程中美化图形中的消息显示,并为工具节点添加错误处理(通过将错误添加到聊天记录中)。

from langgraph.prebuilt import ToolNode
from langchain_core.runnables import RunnableLambda

def handle_tool_error(state):
    error = state.get("error")
    tool_calls = state["messages"][-1].tool_calls
    return {
        "messages": [
            ToolMessage(
                cnotallow=f"错误: {repr(error)}\n请修正你的错误。",
                tool_call_id=tc["id"],
            )
            for tc in tool_calls
        ]
    }

def create_tool_node_with_fallback(tools):
    return ToolNode(tools).with_fallbacks(
        [RunnableLambda(handle_tool_error)], exception_key="error"
    )

def _print_event(event, _printed, max_length=1500):
    current_state = event.get("dialog_state")
    if current_state:
        print(f"当前状态: ", current_state[-1])
    message = event.get("messages")
    if message:
        if isinstance(message, list):
            message = message[-1]
        if message.id not in _printed:
            msg_repr = message.pretty_repr(html=True)
            if len(msg_repr) > max_length:
                msg_repr = msg_repr[:max_length] + " ... (内容已截断)"
            print(msg_repr)
            _printed.add(message.id)

第一部分:零样本代理

在构建任何系统时,最佳实践是从最简单的可行方案开始,并通过使用类似LangSmith这样的评估工具来测试其有效性。在条件相同的情况下,我们倾向于选择简单且可扩展的解决方案,而不是复杂的方案。然而,单一图谱方法存在一些限制,比如机器人可能在未经用户确认的情况下执行不希望的操作,处理复杂查询时可能遇到困难,或者在回答时缺乏针对性。这些问题我们会在后续进行改进。 在这部分,我们将定义一个简单的零样本代理作为用户的助手,并将所有工具赋予给它。我们的目标是引导它明智地使用这些工具来帮助用户。 我们的简单两节点图如下所示:

第一部分图解

首先,我们定义状态。

状态

我们将StateGraph的状态定义为一个包含消息列表的类型化字典。这些消息构成了聊天的记录,也就是我们简单助手所需要的全部状态信息。

from langgraph.graph.message import add_messages, AnyMessage
from typing_extensions import TypedDict
from typing import Annotated


class State(TypedDict):
    messages: Annotated[list[AnyMessage], add_messages]

代理

然后,我们定义助手函数。这个函数接收图的状态,将其格式化为提示,然后调用一个大型语言模型(LLM)来预测最佳的响应。

from langchain_core.runnables import Runnable, RunnableConfig
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_anthropic import ChatAnthropic
from langchain_core.prompts import ChatPromptTemplate


class Assistant:
    def __init__(self, runnable: Runnable):
        self.runnable = runnable

    def __call__(self, state: State, config: RunnableConfig):
        while True:
            passenger_id = config.get("passenger_id", None)
            state = {**state, "user_info": passenger_id}
            result = self.runnable.invoke(state)
            # 如果大型语言模型返回了一个空响应,我们将重新提示它给出一个实际的响应。
            if (
                not result.content
                or isinstance(result.content, list)
                and not result.content[0].get("text")
            ):
                messages = state["messages"] + [("user", "请给出一个真实的输出。")]
                state = {**state, "messages": messages}
            else:
                break
        return {"messages": result}


# Haiku模型更快、成本更低,但准确性稍差
# llm = ChatAnthropic(model="claude-3-haiku-20240307")
llm = ChatAnthropic(model="claude-3-sonnet-20240229", temperature=1)

primary_assistant_prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "你是一个为瑞士航空提供帮助的客户支持助手。"
            "使用提供的工具来搜索航班、公司政策和其他信息以帮助回答用户的查询。"
            "在搜索时,要有毅力。如果第一次搜索没有结果,就扩大你的查询范围。"
            "如果搜索结果为空,不要放弃,先扩大搜索范围。"
            "\n\n当前用户:\n<User>\n{user_info}\n</User>"
            "\n当前时间:{time}。",
        ),
        ("placeholder", "{messages}"),
    ]
).partial(time=datetime.now())

part_1_tools = [
    TavilySearchResults(max_results=1),
    fetch_user_flight_information,
    search_flights,
    lookup_policy,
    update_ticket_to_new_flight,
    cancel_ticket,
    search_car_rentals,
    book_car_rental,
    update_car_rental,
    cancel_car_rental,
    search_hotels,
    book_hotel,
    update_hotel,
    cancel_hotel,
    search_trip_recommendations,
    book_excursion,
    update_excursion,
    cancel_excursion,
]
part_1_assistant_runnable = primary_assistant_prompt | llm.bind_tools(part_1_tools)

定义图

现在,我们来创建图。这张图是我们这部分的最终助手。

from langgraph.checkpoint.sqlite import SqliteSaver
from langgraph.graph import StateGraph, END
from langgraph.prebuilt import tools_condition, ToolNode

builder = StateGraph(State)


# 定义节点:这些节点执行具体的工作
builder.add_node("assistant", Assistant(part_1_assistant_runnable))
builder.add_node("action", create_tool_node_with_fallback(part_1_tools))
# 定义边:这些边决定了控制流程如何移动
builder.set_entry_point("assistant")
builder.add_conditional_edges(
    "assistant",
    tools_condition,
    # "action"调用我们的工具之一。END导致图终止(并向用户做出响应)
    {"action": "action", END: END},
)
builder.add_edge("action", "assistant")

# 检查点器允许图保存其状态
# 这是整个图的完整记忆。
memory = SqliteSaver.from_conn_string(":memory:")
part_1_graph = builder.compile(checkpointer=memory)

from IPython.display import Image, display

try:
    display(Image(part_1_graph.get_graph(xray=True).draw_mermaid_png()))
except:
    # 这需要一些额外的依赖项,是可选的
    pass

示例对话

现在,让我们通过一系列对话示例来测试我们的聊天机器人。

import uuid
import shutil

# 假设这是用户与助手之间可能发生的对话示例
tutorial_questions = [
    "你好,我的航班是什么时候?",
    "我可以把我的航班改签到更早的时间吗?我想今天晚些时候离开。",
    "那就把我的航班改签到下周某个时间吧",
    "下一个可用的选项很好",
    "住宿和交通方面有什么建议?",
    "我想在为期一周的住宿中选择一个经济实惠的酒店(7天),并且我还想租一辆车。",
    "好的,你能为你推荐的酒店预订吗?听起来不错。",
    "是的,去预订任何中等价位且有可用性的酒店。",
    "对于汽车,我有哪些选择?",
    "太棒了,我们只选择最便宜的选项。预订7天。",
    "那么,你对我的旅行有什么建议?",
    "在我在那里的时候,有哪些活动是可用的?",
    "有趣 - 我喜欢博物馆,有哪些选择?",
    "好的,那就为我在那里的第二天预订一个。",
]

# 使用备份文件以便我们可以从每个部分的原始位置重新启动
shutil.copy(backup_file, db)
thread_id = str(uuid.uuid4())

config = {
    "configurable": {
        # passenger_id 在我们的航班工具中使用
        # 以获取用户的航班信息
        "passenger_id": "3442 587242",
        # 检查点通过 thread_id 访问
        "thread_id": thread_id,
    }
}


_printed = set()
for question in tutorial_questions:
    events = part_1_graph.stream(
        {"messages": ("user", question)}, config, stream_mode="values"
    )
    for event in events:
        _print_event(event, _printed)

更新时间 2024-05-09