文章目录
整体代码 unet解析 self.input_blocks middle_blocks self.output_blocks保姆级讲解 Stable Diffusion:
https://mp.weixin.qq.com/s?__biz=Mzk0MzIzODM5MA==&mid=2247486486&idx=1&sn=aff9ed60bba2cbf9efd32aa68557c93b&chksm=c337b18ff4403899d24ac32a60dbfd0402aab7309e8442dabdcb14cd61cfb55ad6cc1f977b3b#rd
整体代码
# 1、prompt编码为token。编码器为FrozenCLIPEmbedde(包括1层的 CLIPTextEmbeddings 和12层的自注意力encoder)
c = self.cond_stage_model.encode(c) # (c为输入的提示语句,重复2次) 输出:(2,77,768)
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
# self.tokenizer来自 transformers包中的 预训练CLIPTokenizer
tokens = batch_encoding["input_ids"].to(self.device) # (2,77)一句话编码为77维
outputs = self.transformer(input_ids=tokens).last_hidden_state # 12层self-atten,结果(2,77,768)
# 2、
samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
conditioning=c,
batch_size=opt.n_samples,
shape=shape,
verbose=False,
unconditional_guidance_scale=opt.scale,
unconditional_conditioning=uc,
eta=opt.ddim_eta,
x_T=start_code)
# 01、
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) # S=50
# 这一步是ddim中,预先register超参数,如a的连乘等
# Data shape for PLMS sampling is (2, 4, 32, 32)
# 02、
samples, intermediates = self.plms_sampling(conditioning, size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask, x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T )
img = torch.randn(shape, device=device) # (2,4,32,32)
for i, step in enumerate(iterator):
index = total_steps - i - 1 # index=50-i-1, step=981
ts = torch.full((b,), step, device=device, dtype=torch.long) # [981,981]
outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised, temperature=temperature,
noise_dropout=noise_dropout, score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
old_eps=old_eps, t_next=ts_next)
c_in = torch.cat([unconditional_conditioning, c]) # 添加一个空字符,与promt拼接
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) # timesteps:[981,981,981,981] -> (4,320)
emb = self.time_embed(t_emb) # 2*linear:(4,320) -> (4,1280)
# unet中带入embed与prompt,具体见源码
for module in self.input_blocks:
h = module(h, emb, context) # 输入(4,4,32,32) (4,1280) (4,77,768)
hs.append(h)
h = self.middle_block(h, emb, context)
for module in self.output_blocks:
h = th.cat([h, hs.pop()], dim=1) # (4,1280,4,4) -> (4,2560,4,4)
h = module(h, emb, context)
return self.out(h) # (4,320,32,32)卷积为(4,4,32,32)
# 3、
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) # 上步中得到的结果拆开:(2,4,32,32
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) # 用7.5乘以二者差距,再加回空语句生成的图
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) # DDIM计算:e_t(2,4,32,32) index:49 -> (2,4,32,32)
# 4、
x_samples_ddim = model.decode_first_stage(samples_ddim) # (2,4,32,32)
h = self.conv_in(z) # 卷积4->512
x = torch.nn.functional.interpolate(h, scale_factor=2.0, mode="nearest") #(2,512,64,64)
h = self.up[i_level].block[i_block](h) # 经过几次卷积与上采样
h = self.norm_out(h) # (2,128,256,256)
h = nonlinearity(h) # x*torch.sigmoid(x)
h = self.conv_out(h) # conv(128,3) -》(2,3,256,256)
# 5、
后处理
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy()
x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim)
x_checked_image_torch = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2)
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
img = Image.fromarray(x_sample.astype(np.uint8))
img.save(os.path.join(sample_path, f"{base_count:05}.png"))
unet解析
DDIM中的Unet 包含输入模块、中间模块、输出模块三部分:
self.input_blocks
包含12个不同的 TimestepEmbedSequential结构,下面列举三种:
# 1、self.input_blocks
ModuleList(
(0): TimestepEmbedSequential(
(0): Conv2d(4, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
(1): TimestepEmbedSequential(
(0): ResBlock(
(in_layers): Sequential(
(0): GroupNorm32(32, 320, eps=1e-05, affine=True)
(1): SiLU()
(2): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
(h_upd): Identity()
(x_upd): Identity()
(emb_layers): Sequential(
(0): SiLU()
(1): Linear(in_features=1280, out_features=320, bias=True)
)
(out_layers): Sequential(
(0): GroupNorm32(32, 320, eps=1e-05, affine=True)
(1): SiLU()
(2): Dropout(p=0, inplace=False)
(3): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
(skip_connection): Identity()
)
(1): SpatialTransformer(
(norm): GroupNorm(32, 320, eps=1e-06, affine=True)
(proj_in): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))
(transformer_blocks): ModuleList(
(0): BasicTransformerBlock(
(attn1): CrossAttention(
(to_q): Linear(in_features=320, out_features=320, bias=False)
(to_k): Linear(in_features=320, out_features=320, bias=False)
(to_v): Linear(in_features=320, out_features=320, bias=False)
(to_out): Sequential(
(0): Linear(in_features=320, out_features=320, bias=True)
(1): Dropout(p=0.0, inplace=False)
)
)
(ff): FeedForward(
(net): Sequential(
(0): GEGLU(
(proj): Linear(in_features=320, out_features=2560, bias=True)
)
(1): Dropout(p=0.0, inplace=False)
(2): Linear(in_features=1280, out_features=320, bias=True)
)
)
(attn2): CrossAttention(
(to_q): Linear(in_features=320, out_features=320, bias=False)
(to_k): Linear(in_features=768, out_features=320, bias=False)
(to_v): Linear(in_features=768, out_features=320, bias=False)
(to_out): Sequential(
(0): Linear(in_features=320, out_features=320, bias=True)
(1): Dropout(p=0.0, inplace=False)
)
)
(norm1): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
(norm2): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
(norm3): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
)
)
(proj_out): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))
)
)
(6): TimestepEmbedSequential(
(0): Downsample(
(op): Conv2d(640, 640, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
)
)
前向过程:
为h添加emb和交与propmt的交叉注意力,会执行多次
emb_out = self.emb_layers(emb) # (4,1280)卷积为(4,320)
h = h + emb_out # (4,320,32,32)+(4,320,1,1)
x = self.attn1(self.norm1(x)) + x # 自注意力:x(4,1024,320)映射到qkv,均320维
x = self.attn2(self.norm2(x), context=context) + x # 交叉注意力:context(4,77,768)映射到kv的320维
x = self.ff(self.norm3(x)) + x
噪音图像h(4,4,32,32)在其中变化为:(4,320,32,32)(4,320,16,16)(4,640,16,16)(4,1280,8,8)(4,1280,4,4)
middle_blocks
TimestepEmbedSequential(
(0): ResBlock(
(in_layers): Sequential(
(0): GroupNorm32(32, 1280, eps=1e-05, affine=True)
(1): SiLU()
(2): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
(h_upd): Identity()
(x_upd): Identity()
(emb_layers): Sequential(
(0): SiLU()
(1): Linear(in_features=1280, out_features=1280, bias=True)
)
(out_layers): Sequential(
(0): GroupNorm32(32, 1280, eps=1e-05, affine=True)
(1): SiLU()
(2): Dropout(p=0, inplace=False)
(3): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
(skip_connection): Identity()
)
(1): SpatialTransformer(
(norm): GroupNorm(32, 1280, eps=1e-06, affine=True)
(proj_in): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
(transformer_blocks): ModuleList(
(0): BasicTransformerBlock(
(attn1): CrossAttention(
(to_q): Linear(in_features=1280, out_features=1280, bias=False)
(to_k): Linear(in_features=1280, out_features=1280, bias=False)
(to_v): Linear(in_features=1280, out_features=1280, bias=False)
(to_out): Sequential(
(0): Linear(in_features=1280, out_features=1280, bias=True)
(1): Dropout(p=0.0, inplace=False)
)
)
(ff): FeedForward(
(net): Sequential(
(0): GEGLU(
(proj): Linear(in_features=1280, out_features=10240, bias=True)
)
(1): Dropout(p=0.0, inplace=False)
(2): Linear(in_features=5120, out_features=1280, bias=True)
)
)
(attn2): CrossAttention(
(to_q): Linear(in_features=1280, out_features=1280, bias=False)
(to_k): Linear(in_features=768, out_features=1280, bias=False)
(to_v): Linear(in_features=768, out_features=1280, bias=False)
(to_out): Sequential(
(0): Linear(in_features=1280, out_features=1280, bias=True)
(1): Dropout(p=0.0, inplace=False)
)
)
(norm1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
(norm2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
(norm3): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
)
)
(proj_out): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
)
(2): ResBlock(
(in_layers): Sequential(
(0): GroupNorm32(32, 1280, eps=1e-05, affine=True)
(1): SiLU()
(2): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
(h_upd): Identity()
(x_upd): Identity()
(emb_layers): Sequential(
(0): SiLU()
(1): Linear(in_features=1280, out_features=1280, bias=True)
)
(out_layers): Sequential(
(0): GroupNorm32(32, 1280, eps=1e-05, affine=True)
(1): SiLU()
(2): Dropout(p=0, inplace=False)
(3): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
(skip_connection): Identity()
)
self.output_blocks
与输入模块相同,包含12个 TimestepEmbedSequential,顺序相反。