【Diffusion实战】基于Stable Diffusion实现文本到图像的生成(Pytorch代码详解)

  来试试强大的Stable Diffusion吧,基于Stable Diffusion的pipeline,进一步了解Stable Diffusion的结构~


1、Stable Diffusion初探:从文本生成图像

  首先,得看看Stable Diffusion用起来是个什么效果。

import torch
import requests
from PIL import Image
from io import BytesIO
from matplotlib import pyplot as plt
from diffusers import StableDiffusionPipeline

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# pipeline加载
model_id = "E:/Code/kuosan/stable-diffusion-2-1-base"
pipe = StableDiffusionPipeline.from_pretrained(model_id).to(device)

# 为生成器设置一个随机种子, 使结果可重复
generator = torch.Generator(device=device).manual_seed(42)

# 运行pipeline
pipe_output = pipe(
    prompt="Palette knife painting of an autumn cityscape", # 提示文字:需要生成的
    negative_prompt="Oversaturated, blurry, low quality", # 提示文字:不需要生成的
    height=1024, width=1024,     # 定义生成图像尺寸大小
    guidance_scale=10,          # 提示文字的影响程度
    num_inference_steps=50,    # 一次生成需要的推理步骤
    generator=generator        # 设置随机数种子生成器

# 查看生成结果



pipe_output = pipe(
    prompt="A realistic photo of a cute giant panda eating fresh green bamboo", # 提示文字:需要生成的
    negative_prompt="Oversaturated, blurry, low quality", # 提示文字:不需要生成的
    height=480, width=640,     # 定义生成图像尺寸大小
    guidance_scale=10,          # 提示文字的影响程度
    num_inference_steps=50,    # 一次生成需要的推理步骤
    generator=generator        # 设置随机数种子生成器



cfg_scales = [2, 5, 8, 11, 14] 
prompt = "A cute kitten sleeping in a pile of flower petals" 
fig, axs = plt.subplots(1, len(cfg_scales), figsize=(16, 5))
for i, ax in enumerate(axs):
  im = pipe(prompt, height=480, width=480,
    guidance_scale=cfg_scales[i], num_inference_steps=35,
  ax.set_title(f'CFG Scale {cfg_scales[i]}')



2、Stable Diffusion深入:结构解析

  Stable Diffusion的pipeline中有哪些结构呢,可以打印查看一下:

['vae', 'text_encoder', 'tokenizer', 'unet', 'scheduler', 'safety_checker', 'feature_extractor', 'image_encoder']

Latent Diffusion Models结构图:

2.1 变分自编码器(VAE)

  VAE的作用是对输入图像进行压缩,其编码器完成从像素空间(Pixel space)到隐空间(Latent space)的编码,扩散过程在隐空间的图像特征中完成,VAE解码器实现从隐空间再到像素空间的转换。

# 创建测试图像, 取值范围为(-1,1)
images = torch.rand(1, 3, 512, 512).to(device) * 2 - 1 
print("Input images shape:", images.shape)

# 编码到隐空间
with torch.no_grad():
  latents = 0.18215 * pipe.vae.encode(images).latent_dist.mean
print("Encoded latents shape:", latents.shape)

# 从隐空间解码
with torch.no_grad():
  decoded_images = pipe.vae.decode(latents / 0.18215).sample
print("Decoded images shape:", decoded_images.shape)


Input images shape: torch.Size([1, 3, 512, 512])
Encoded latents shape: torch.Size([1, 4, 64, 64])
Decoded images shape: torch.Size([1, 3, 512, 512])


2.2 分词器与文本编码器



# 对输入文字进行分词
input_ids = pipe.tokenizer(["A painting of a flooble"])['input_ids']
print("Input ID -> decoded token")
for input_id in input_ids[0]:
  print(f"{input_id} -> {pipe.tokenizer.decode(input_id)}")

# 将分词结果输入CLIP文本编码器
input_ids = torch.tensor(input_ids).to(device)
with torch.no_grad():
  text_embeddings = pipe.text_encoder(input_ids)['last_hidden_state']
print("Text embeddings shape:", text_embeddings.shape)

text_embeddings = pipe.encode_prompt("A painting of a flooble", device, 1, False, '')


Input ID -> decoded token
49406 -> <|startoftext|>
320 -> a
3086 -> painting
539 -> of
320 -> a
4062 -> floo
1059 -> ble
49407 -> <|endoftext|>
Text embeddings shape: torch.Size([1, 8, 1024])
torch.Size([1, 77, 1024])


2.3 UNet网络

  UNet网络的作用是接收带噪输入并预测噪声,实现去噪。UNet的输入有大小为 [ 1 , 77 , 1024 ] {[1, 77, 1024]} [1,77,1024] 的文本嵌入、大小为 [ 4 , 64 , 64 ] {[4, 64, 64]} [4,64,64] 的图像隐特征,以及时间步。

# 创建输入
timestep = pipe.scheduler.timesteps[0]
latents = torch.randn(1, 4, 64, 64).to(device)
text_embeddings = torch.randn(1, 77, 1024).to(device)

# 模型预测:
with torch.no_grad():
  unet_output = pipe.unet(latents, timestep, text_embeddings).sample
print('UNet output shape:', unet_output.shape) 


UNet output shape: torch.Size([1, 4, 64, 64])

2.4 调度器

  前向加噪过程: x t = α ˉ t x 0 + 1 − α ˉ t ε {{x_t} = \sqrt {{{\bar \alpha }_t}} {x_0} + \sqrt {1 - {{\bar \alpha }_t}} \varepsilon } xt​=αˉt​ ​x0​+1−αˉt​ ​ε

plt.plot(pipe.scheduler.alphas_cumprod, label=r'$\bar{\alpha}$')
plt.xlabel('Timestep (high noise to low noise ->)')
plt.title('Noise schedule')


from diffusers import LMSDiscreteScheduler

# 更换调度器
pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config)

# 输出配置参数
print('Scheduler config:', pipe.scheduler)

# 用新的调度器生成图像
pipe(prompt="Beautiful pastoral scenery, beautiful mountains and waters", 
     height=480, width=480, num_inference_steps=50,


Scheduler config: LMSDiscreteScheduler {
  "_class_name": "LMSDiscreteScheduler",
  "_diffusers_version": "0.26.3",
  "beta_end": 0.012,
  "beta_schedule": "scaled_linear",
  "beta_start": 0.00085,
  "clip_sample": false,
  "num_train_timesteps": 1000,
  "prediction_type": "epsilon",
  "set_alpha_to_one": false,
  "skip_prk_steps": true,
  "steps_offset": 1,
  "timestep_spacing": "linspace",
  "trained_betas": null,
  "use_karras_sigmas": false


2.5 自定义循环采样

  探索了Stable Diffusion的各个组件,就可以自定义循环采样过程,将其组装起来实现文生图:

guidance_scale = 8
num_inference_steps = 60
prompt = "A cute little monkey is standing on a tree"
negative_prompt = "zoomed in, blurry, oversaturated, warped"

# 文本编码
text_embeddings = pipe._encode_prompt(prompt, device, 1, True, negative_prompt)

# 创建随机噪声作为起点
latents = torch.randn((1, 4, 64, 64), device=device, generator=generator)
latents *= pipe.scheduler.init_noise_sigma

# 准备调度器
pipe.scheduler.set_timesteps(num_inference_steps, device=device)

# 循环采样
for i, t in enumerate(pipe.scheduler.timesteps):
    # 分类引导扩大隐特征
    latent_model_input = torch.cat([latents] * 2)

    # 应用调度器
    latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t)

    # 噪声预测
    with torch.no_grad():
        noise_pred = pipe.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample

    # 进行引导
    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

    # 去噪计算前一个样本 x_t -> x_t-1
    latents = pipe.scheduler.step(noise_pred, t, latents).prev_sample

# 将隐特征转换到像素域
with torch.no_grad():
    image = pipe.decode_latents(latents.detach())

# 可视化
final_image = pipe.numpy_to_pil(image)[0]



