当前位置:AIGC资讯 > AIGC > 正文

stable diffusion推理过程代码梳理

最近在看stable diffusion,想梳理一下代码流程,以便之后查阅

从txt2img.py开始看

1.首先是对文本进行编码

(1)调用的是 stable-diffusion/ldm/models/diffusion/ddpm.py的get_learned_conditioning函数

(2) 第555行表示使用CLIP的文本编码器对输入的文本进行编码,调用的是stable-diffusion/ldm/modules/encoders/modules.py中的FrozenCLIPEmbedder类

 2.进行采样操作

 (1)调用plms中的采样操作,在stable-diffusion/ldm/models/diffusion/plms.py中

生成时间步长self.ddim_timesteps= [  1  21  41  61  81 101 121 141 161 181 201 221 241 261 281 301 321 341 361 381 401 421 441 461 481 501 521 541 561 581 601 621 641 661 681 701 721 741 761 781 801 821 841 861 881 901 921 941 961 981] 

 (2)调用self.plms_sampling函数 

时间步的循环是从这里开始的

调用self.p_sample_plms函数

 调用stable-diffusion/ldm/models/diffusion/ddpm.py的apply_model函数

调用同文件下的DiffusionWrapper类,key="crossattn",c_crossattn=torch.cat([unconditional_conditioning, c])

 调用了stable-diffusion/ldm/modules/diffusionmodules/openaimodel.py里面的UnetModel类

self.input_blocks的定义为

TimestepEmbedSequential的定义为

其中,TimestepBlock类型的layer为ResBlock,TimestepEmbedSequential的结构图可以表示成下图。

ResBlock的代码如下

SpatialTransforme在stable-diffusion/ldm/modules/attention.py中定义如下

BasicTransformerBlock展示了图像和文本的融合过程

 CrossAttention的定义如下,图像作为Q,文本作为K和V

 UNetModel的模型结构可参考如下Stable Diffusion 原理介绍与源码分析(一) - 知乎 (zhihu.com)

2.关于图像解码部分

得到去噪后的图像特征后进行解码

调用的是ddpm中的decode_first_stage函数 ,调用AutoencoderKL中的解码器

AutoencoderKL的解码器输出的就是最后的图像

想要了解更多扩散模型的知识,推荐这个视频54、Probabilistic Diffusion Model概率扩散模型理论与完整PyTorch代码详细解读_哔哩哔哩_bilibili

更新时间 2024-01-25