当前位置:AIGC资讯 > AIGC > 正文

Stable Diffusion | Gradio界面设计及API调用

Stability AI 2024年2月发布了 Stable Cascade 模型,但由于该模型较大(fp32格式的 Stage_A + Stage_B + Stage_C 模型超过20GB,ComfyUI 专用 Stage_B + Stage_C 模型也要14GB),对显卡要求较高,限制了大家体验 Stable Cascade 模型。

本文主要介绍如何在百度 AI Studio 平台上通过 Gradio 交互式界面运行 Stable Cascade 模型进行文生图。

项目链接:Stable Diffusion ComfyUI 在线体验 。

1. 项目体验

由于项目更新,以下示例图片与新版本有一定差异,但操作方法一样。

1.1 创建项目副本

打开项目链接 Stable Diffusion | Gradio界面设计及API调用

百度 AI Studio 平台需要登陆百度账号使用

点击右上角红色方框内的 fork 按钮

创建项目副本

运行项目

1.2 获取免费算力卡

百度 AI Studio 平台每日免费算力卡需要运行任意项目后发放,fork 项目后点击启动环境即可获取8点免费算力卡

1.3 启动GPU环境

选择 V100 32GB 运行环境并点击确定

进入环境

1.4 部署ComfyUI

双击打开左侧文件浏览器中的 1 部署ComfyUI.ipynb

点击红框内按钮解压 ComfyUI 部署包

解压完成(解压仅需1-2分钟):

1.5 启动ComfyUI-API

双击打开左侧文件浏览器中的 2 启动ComfyUI-API.ipynb 并点击红框内按钮启动 ComfyUI-API

ComfyUI-API 已启动:

1.6 启动Gradio界面

双击打开左侧文件浏览器中的 3 启动Gradio界面.gradio.py,等待 Gradio 界面加载完成后点击红框内按钮在新的浏览器页面打开 Gradio 界面

浏览器新页面中的 Gradio 界面:

点击右上角红框内按钮开始文生图,首次运行因为要加载约14GB的 Stable Cascade 模型到显存,第一张图片大约需要2分钟才能生成,后续生成一张 1024*1024 图片大约需要30秒。

文生图示例:

1.7 停止项目

由于 AI Studio 平台每天的免费算力卡只有8点,运行 V100 32GB 环境每小时消耗3点算力卡,不生图时应尽快关闭项目

依次关闭四个选项卡

无需保存修改

项目首页点击停止按钮

2. Gradio界面设计及API调用源码

源码随项目不断更新,最新版本见项目内部(Stable Diffusion | Gradio界面设计及API调用)。

import gradio as gr
import io
import json
import os
import random
import requests
import urllib.parse
import uuid
from PIL import Image
import sys
sys.path.append("/home/aistudio/work/ComfyUI/venv/lib/python3.10/site-packages")
import websocket

# 定义ComfyUI服务器地址
server_address = "127.0.0.1:8188"
# 定义SD模型所在文件夹路径,默认sd_models_path为该py文件所在路径+"/data"
sd_models_path = os.getcwd() + "/data"

