当前位置:AIGC资讯 > AIGC > 正文

Lawyer LLaMA(中文法律大模型本地部署)

Lawyer LLaMA(中文法律大模型本地部署)

1.模型选择(lawyer-llama-13b-v2
2.运行环境

​ 1.建议使用Python 3.8及以上版本。

​ 2.主要依赖库如下:

transformers >= 4.28.0 注意:检索模块需要使用transformers <= 4.30 sentencepiece >= 0.1.97 gradio
3.使用步骤

​ 1.从HuggingFace下载 **Lawyer LLaMA 2 (lawyer-llama-13b-v2)**模型参数。(需要的torch )

# Use a pipeline as a high-level helper
from transformers import pipeline

pipe = pipeline("text-generation", model="pkupie/lawyer-llama-13b-v2")

2.从HuggingFace下载法条检索模块,并运行其中的python server.py启动法条检索服务,默认挂在9098端口。(注意事项,拉取的代码有可能少labels2id.pkl,pytorch_model.bin等文件)

​ 1.git lfs install

​ 2.git clone https://huggingface.co/pkupie/marriage_law_retrieval

​ 3.GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/pkupie/marriage_law_retrieval

​ 4.server.py代码这样的,模型路径手动更改

import json
import subprocess
import os
import codecs
import logging
import os
import math

import json
import random
from tqdm import tqdm
from transformers import pipeline
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig

from flask import Flask, request, jsonify
import json
import random
from tqdm import tqdm
import os
import pickle as pkl
from argparse import Namespace

from models import Elect

import torch
from transformers import AutoModel, AutoTokenizer

from sklearn.preprocessing import MultiLabelBinarizer

logger = logging.getLogger(__name__)

app = Flask(__name__)

hunyin_classifier = None

fatiao_args = Namespace()
fatiao_tokenizer = None
fatiao_model = None


@app.route('/check_hunyin', methods=['GET', 'POST'])
def check_hunyin():
    input_text = request.json['input'].strip()
    force_return = request.json['force_return'] if 'force_return' in request.json else False

    print("input_text:", input_text)

    if len(input_text) == 0:
        json_result = {
            "output": []
        }
        return jsonify(json_result)

    if not force_return:
        classifier_result = hunyin_classifier(input_text[:500])
        print(classifier_result)
        classifier_result = classifier_result[0]['label']

        # 加一条规则,如果输入文本中包含“婚”字,那么直接判定为婚姻相关
        if '婚' in input_text:
            classifier_result = True

        # 如果不是婚姻相关的,直接返回空
        if classifier_result == False:
            json_result = {
                "output": []
            }
            return jsonify(json_result)

    inputs = fatiao_tokenizer(input_text, padding='max_length', truncation=True, max_length=256, return_tensors="pt")
    batch = {
        'ids': inputs['input_ids'],
        'mask': inputs['attention_mask'],
        'token_type_ids': inputs["token_type_ids"]
    }
    model_output = fatiao_model(batch)
    pred = torch.sigmoid(model_output).cpu().detach().numpy()[0]
    pred_laws = []
    for law_id, score in sorted(enumerate(pred), key=lambda x: x[1], reverse=True):
        pred_laws.append({
            'id': law_id,
            'score': float(score),
            'text': fatiao_args.mlb.classes_[law_id]
        })

    json_result = {
        "output": pred_laws[:3]
    }

    print("json_result:", json_result)
    return jsonify(json_result)


if __name__ == "__main__":
    # 加载咨询分类模型,用于判断是否与婚姻有关
    hunyin_classifier_path = "C:/Users/win10/PycharmProjects/lawyer-llama_/marriage_law_retrieval/pretrained_models/roberta_wwm_ext_hunyin_2epoch/"

    # 检查模型文件是否存在
    model_file = os.path.join(hunyin_classifier_path, "pytorch_model.bin")

    # 打印目录内容
    print("Files in directory:")
    for filename in os.listdir(hunyin_classifier_path):
        print(filename)

    if not os.path.exists(model_file):
        print(f"Model file not found at {model_file}")
    else:
        print(f"Model file found at {model_file}")

    hunyin_config = AutoConfig.from_pretrained(
        hunyin_classifier_path,
        num_labels=2,
    )
    hunyin_tokenizer = AutoTokenizer.from_pretrained(
        hunyin_classifier_path
    )
    hunyin_model = AutoModelForSequenceClassification.from_pretrained(
        hunyin_classifier_path,
        config=hunyin_config,
    )
    hunyin_classifier = pipeline(model=hunyin_model, tokenizer=hunyin_tokenizer, task="text-classification", device=0)

    print("Model loaded successfully")

    # 加载法条检索模型
    fatiao_args.ckpt_dir = r"C:\Users\win10\PycharmProjects\lawyer-llama_\marriage_law_retrieval\pretrained_models\chinese-roberta-wwm-ext"
    fatiao_args.device = "cuda:0"

    # 确认路径是否正确
    labels2id_path = os.path.join("data", "labels2id.pkl")
    if not os.path.exists(labels2id_path):
        print(f"Labels2id file not found at {labels2id_path}")
    else:
        print(f"Labels2id file found at {labels2id_path}")

    with open(labels2id_path, "rb") as f:
        laws2id = pkl.load(f)
        fatiao_args.labels = list(laws2id.keys())

    id2laws = {}
    for k, v in laws2id.items():
        id2laws[v] = k
    print("法条个数:", len(id2laws))

    fatiao_tokenizer = AutoTokenizer.from_pretrained(fatiao_args.ckpt_dir)

    fatiao_args.tokenizer = fatiao_tokenizer
    fatiao_model = Elect(fatiao_args, "cuda:0").to("cuda:0")
    fatiao_model.eval()

    mlb = MultiLabelBinarizer()
    mlb.fit([fatiao_args.labels])
    fatiao_args.mlb = mlb

    with torch.no_grad():
        for idx, l in enumerate(fatiao_args.labels):
            text = ':'.join(l.split(':')[1:]).lower()
            la_in = fatiao_tokenizer(text, padding='max_length', truncation=True, max_length=256, return_tensors="pt")
            ids = la_in['input_ids'].to(fatiao_args.device)
            mask = la_in['attention_mask'].to(fatiao_args.device)
            fatiao_model.la[idx] += (fatiao_model.plm(input_ids=ids, attention_mask=mask)[0][:, 0]).squeeze(0)

    fatiao_model.load_state_dict(torch.load('./pretrained_models/ELECT', map_location=torch.device(fatiao_args.device)))
    fatiao_model.to(fatiao_args.device)

    logger.info("model loaded")
    app.run(host="0.0.0.0", port=9098, debug=False)

​ 5.如需使用nginx反向代理访问此服务,可参考https://github.com/LeetJoe/lawyer-llama/blob/main/demo/nginx_proxy.md (Credit to @LeetJoe)

​ 1.启动命令 python demo_web.py --port 7863 --checkpoint “C:/Users/win10/.cache/huggingface/hub/models–pkupie–lawyer-llama-13b-v2/snapshots/f61a4a16c97b6bd546790d88eaec7bc7fcd7344b” --classifier_url “http://127.0.0.1:9098/check_hunyin” --offload_folder “C:/path/to/offload/folder”(内存不够时启动的命令在这个命令中,--offload_folder "C:/path/to/offload/folder" 用于指定一个目录,用来存储模型的部分数据,从而减轻内存负担。这通常是在处理大模型时的一种策略,通过将一些不常用的模型部分卸载到磁盘上,可以节省系统内存(RAM)的使用。)

​ 2.python demo_web.py --port 7863 --checkpoint “C:/Users/win10/.cache/huggingface/hub/models–pkupie–lawyer-llama-13b-v2/snapshots/f61a4a16c97b6bd546790d88eaec7bc7fcd7344b” --classifier_url “http://127.0.0.1:9098/check_hunyin”(内存够的时候启动命令)

demo_web.py代码

import gradio as gr
import requests
import json
from transformers import LlamaForCausalLM, LlamaTokenizer, TextIteratorStreamer
import torch
import threading
import argparse

class StoppableThread(threading.Thread):
    """Thread class with a stop() method. The thread itself has to check
    regularly for the stopped() condition."""

    def __init__(self,  *args, **kwargs):
        super(StoppableThread, self).__init__(*args, **kwargs)
        self._stop_event = threading.Event()

    def stop(self):
        self._stop_event.set()

    def stopped(self):
        return self._stop_event.is_set()

def json_send(url, data=None, method="POST"):
    headers = {"Content-type": "application/json", "Accept": "text/plain", "charset": "UTF-8"}
    try:
        if method == "POST":
            if data is not None:
                response = requests.post(url=url, headers=headers, data=json.dumps(data))
            else:
                response = requests.post(url=url, headers=headers)
        elif method == "GET":
            response = requests.get(url=url, headers=headers)
        response.raise_for_status()  # Ensure we notice bad responses
        return response.json()  # Return the response as a JSON object
    except requests.exceptions.RequestException as e:
        print(f"HTTP Request failed: {e}")
        return {}

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--port", type=int, default=7860)
    parser.add_argument("--checkpoint", type=str, default="")
    parser.add_argument("--classifier_url", type=str, default="")
    parser.add_argument("--load_in_8bit", action="store_true")
    parser.add_argument("--offload_folder", type=str, default="./offload")
    args = parser.parse_args()
    checkpoint = args.checkpoint
    classifier_url = args.classifier_url

    print("Loading model...")
    tokenizer = LlamaTokenizer.from_pretrained(checkpoint)
    if args.load_in_8bit:
        model = LlamaForCausalLM.from_pretrained(checkpoint, device_map="auto", load_in_8bit=True, offload_folder=args.offload_folder)
    else:
        model = LlamaForCausalLM.from_pretrained(checkpoint, device_map="auto", torch_dtype=torch.float16, offload_folder=args.offload_folder)
    print("Model loaded.")

    with gr.Blocks() as demo:
        chatbot = gr.Chatbot()
        input_msg = gr.Textbox(label="Input")
        with gr.Row():
            generate_button = gr.Button('Generate', elem_id='generate', variant='primary')
            clear_button = gr.Button('Clear', elem_id='clear', variant='secondary')

        def user(user_message, chat_history):
            user_message = user_message.strip()
            return "", chat_history + [[user_message, None]]

        def bot(chat_history):
            # extract user inputs from chat history and retrieve law articles
            current_user_input = chat_history[-1][0]

            if len(current_user_input) == 0:
                yield chat_history[:-1]
                return

            # 检索法条
            history_user_input = [x[0] for x in chat_history]
            input_to_classifier = " ".join(history_user_input)
            data = {"input": input_to_classifier}
            result = json_send(classifier_url, data, method="POST")
            retrieve_output = result.get('output', [])

            # 构造输入
            if len(retrieve_output) == 0:
                input_text = "你是人工智能法律助手“Lawyer LLaMA”,能够回答与中国法律相关的问题。\n"
                for history_pair in chat_history[:-1]:
                    input_text += f"### Human: {history_pair[0]}\n### Assistant: {history_pair[1]}\n"
                input_text += f"### Human: {current_user_input}\n### Assistant: "
            else:
                input_text = f"你是人工智能法律助手“Lawyer LLaMA”,能够回答与中国法律相关的问题。请参考给出的\"参考法条\",回复用户的咨询问题。\"参考法条\"中可能存在与咨询无关的法条,请回复时不要引用这些无关的法条。\n"
                for history_pair in chat_history[:-1]:
                    input_text += f"### Human: {history_pair[0]}\n### Assistant: {history_pair[1]}\n"
                input_text += f"### Human: {current_user_input}\n### 参考法条: {retrieve_output[0]['text']}\n{retrieve_output[1]['text']}\n{retrieve_output[2]['text']}\n### Assistant: "

            print("=== Input ===")
            print("input_text: ", input_text)

            inputs = tokenizer(input_text, return_tensors="pt").to("cuda")
            streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

            # Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way.
            generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=400, do_sample=False, repetition_penalty=1.1)
            thread = StoppableThread(target=model.generate, kwargs=generation_kwargs)
            thread.start()

            # 开始流式生成
            chat_history[-1][1] = ""
            for new_text in streamer:
                chat_history[-1][1] += new_text
                yield chat_history

            streamer.end()
            thread.stop()
            print("Output: ", chat_history[-1][1])

        input_msg.submit(user, [input_msg, chatbot], [input_msg, chatbot], queue=False).then(
            bot, [chatbot], chatbot
        )
        generate_button.click(user, [input_msg, chatbot], [input_msg, chatbot], queue=False).then(
            bot, [chatbot], chatbot
        )

    demo.queue()
    demo.launch(share=False, server_port=args.port, server_name='0.0.0.0')

