Stability AI 2024年2月发布了 Stable Cascade 模型,但由于该模型较大(fp32格式的 Stage_A + Stage_B + Stage_C 模型超过20GB,ComfyUI 专用 Stage_B + Stage_C 模型也要14GB),对显卡要求较高,限制了大家体验 Stable Cascade 模型。
本文主要介绍如何在百度 AI Studio 平台上通过 Gradio 交互式界面运行 Stable Cascade 模型进行文生图。
项目链接:Stable Diffusion ComfyUI 在线体验 。
1. 项目体验
由于项目更新,以下示例图片与新版本有一定差异,但操作方法一样。
1.1 创建项目副本
打开项目链接 Stable Diffusion | Gradio界面设计及API调用
百度 AI Studio 平台需要登陆百度账号使用
点击右上角红色方框内的 fork 按钮
创建项目副本
运行项目
1.2 获取免费算力卡
百度 AI Studio 平台每日免费算力卡需要运行任意项目后发放,fork 项目后点击启动环境即可获取8点免费算力卡
1.3 启动GPU环境
选择 V100 32GB 运行环境并点击确定
进入环境
1.4 部署ComfyUI
双击打开左侧文件浏览器中的 1 部署ComfyUI.ipynb
点击红框内按钮解压 ComfyUI 部署包
解压完成(解压仅需1-2分钟):
1.5 启动ComfyUI-API
双击打开左侧文件浏览器中的 2 启动ComfyUI-API.ipynb 并点击红框内按钮启动 ComfyUI-API
ComfyUI-API 已启动:
1.6 启动Gradio界面
双击打开左侧文件浏览器中的 3 启动Gradio界面.gradio.py,等待 Gradio 界面加载完成后点击红框内按钮在新的浏览器页面打开 Gradio 界面
浏览器新页面中的 Gradio 界面:
点击右上角红框内按钮开始文生图,首次运行因为要加载约14GB的 Stable Cascade 模型到显存,第一张图片大约需要2分钟才能生成,后续生成一张 1024*1024 图片大约需要30秒。
文生图示例:
1.7 停止项目
由于 AI Studio 平台每天的免费算力卡只有8点,运行 V100 32GB 环境每小时消耗3点算力卡,不生图时应尽快关闭项目
依次关闭四个选项卡
无需保存修改
项目首页点击停止按钮
2. Gradio界面设计及API调用源码
源码随项目不断更新,最新版本见项目内部(Stable Diffusion | Gradio界面设计及API调用)。
import gradio as gr
import io
import json
import os
import random
import requests
import urllib.parse
import uuid
from PIL import Image
import sys
sys.path.append("/home/aistudio/work/ComfyUI/venv/lib/python3.10/site-packages")
import websocket
# 定义ComfyUI服务器地址
server_address = "127.0.0.1:8188"
# 定义SD模型所在文件夹路径,默认sd_models_path为该py文件所在路径+"/data"
sd_models_path = os.getcwd() + "/data"
# 定义默认正向提示词
default_prompt = "evening sunset scenery blue sky nature, glass bottle with a fizzy ice cold freezing rainbow liquid in it"
# 定义默认负向提示词
default_negative_prompt = "text, watermark"
# 定义可选择采样器和采样计划表类型(数组格式)
samplers = ["euler", "euler_ancestral", "heun", "heunpp2", "dpm_2", "dpm_2_ancestral", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu", "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm"]
schedulers = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"]
# 定义 获取指定SD模型(sd_model)在sd_models_path中的路径 的函数
def get_model_path(sd_model):
# 当指定SD模型在sd_models_path根目录中时,模型路径与模型名称相同
if os.path.exists(os.path.join(sd_models_path, sd_model)):
sd_model_path = sd_model
# 当指定SD模型在sd_models_path子目录中时,模型路径为"子目录名称/模型名称"
else:
for folder in os.listdir(sd_models_path):
temp_sd_models_path = os.path.join(sd_models_path, folder)
if os.path.exists(os.path.join(temp_sd_models_path, sd_model)):
sd_model_path = os.path.join(folder, sd_model)
return sd_model_path
# 定义客户端ID,用于和服务器建立websocket连接
client_id = str(uuid.uuid4())
# 定义 向服务器提交工作流并获取生成的图片 的函数
def generate_images(workflow):
# 与服务器建立websocket连接
ws = websocket.WebSocket()
ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id))
data = {"prompt":workflow, "client_id":client_id}
prompt_id = requests.post(url="http://{}/prompt".format(server_address), json=data).json()["prompt_id"]
while True:
wsrecv = ws.recv()
if isinstance(wsrecv, str):
message = json.loads(wsrecv)
if message["type"] == "executing":
data = message["data"]
if data["node"] is None and data["prompt_id"] == prompt_id:
break
else:
continue
history = requests.get(url="http://{}/history/{}".format(server_address, prompt_id)).json()[prompt_id]
for output in history["outputs"]:
for node_id in history["outputs"]:
node_output = history["outputs"][node_id]
if "images" in node_output:
images = []
for image in node_output["images"]:
data = {"filename":image["filename"], "subfolder":image["subfolder"], "type":image["type"]}
url_values = urllib.parse.urlencode(data)
image_data = requests.get("http://{}/view?{}".format(server_address, url_values)).content
image = Image.open(io.BytesIO(image_data))
images.append(image)
return images
# 定义 Stable Cascade文生图 函数
def cascade_txt2img(positive_prompt, negative_prompt, width, height, compression, batch_size, seed_c, steps_c, cfg_c, sampler_name_c, scheduler_c, denoise_c, seed_b, steps_b, cfg_b, sampler_name_b, scheduler_b, denoise_b):
if seed_c == "-1":
seed_c = random.randint(0, 9223372036854775807)
if seed_b == "-1":
seed_b = random.randint(0, 9223372036854775807)
# 定义Stable Cascade txt2img工作流
cascade_txt2img_workflow = {
"1":{"inputs":{"ckpt_name":get_model_path("stable_cascade_stage_c.safetensors")}, "class_type":"CheckpointLoaderSimple"},
"2":{"inputs":{"text":positive_prompt, "clip":["1", 1]}, "class_type":"CLIPTextEncode"},
"3":{"inputs":{"text":negative_prompt, "clip":["1", 1]}, "class_type":"CLIPTextEncode"},
"4":{"inputs":{"width":width, "height":height, "compression":compression, "batch_size":batch_size}, "class_type":"StableCascade_EmptyLatentImage"},
"5":{"inputs":{"seed":seed_c, "steps":steps_c, "cfg":cfg_c, "sampler_name":sampler_name_c, "scheduler":scheduler_c, "denoise":denoise_c, "model":["1", 0], "positive":["2", 0], "negative":["3", 0], "latent_image":["4", 0]}, "class_type":"KSampler"},
"6":{"inputs":{"conditioning":["2", 0], "stage_c":["5", 0]}, "class_type":"StableCascade_StageB_Conditioning"},
"7":{"inputs":{"ckpt_name":get_model_path("stable_cascade_stage_b.safetensors")}, "class_type":"CheckpointLoaderSimple"},
"8":{"inputs":{"seed":seed_b, "steps":steps_b, "cfg":cfg_b, "sampler_name":sampler_name_b, "scheduler":scheduler_b, "denoise":denoise_b, "model":["7", 0], "positive":["6", 0], "negative":["3", 0], "latent_image":["4", 1]}, "class_type":"KSampler"},
"9":{"inputs":{"samples":["8", 0], "vae":["7", 2]}, "class_type":"VAEDecode"},
"10":{"inputs":{"filename_prefix":"Cascade", "images":["9", 0]}, "class_type":"SaveImage"}
}
images = generate_images(cascade_txt2img_workflow)
return images
# Gradio界面设计
with gr.Blocks() as demo:
# 以下模块按行排列
with gr.Row():
# 以下模块按列排列
with gr.Column():
# gr.Textbox()为可输入文本框,label为该模块的标签,value为默认值
positive_prompt = gr.Textbox(label="Positive prompt | 正向提示词", value=default_prompt)
negative_prompt = gr.Textbox(label="Negative prompt | 负向提示词", value=default_negative_prompt)
# gr.Tab()为选项卡模块,label为该模块的标签
with gr.Tab(label="Stage C 采样阶段设置"):
# 以下模块合并为组
with gr.Group():
with gr.Row():
# gr.Dropdown()为可下拉选择框,第一个参数必须为包含下拉选项的数组["..."],label为该模块的标签,value为默认值
sampler_name_c = gr.Dropdown(samplers, label="Sampling method | 采样方法", value=samplers[12])
scheduler_c = gr.Dropdown(schedulers, label="Schedule type | 采样计划表类型", value=schedulers[1])
with gr.Row():
# gr.Slider()为滑块模块,minimum为最小数值,maximum为最大数值,step为最小滑动步长,label为该模块的标签,value为默认值
width = gr.Slider(minimum=512, maximum=2048, step=128, label="Width | 图像宽度", value=1024)
steps_c = gr.Slider(minimum=10, maximum=30, step=1, label="Sampling steps | 采样次数", value=20)
with gr.Row():
height = gr.Slider(minimum=512, maximum=2048, step=128, label="Height | 图像高度", value=1024)
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label="Batch size | 单批次生成图像数", value=1)
with gr.Row():
denoise_c = gr.Slider(minimum=0, maximum=1, step=0.05, label="Denoise | 去噪强度", value=1)
compression = gr.Slider(minimum=8, maximum=42, step=1, label="Compression | 压缩倍率", value=42)
with gr.Row():
cfg_c = gr.Slider(minimum=0, maximum=20, step=0.1, label="CFG Scale | CFG权重", value=4.0)
seed_c = gr.Textbox(label="Seed | 种子数(-1表示随机种子数)", value=-1)
with gr.Tab(label="Stage B 采样阶段设置", open=False):
with gr.Row():
sampler_name_b = gr.Dropdown(samplers, label="Sampling method | 采样方法", value=samplers[12])
scheduler_b = gr.Dropdown(schedulers, label="Schedule type | 采样计划表类型", value=schedulers[1])
with gr.Row():
denoise_b = gr.Slider(minimum=0, maximum=1, step=0.05, label="Denoise | 去噪强度", value=1)
steps_b = gr.Slider(minimum=4, maximum=12, step=1, label="Sampling steps | 采样次数", value=10)
with gr.Row():
cfg_b = gr.Slider(minimum=0, maximum=20, step=0.1, label="CFG Scale | CFG权重", value=1.1)
seed_b = gr.Textbox(label="Seed | 种子数(-1表示随机种子数)", value=-1)
with gr.Column():
# gr.Button()为按键模块,仅显示一个按键,需搭配.Click()定义该按键功能
btn = gr.Button("Generate | 生成")
# gr.Gallery()为画廊模块,可以显示一张或多张生成图片,设置preview=True可开启预览模式,height参数为画廊模块的高度,单位为像素
gallery = gr.Gallery(preview=True, height=640)
# .Click()模块,可用于定义按键功能,fn为按下该按键后调用的函数,inputs为该函数的输入值,outputs为Gradio输出内容,outputs内模块会依次读取函数返回的值,注意顺序和数值类型!!!
btn.click(fn=cascade_txt2img, inputs=[positive_prompt, negative_prompt, width, height, compression, batch_size, seed_c, steps_c, cfg_c, sampler_name_c, scheduler_c, denoise_c, seed_b, steps_b, cfg_b, sampler_name_b, scheduler_b, denoise_b], outputs=[gallery])
# .queue()可指定队列相关参数,此处status_update_rate=30为每30秒给客户端发送队列完成状态,用于防止Gradio超时60秒后自动报错并退出,此处inbrowser=True可在Gradio启动后自动打开网页,受AI Studio平台限制,该参数无法打开网页~
demo.queue(status_update_rate=30).launch(inbrowser=True)