# 定义默认正向提示词
default_prompt = "evening sunset scenery blue sky nature, glass bottle with a fizzy ice cold freezing rainbow liquid in it"
# 定义默认负向提示词
default_negative_prompt = "text, watermark"
# 定义可选择采样器和采样计划表类型(数组格式)
samplers = ["euler", "euler_ancestral", "heun", "heunpp2", "dpm_2", "dpm_2_ancestral", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu", "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm"]
schedulers = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"]

# 定义 获取指定SD模型(sd_model)在sd_models_path中的路径 的函数
def get_model_path(sd_model):
    # 当指定SD模型在sd_models_path根目录中时,模型路径与模型名称相同
    if os.path.exists(os.path.join(sd_models_path, sd_model)):
        sd_model_path = sd_model
    # 当指定SD模型在sd_models_path子目录中时,模型路径为"子目录名称/模型名称"
    else:
        for folder in os.listdir(sd_models_path):
            temp_sd_models_path = os.path.join(sd_models_path, folder)
            if os.path.exists(os.path.join(temp_sd_models_path, sd_model)):
                sd_model_path = os.path.join(folder, sd_model)
    return sd_model_path

# 定义客户端ID,用于和服务器建立websocket连接
client_id = str(uuid.uuid4())

# 定义 向服务器提交工作流并获取生成的图片 的函数
def generate_images(workflow):
    # 与服务器建立websocket连接
    ws = websocket.WebSocket()
    ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id))
    data = {"prompt":workflow, "client_id":client_id}
    prompt_id = requests.post(url="http://{}/prompt".format(server_address), json=data).json()["prompt_id"]
    while True:
        wsrecv = ws.recv()
        if isinstance(wsrecv, str):
            message = json.loads(wsrecv)
            if message["type"] == "executing":
                data = message["data"]
                if data["node"] is None and data["prompt_id"] == prompt_id:
                    break
        else:
            continue
    history = requests.get(url="http://{}/history/{}".format(server_address, prompt_id)).json()[prompt_id]
    for output in history["outputs"]:
        for node_id in history["outputs"]:
            node_output = history["outputs"][node_id]
            if "images" in node_output:
                images = []
                for image in node_output["images"]:
                    data = {"filename":image["filename"], "subfolder":image["subfolder"], "type":image["type"]}
                    url_values = urllib.parse.urlencode(data)
                    image_data = requests.get("http://{}/view?{}".format(server_address, url_values)).content
                    image = Image.open(io.BytesIO(image_data))
                    images.append(image)
    return images

# 定义 Stable Cascade文生图 函数
def cascade_txt2img(positive_prompt, negative_prompt, width, height, compression, batch_size, seed_c, steps_c, cfg_c, sampler_name_c, scheduler_c, denoise_c, seed_b, steps_b, cfg_b, sampler_name_b, scheduler_b, denoise_b):
    if seed_c == "-1":
        seed_c = random.randint(0, 9223372036854775807)
    if seed_b == "-1":
        seed_b = random.randint(0, 9223372036854775807)
    # 定义Stable Cascade txt2img工作流
    cascade_txt2img_workflow = {
    "1":{"inputs":{"ckpt_name":get_model_path("stable_cascade_stage_c.safetensors")}, "class_type":"CheckpointLoaderSimple"}, 
    "2":{"inputs":{"text":positive_prompt, "clip":["1", 1]}, "class_type":"CLIPTextEncode"}, 
    "3":{"inputs":{"text":negative_prompt, "clip":["1", 1]}, "class_type":"CLIPTextEncode"}, 
    "4":{"inputs":{"width":width, "height":height, "compression":compression, "batch_size":batch_size}, "class_type":"StableCascade_EmptyLatentImage"}, 
    "5":{"inputs":{"seed":seed_c, "steps":steps_c, "cfg":cfg_c, "sampler_name":sampler_name_c, "scheduler":scheduler_c, "denoise":denoise_c, "model":["1", 0], "positive":["2", 0], "negative":["3", 0], "latent_image":["4", 0]}, "class_type":"KSampler"}, 
    "6":{"inputs":{"conditioning":["2", 0], "stage_c":["5", 0]}, "class_type":"StableCascade_StageB_Conditioning"}, 
    "7":{"inputs":{"ckpt_name":get_model_path("stable_cascade_stage_b.safetensors")}, "class_type":"CheckpointLoaderSimple"}, 
    "8":{"inputs":{"seed":seed_b, "steps":steps_b, "cfg":cfg_b, "sampler_name":sampler_name_b, "scheduler":scheduler_b, "denoise":denoise_b, "model":["7", 0], "positive":["6", 0], "negative":["3", 0], "latent_image":["4", 1]}, "class_type":"KSampler"}, 
    "9":{"inputs":{"samples":["8", 0], "vae":["7", 2]}, "class_type":"VAEDecode"}, 
    "10":{"inputs":{"filename_prefix":"Cascade", "images":["9", 0]}, "class_type":"SaveImage"}
    }
    images = generate_images(cascade_txt2img_workflow)
    return images