总结

**总结:Lawyer LLaMA (中文法律大模型本地部署)**
本文详细介绍了如何在本地环境部署和使用“Lawyer LLaMA”——一款基于LLaMA的中文法律大模型。部署过程分为选择模型、准备运行环境和具体使用步骤:
### 1. 模型选择
选择使用`lawyer-llama-13b-v2`模型,这是专为法律领域定制的LLaMA模型。
### 2. 运行环境
- **语言支持**:推荐使用Python 3.8及以上版本。
- **主要依赖库**:
- `transformers`(建议版本>=4.28.0,且检索模块需要<4.30)
- `sentencepiece`(>=0.1.97)
- `gradio`
### 3. 使用步骤
1. **下载模型**:
- 从Hugging Face下载`lawyer-llama-13b-v2`模型参数,并使用Transformers库构建pipeline。
2. **配置法条检索模块**:
- 通过Git LFS克隆法条检索模块(可能需处理缺少的文件如`labels2id.pkl`等)。
- 修改`server.py`中的模型路径并启动服务,默认为9098端口。
3. **启动法条检索服务**:
- 下载法条检索模块并配置Flask应用,同时检查模型和相关文件的完整性并加载模型。
4. **整合使用与部署Web界面**:
- 使用`demo_web.py`脚本,可通过命令行参数指定端口、模型检查点、法条检索服务的URL以及内存不足时的卸载文件夹。
- Gradio库用于构建交互式聊天界面,用户可通过输入文本获得法律相关的回复和参考法条。
- 流式文本生成技术用于提高回复的速度和灵活性。
5. **可选:Nginx反向代理**:
- 提供Nginx反向代理的配置示例,可用于在生产环境中安全访问和分发服务请求。
整个过程展示了如何将LLaMA模型定制为法律领域的专业助手,并通过Web界面提供给用户交互的便利。这对于需要法律咨询的用户和法律从业者来说,是一个有效且易于使用的工具。

更新时间 2024-09-23