# Gradio界面设计
with gr.Blocks() as demo:
    # 以下模块按行排列
    with gr.Row():
        # 以下模块按列排列
        with gr.Column():
            # gr.Textbox()为可输入文本框,label为该模块的标签,value为默认值
            positive_prompt = gr.Textbox(label="Positive prompt | 正向提示词", value=default_prompt)
            negative_prompt = gr.Textbox(label="Negative prompt | 负向提示词", value=default_negative_prompt)
            # gr.Tab()为选项卡模块,label为该模块的标签
            with gr.Tab(label="Stage C 采样阶段设置"):
                # 以下模块合并为组
                with gr.Group():
                    with gr.Row():
                        # gr.Dropdown()为可下拉选择框,第一个参数必须为包含下拉选项的数组["..."],label为该模块的标签,value为默认值
                        sampler_name_c = gr.Dropdown(samplers, label="Sampling method | 采样方法", value=samplers[12])
                        scheduler_c = gr.Dropdown(schedulers, label="Schedule type | 采样计划表类型", value=schedulers[1])
                    with gr.Row():
                        # gr.Slider()为滑块模块,minimum为最小数值,maximum为最大数值,step为最小滑动步长,label为该模块的标签,value为默认值
                        width = gr.Slider(minimum=512, maximum=2048, step=128, label="Width | 图像宽度", value=1024)
                        steps_c = gr.Slider(minimum=10, maximum=30, step=1, label="Sampling steps | 采样次数", value=20)
                    with gr.Row():
                        height = gr.Slider(minimum=512, maximum=2048, step=128, label="Height | 图像高度", value=1024)
                        batch_size = gr.Slider(minimum=1, maximum=8, step=1, label="Batch size | 单批次生成图像数", value=1)
                    with gr.Row():
                        denoise_c = gr.Slider(minimum=0, maximum=1, step=0.05, label="Denoise | 去噪强度", value=1)
                        compression = gr.Slider(minimum=8, maximum=42, step=1, label="Compression | 压缩倍率", value=42)
                    with gr.Row():
                        cfg_c = gr.Slider(minimum=0, maximum=20, step=0.1, label="CFG Scale | CFG权重", value=4.0)
                        seed_c = gr.Textbox(label="Seed | 种子数(-1表示随机种子数)", value=-1)
            with gr.Tab(label="Stage B 采样阶段设置", open=False):
                    with gr.Row():
                        sampler_name_b = gr.Dropdown(samplers, label="Sampling method | 采样方法", value=samplers[12])
                        scheduler_b = gr.Dropdown(schedulers, label="Schedule type | 采样计划表类型", value=schedulers[1])
                    with gr.Row():
                        denoise_b = gr.Slider(minimum=0, maximum=1, step=0.05, label="Denoise | 去噪强度", value=1)
                        steps_b = gr.Slider(minimum=4, maximum=12, step=1, label="Sampling steps | 采样次数", value=10)
                    with gr.Row():
                        cfg_b = gr.Slider(minimum=0, maximum=20, step=0.1, label="CFG Scale | CFG权重", value=1.1)
                        seed_b = gr.Textbox(label="Seed | 种子数(-1表示随机种子数)", value=-1)
        with gr.Column():
            # gr.Button()为按键模块,仅显示一个按键,需搭配.Click()定义该按键功能
            btn = gr.Button("Generate | 生成")
            # gr.Gallery()为画廊模块,可以显示一张或多张生成图片,设置preview=True可开启预览模式,height参数为画廊模块的高度,单位为像素
            gallery = gr.Gallery(preview=True, height=640)
    # .Click()模块,可用于定义按键功能,fn为按下该按键后调用的函数,inputs为该函数的输入值,outputs为Gradio输出内容,outputs内模块会依次读取函数返回的值,注意顺序和数值类型!!!
    btn.click(fn=cascade_txt2img, inputs=[positive_prompt, negative_prompt, width, height, compression, batch_size, seed_c, steps_c, cfg_c, sampler_name_c, scheduler_c, denoise_c, seed_b, steps_b, cfg_b, sampler_name_b, scheduler_b, denoise_b], outputs=[gallery])
# .queue()可指定队列相关参数,此处status_update_rate=30为每30秒给客户端发送队列完成状态,用于防止Gradio超时60秒后自动报错并退出,此处inbrowser=True可在Gradio启动后自动打开网页,受AI Studio平台限制,该参数无法打开网页~
demo.queue(status_update_rate=30).launch(inbrowser=True)

更新时间 2024-06-22