当前位置:AIGC资讯 > AIGC > 正文

AIGC笔记--Stable Diffusion源码剖析之UNetModel

1--前言

        以论文《High-Resolution Image Synthesis with Latent Diffusion Models》  开源的项目为例,剖析Stable Diffusion经典组成部分,巩固学习加深印象。

2--UNetModel

一个可以debug的小demo:SD_UNet

        以文生图为例,剖析UNetModel核心组成模块。

2-1--Forward总揽

提供的文生图Demo中,实际传入的参数只有x、timesteps和context三个,其中:

        x 表示随机初始化的噪声Tensor(shape: [B*2, 4, 64, 64],*2表示使用Classifier-Free Diffusion Guidance)。

        timesteps 表示去噪过程中每一轮传入的timestep(shape: [B*2])。

        context表示经过CLIP编码后对应的文本Prompt(shape: [B*2, 77, 768])。

    def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
        """
        Apply the model to an input batch.
        :param x: an [N x C x ...] Tensor of inputs.
        :param timesteps: a 1-D batch of timesteps.
        :param context: conditioning plugged in via crossattn
        :param y: an [N] Tensor of labels, if class-conditional.
        :return: an [N x C x ...] Tensor of outputs.
        """
        assert (y is not None) == (
            self.num_classes is not None
        ), "must specify y if and only if the model is class-conditional"
        hs = []
        t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) # Create sinusoidal timestep embeddings.
        emb = self.time_embed(t_emb) # MLP

        if self.num_classes is not None:
            assert y.shape == (x.shape[0],)
            emb = emb + self.label_emb(y)

        h = x.type(self.dtype)
        for module in self.input_blocks:
            h = module(h, emb, context)
            hs.append(h)
        h = self.middle_block(h, emb, context)
        for module in self.output_blocks:
            h = th.cat([h, hs.pop()], dim=1)
            h = module(h, emb, context)
        h = h.type(x.dtype)
        if self.predict_codebook_ids:
            return self.id_predictor(h)
        else:
            return self.out(h)

2-2--timestep embedding生成

        使用函数 timestep_embedding() 和 self.time_embed() 对传入的timestep进行位置编码,生成sinusoidal timestep embeddings。

        其中 timestep_embedding() 函数定义如下,而self.time_embed()是一个MLP函数。

def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
    """
    Create sinusoidal timestep embeddings.
    :param timesteps: a 1-D Tensor of N indices, one per batch element.
                      These may be fractional.
    :param dim: the dimension of the output.
    :param max_period: controls the minimum frequency of the embeddings.
    :return: an [N x dim] Tensor of positional embeddings.
    """
    if not repeat_only:
        half = dim // 2
        freqs = torch.exp(
            -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
        ).to(device=timesteps.device)
        args = timesteps[:, None].float() * freqs[None]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if dim % 2:
            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
    else:
        embedding = repeat(timesteps, 'b -> b d', d=dim)
    return embedding
self.time_embed = nn.Sequential(
    linear(model_channels, time_embed_dim),
    nn.SiLU(),
    linear(time_embed_dim, time_embed_dim),
)

2-3--self.input_blocks下采样

        在 Forward() 中,使用 self.input_blocks 将输入噪声进行分辨率下采样,经过下采样具体维度变化为:[B*2, 4, 64, 64] > [B*2, 1280, 8, 8];

        下采样模块共有12个 module,其组成如下:

ModuleList(
  (0): TimestepEmbedSequential(
    (0): Conv2d(4, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (1-2): 2 x TimestepEmbedSequential(
    (0): ResBlock(
      (in_layers): Sequential(
        (0): GroupNorm32(32, 320, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (h_upd): Identity()
      (x_upd): Identity()
      (emb_layers): Sequential(
        (0): SiLU()
        (1): Linear(in_features=1280, out_features=320, bias=True)
      )
      (out_layers): Sequential(
        (0): GroupNorm32(32, 320, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Dropout(p=0, inplace=False)
        (3): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (skip_connection): Identity()
    )
    (1): SpatialTransformer(
      (norm): GroupNorm(32, 320, eps=1e-06, affine=True)
      (proj_in): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))
      (transformer_blocks): ModuleList(
        (0): BasicTransformerBlock(
          (attn1): CrossAttention(
            (to_q): Linear(in_features=320, out_features=320, bias=False)
            (to_k): Linear(in_features=320, out_features=320, bias=False)
            (to_v): Linear(in_features=320, out_features=320, bias=False)
            (to_out): Sequential(
              (0): Linear(in_features=320, out_features=320, bias=True)
              (1): Dropout(p=0.0, inplace=False)
            )
          )
          (ff): FeedForward(
            (net): Sequential(
              (0): GEGLU(
                (proj): Linear(in_features=320, out_features=2560, bias=True)
              )
              (1): Dropout(p=0.0, inplace=False)
              (2): Linear(in_features=1280, out_features=320, bias=True)
            )
          )
          (attn2): CrossAttention(
            (to_q): Linear(in_features=320, out_features=320, bias=False)
            (to_k): Linear(in_features=768, out_features=320, bias=False)
            (to_v): Linear(in_features=768, out_features=320, bias=False)
            (to_out): Sequential(
              (0): Linear(in_features=320, out_features=320, bias=True)
              (1): Dropout(p=0.0, inplace=False)
            )
          )
          (norm1): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
          (norm3): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
        )
      )
      (proj_out): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))
    )
  )
  (3): TimestepEmbedSequential(
    (0): Downsample(
      (op): Conv2d(320, 320, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    )
  )
  (4): TimestepEmbedSequential(
    (0): ResBlock(
      (in_layers): Sequential(
        (0): GroupNorm32(32, 320, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Conv2d(320, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (h_upd): Identity()
      (x_upd): Identity()
      (emb_layers): Sequential(
        (0): SiLU()
        (1): Linear(in_features=1280, out_features=640, bias=True)
      )
      (out_layers): Sequential(
        (0): GroupNorm32(32, 640, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Dropout(p=0, inplace=False)
        (3): Conv2d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (skip_connection): Conv2d(320, 640, kernel_size=(1, 1), stride=(1, 1))
    )
    (1): SpatialTransformer(
      (norm): GroupNorm(32, 640, eps=1e-06, affine=True)
      (proj_in): Conv2d(640, 640, kernel_size=(1, 1), stride=(1, 1))
      (transformer_blocks): ModuleList(
        (0): BasicTransformerBlock(
          (attn1): CrossAttention(
            (to_q): Linear(in_features=640, out_features=640, bias=False)
            (to_k): Linear(in_features=640, out_features=640, bias=False)
            (to_v): Linear(in_features=640, out_features=640, bias=False)
            (to_out): Sequential(
              (0): Linear(in_features=640, out_features=640, bias=True)
              (1): Dropout(p=0.0, inplace=False)
            )
          )
          (ff): FeedForward(
            (net): Sequential(
              (0): GEGLU(
                (proj): Linear(in_features=640, out_features=5120, bias=True)
              )
              (1): Dropout(p=0.0, inplace=False)
              (2): Linear(in_features=2560, out_features=640, bias=True)
            )
          )
          (attn2): CrossAttention(
            (to_q): Linear(in_features=640, out_features=640, bias=False)
            (to_k): Linear(in_features=768, out_features=640, bias=False)
            (to_v): Linear(in_features=768, out_features=640, bias=False)
            (to_out): Sequential(
              (0): Linear(in_features=640, out_features=640, bias=True)
              (1): Dropout(p=0.0, inplace=False)
            )
          )
          (norm1): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
          (norm3): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
        )
      )
      (proj_out): Conv2d(640, 640, kernel_size=(1, 1), stride=(1, 1))
    )
  )
  (5): TimestepEmbedSequential(
    (0): ResBlock(
      (in_layers): Sequential(
        (0): GroupNorm32(32, 640, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Conv2d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (h_upd): Identity()
      (x_upd): Identity()
      (emb_layers): Sequential(
        (0): SiLU()
        (1): Linear(in_features=1280, out_features=640, bias=True)
      )
      (out_layers): Sequential(
        (0): GroupNorm32(32, 640, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Dropout(p=0, inplace=False)
        (3): Conv2d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (skip_connection): Identity()
    )
    (1): SpatialTransformer(
      (norm): GroupNorm(32, 640, eps=1e-06, affine=True)
      (proj_in): Conv2d(640, 640, kernel_size=(1, 1), stride=(1, 1))
      (transformer_blocks): ModuleList(
        (0): BasicTransformerBlock(
          (attn1): CrossAttention(
            (to_q): Linear(in_features=640, out_features=640, bias=False)
            (to_k): Linear(in_features=640, out_features=640, bias=False)
            (to_v): Linear(in_features=640, out_features=640, bias=False)
            (to_out): Sequential(
              (0): Linear(in_features=640, out_features=640, bias=True)
              (1): Dropout(p=0.0, inplace=False)
            )
          )
          (ff): FeedForward(
            (net): Sequential(
              (0): GEGLU(
                (proj): Linear(in_features=640, out_features=5120, bias=True)
              )
              (1): Dropout(p=0.0, inplace=False)
              (2): Linear(in_features=2560, out_features=640, bias=True)
            )
          )
          (attn2): CrossAttention(
            (to_q): Linear(in_features=640, out_features=640, bias=False)
            (to_k): Linear(in_features=768, out_features=640, bias=False)
            (to_v): Linear(in_features=768, out_features=640, bias=False)
            (to_out): Sequential(
              (0): Linear(in_features=640, out_features=640, bias=True)
              (1): Dropout(p=0.0, inplace=False)
            )
          )
          (norm1): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
          (norm3): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
        )
      )
      (proj_out): Conv2d(640, 640, kernel_size=(1, 1), stride=(1, 1))
    )
  )
  (6): TimestepEmbedSequential(
    (0): Downsample(
      (op): Conv2d(640, 640, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    )
  )
  (7): TimestepEmbedSequential(
    (0): ResBlock(
      (in_layers): Sequential(
        (0): GroupNorm32(32, 640, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Conv2d(640, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (h_upd): Identity()
      (x_upd): Identity()
      (emb_layers): Sequential(
        (0): SiLU()
        (1): Linear(in_features=1280, out_features=1280, bias=True)
      )
      (out_layers): Sequential(
        (0): GroupNorm32(32, 1280, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Dropout(p=0, inplace=False)
        (3): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (skip_connection): Conv2d(640, 1280, kernel_size=(1, 1), stride=(1, 1))
    )
    (1): SpatialTransformer(
      (norm): GroupNorm(32, 1280, eps=1e-06, affine=True)
      (proj_in): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
      (transformer_blocks): ModuleList(
        (0): BasicTransformerBlock(
          (attn1): CrossAttention(
            (to_q): Linear(in_features=1280, out_features=1280, bias=False)
            (to_k): Linear(in_features=1280, out_features=1280, bias=False)
            (to_v): Linear(in_features=1280, out_features=1280, bias=False)
            (to_out): Sequential(
              (0): Linear(in_features=1280, out_features=1280, bias=True)
              (1): Dropout(p=0.0, inplace=False)
            )
          )
          (ff): FeedForward(
            (net): Sequential(
              (0): GEGLU(
                (proj): Linear(in_features=1280, out_features=10240, bias=True)
              )
              (1): Dropout(p=0.0, inplace=False)
              (2): Linear(in_features=5120, out_features=1280, bias=True)
            )
          )
          (attn2): CrossAttention(
            (to_q): Linear(in_features=1280, out_features=1280, bias=False)
            (to_k): Linear(in_features=768, out_features=1280, bias=False)
            (to_v): Linear(in_features=768, out_features=1280, bias=False)
            (to_out): Sequential(
              (0): Linear(in_features=1280, out_features=1280, bias=True)
              (1): Dropout(p=0.0, inplace=False)
            )
          )
          (norm1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
          (norm3): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
        )
      )
      (proj_out): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
    )
  )
  (8): TimestepEmbedSequential(
    (0): ResBlock(
      (in_layers): Sequential(
        (0): GroupNorm32(32, 1280, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (h_upd): Identity()
      (x_upd): Identity()
      (emb_layers): Sequential(
        (0): SiLU()
        (1): Linear(in_features=1280, out_features=1280, bias=True)
      )
      (out_layers): Sequential(
        (0): GroupNorm32(32, 1280, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Dropout(p=0, inplace=False)
        (3): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (skip_connection): Identity()
    )
    (1): SpatialTransformer(
      (norm): GroupNorm(32, 1280, eps=1e-06, affine=True)
      (proj_in): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
      (transformer_blocks): ModuleList(
        (0): BasicTransformerBlock(
          (attn1): CrossAttention(
            (to_q): Linear(in_features=1280, out_features=1280, bias=False)
            (to_k): Linear(in_features=1280, out_features=1280, bias=False)
            (to_v): Linear(in_features=1280, out_features=1280, bias=False)
            (to_out): Sequential(
              (0): Linear(in_features=1280, out_features=1280, bias=True)
              (1): Dropout(p=0.0, inplace=False)
            )
          )
          (ff): FeedForward(
            (net): Sequential(
              (0): GEGLU(
                (proj): Linear(in_features=1280, out_features=10240, bias=True)
              )
              (1): Dropout(p=0.0, inplace=False)
              (2): Linear(in_features=5120, out_features=1280, bias=True)
            )
          )
          (attn2): CrossAttention(
            (to_q): Linear(in_features=1280, out_features=1280, bias=False)
            (to_k): Linear(in_features=768, out_features=1280, bias=False)
            (to_v): Linear(in_features=768, out_features=1280, bias=False)
            (to_out): Sequential(
              (0): Linear(in_features=1280, out_features=1280, bias=True)
              (1): Dropout(p=0.0, inplace=False)
            )
          )
          (norm1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
          (norm3): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
        )
      )
      (proj_out): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
    )
  )
  (9): TimestepEmbedSequential(
    (0): Downsample(
      (op): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    )
  )
  (10-11): 2 x TimestepEmbedSequential(
    (0): ResBlock(
      (in_layers): Sequential(
        (0): GroupNorm32(32, 1280, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (h_upd): Identity()
      (x_upd): Identity()
      (emb_layers): Sequential(
        (0): SiLU()
        (1): Linear(in_features=1280, out_features=1280, bias=True)
      )
      (out_layers): Sequential(
        (0): GroupNorm32(32, 1280, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Dropout(p=0, inplace=False)
        (3): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (skip_connection): Identity()
    )
  )
)

        12个 module 都使用了 TimestepEmbedSequential 类进行封装,根据不同的网络层,将输入噪声x与timestep embedding和prompt context进行运算。

class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
    """
    A sequential module that passes timestep embeddings to the children that
    support it as an extra input.
    """

    def forward(self, x, emb, context=None):
        for layer in self:
            if isinstance(layer, TimestepBlock):
                x = layer(x, emb)
            elif isinstance(layer, SpatialTransformer):
                x = layer(x, context)
            else:
                x = layer(x)
        return x

2-3-1--Module0

 Module 0 是一个2D卷积层,主要对输入噪声进行特征提取;

# init 初始化
self.input_blocks = nn.ModuleList(
    [
        TimestepEmbedSequential(
            conv_nd(dims, in_channels, model_channels, 3, padding=1)
        )
    ]
)

# 打印 self.input_blocks[0]
TimestepEmbedSequential(
  (0): Conv2d(4, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)

2-3-2--Module1和Module2

        Module1和Module2的结构相同,都由一个ResBlock和一个SpatialTransformer组成;

# init 初始化
for _ in range(num_res_blocks):
                layers = [
                    ResBlock(
                        ch,
                        time_embed_dim,
                        dropout,
                        out_channels=mult * model_channels,
                        dims=dims,
                        use_checkpoint=use_checkpoint,
                        use_scale_shift_norm=use_scale_shift_norm,
                    )
                ]
                ch = mult * model_channels
                if ds in attention_resolutions:
                    if num_head_channels == -1:
                        dim_head = ch // num_heads
                    else:
                        num_heads = ch // num_head_channels
                        dim_head = num_head_channels
                    if legacy:
                        #num_heads = 1
                        dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
                    layers.append(
                        AttentionBlock(
                            ch,
                            use_checkpoint=use_checkpoint,
                            num_heads=num_heads,
                            num_head_channels=dim_head,
                            use_new_attention_order=use_new_attention_order,
                        ) if not use_spatial_transformer else SpatialTransformer(
                            ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
                        )
                    )
                self.input_blocks.append(TimestepEmbedSequential(*layers))
                self._feature_size += ch
                input_block_chans.append(ch)

# 打印 self.input_blocks[1]
TimestepEmbedSequential(
  (0): ResBlock(
    (in_layers): Sequential(
      (0): GroupNorm32(32, 320, eps=1e-05, affine=True)
      (1): SiLU()
      (2): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (h_upd): Identity()
    (x_upd): Identity()
    (emb_layers): Sequential(
      (0): SiLU()
      (1): Linear(in_features=1280, out_features=320, bias=True)
    )
    (out_layers): Sequential(
      (0): GroupNorm32(32, 320, eps=1e-05, affine=True)
      (1): SiLU()
      (2): Dropout(p=0, inplace=False)
      (3): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (skip_connection): Identity()
  )
  (1): SpatialTransformer(
    (norm): GroupNorm(32, 320, eps=1e-06, affine=True)
    (proj_in): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))
    (transformer_blocks): ModuleList(
      (0): BasicTransformerBlock(
        (attn1): CrossAttention(
          (to_q): Linear(in_features=320, out_features=320, bias=False)
          (to_k): Linear(in_features=320, out_features=320, bias=False)
          (to_v): Linear(in_features=320, out_features=320, bias=False)
          (to_out): Sequential(
            (0): Linear(in_features=320, out_features=320, bias=True)
            (1): Dropout(p=0.0, inplace=False)
          )
        )
        (ff): FeedForward(
          (net): Sequential(
            (0): GEGLU(
              (proj): Linear(in_features=320, out_features=2560, bias=True)
            )
            (1): Dropout(p=0.0, inplace=False)
            (2): Linear(in_features=1280, out_features=320, bias=True)
          )
        )
        (attn2): CrossAttention(
          (to_q): Linear(in_features=320, out_features=320, bias=False)
          (to_k): Linear(in_features=768, out_features=320, bias=False)
          (to_v): Linear(in_features=768, out_features=320, bias=False)
          (to_out): Sequential(
            (0): Linear(in_features=320, out_features=320, bias=True)
            (1): Dropout(p=0.0, inplace=False)
          )
        )
        (norm1): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
        (norm3): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
      )
    )
    (proj_out): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))
  )
)

# 打印 self.input_blocks[2]
TimestepEmbedSequential(
  (0): ResBlock(
    (in_layers): Sequential(
      (0): GroupNorm32(32, 320, eps=1e-05, affine=True)
      (1): SiLU()
      (2): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (h_upd): Identity()
    (x_upd): Identity()
    (emb_layers): Sequential(
      (0): SiLU()
      (1): Linear(in_features=1280, out_features=320, bias=True)
    )
    (out_layers): Sequential(
      (0): GroupNorm32(32, 320, eps=1e-05, affine=True)
      (1): SiLU()
      (2): Dropout(p=0, inplace=False)
      (3): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (skip_connection): Identity()
  )
  (1): SpatialTransformer(
    (norm): GroupNorm(32, 320, eps=1e-06, affine=True)
    (proj_in): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))
    (transformer_blocks): ModuleList(
      (0): BasicTransformerBlock(
        (attn1): CrossAttention(
          (to_q): Linear(in_features=320, out_features=320, bias=False)
          (to_k): Linear(in_features=320, out_features=320, bias=False)
          (to_v): Linear(in_features=320, out_features=320, bias=False)
          (to_out): Sequential(
            (0): Linear(in_features=320, out_features=320, bias=True)
            (1): Dropout(p=0.0, inplace=False)
          )
        )
        (ff): FeedForward(
          (net): Sequential(
            (0): GEGLU(
              (proj): Linear(in_features=320, out_features=2560, bias=True)
            )
            (1): Dropout(p=0.0, inplace=False)
            (2): Linear(in_features=1280, out_features=320, bias=True)
          )
        )
        (attn2): CrossAttention(
          (to_q): Linear(in_features=320, out_features=320, bias=False)
          (to_k): Linear(in_features=768, out_features=320, bias=False)
          (to_v): Linear(in_features=768, out_features=320, bias=False)
          (to_out): Sequential(
            (0): Linear(in_features=320, out_features=320, bias=True)
            (1): Dropout(p=0.0, inplace=False)
          )
        )
        (norm1): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
        (norm3): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
      )
    )
    (proj_out): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))
  )
)

2-3-3--Module3

        Module3是一个下采样2D卷积层。

# init 初始化
if level != len(channel_mult) - 1:
    out_ch = ch
    self.input_blocks.append(
        TimestepEmbedSequential(
            ResBlock(
                ch,
                time_embed_dim,
                dropout,
                out_channels=out_ch,
                dims=dims,
                use_checkpoint=use_checkpoint,
                use_scale_shift_norm=use_scale_shift_norm,
                down=True,
            )
            if resblock_updown
            else Downsample(
                ch, conv_resample, dims=dims, out_channels=out_ch
            )
        )
    )

# 打印 self.input_blocks[3]
TimestepEmbedSequential(
  (0): Downsample(
    (op): Conv2d(320, 320, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  )
)

2-3-4--Module4、Module5、Module7和Module8

        与Module1和Module2的结构相同,都由一个ResBlock和一个SpatialTransformer组成,只有特征维度上的区别;

2-3-4--Module6和Module9

        与Module3的结构相同,是一个下采样2D卷积层。

2-3--5-Module10和Module11

        Module10和Module12的结构相同,只由一个ResBlock组成。

# 打印 self.input_blocks[10]
TimestepEmbedSequential(
  (0): ResBlock(
    (in_layers): Sequential(
      (0): GroupNorm32(32, 1280, eps=1e-05, affine=True)
      (1): SiLU()
      (2): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (h_upd): Identity()
    (x_upd): Identity()
    (emb_layers): Sequential(
      (0): SiLU()
      (1): Linear(in_features=1280, out_features=1280, bias=True)
    )
    (out_layers): Sequential(
      (0): GroupNorm32(32, 1280, eps=1e-05, affine=True)
      (1): SiLU()
      (2): Dropout(p=0, inplace=False)
      (3): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (skip_connection): Identity()
  )
)

# 打印 self.input_blocks[11]
TimestepEmbedSequential(
  (0): ResBlock(
    (in_layers): Sequential(
      (0): GroupNorm32(32, 1280, eps=1e-05, affine=True)
      (1): SiLU()
      (2): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (h_upd): Identity()
    (x_upd): Identity()
    (emb_layers): Sequential(
      (0): SiLU()
      (1): Linear(in_features=1280, out_features=1280, bias=True)
    )
    (out_layers): Sequential(
      (0): GroupNorm32(32, 1280, eps=1e-05, affine=True)
      (1): SiLU()
      (2): Dropout(p=0, inplace=False)
      (3): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (skip_connection): Identity()
  )
)

2-3-6--ResBlock

        ResBlock的输入是噪声图x和timestep embedding,通过卷积处理和残差连接等方式将timestep embedding融入噪声图特征中,核心代码如下:

class ResBlock(TimestepBlock):
    """
    A residual block that can optionally change the number of channels.
    :param channels: the number of input channels.
    :param emb_channels: the number of timestep embedding channels.
    :param dropout: the rate of dropout.
    :param out_channels: if specified, the number of out channels.
    :param use_conv: if True and out_channels is specified, use a spatial
        convolution instead of a smaller 1x1 convolution to change the
        channels in the skip connection.
    :param dims: determines if the signal is 1D, 2D, or 3D.
    :param use_checkpoint: if True, use gradient checkpointing on this module.
    :param up: if True, use this block for upsampling.
    :param down: if True, use this block for downsampling.
    """

    def __init__(
        self,
        channels,
        emb_channels,
        dropout,
        out_channels=None,
        use_conv=False,
        use_scale_shift_norm=False,
        dims=2,
        use_checkpoint=False,
        up=False,
        down=False,
    ):
        super().__init__()
        self.channels = channels
        self.emb_channels = emb_channels
        self.dropout = dropout
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.use_checkpoint = use_checkpoint
        self.use_scale_shift_norm = use_scale_shift_norm

        self.in_layers = nn.Sequential(
            normalization(channels),
            nn.SiLU(),
            conv_nd(dims, channels, self.out_channels, 3, padding=1),
        )

        self.updown = up or down

        if up:
            self.h_upd = Upsample(channels, False, dims)
            self.x_upd = Upsample(channels, False, dims)
        elif down:
            self.h_upd = Downsample(channels, False, dims)
            self.x_upd = Downsample(channels, False, dims)
        else:
            self.h_upd = self.x_upd = nn.Identity()

        self.emb_layers = nn.Sequential(
            nn.SiLU(),
            linear(
                emb_channels,
                2 * self.out_channels if use_scale_shift_norm else self.out_channels,
            ),
        )
        self.out_layers = nn.Sequential(
            normalization(self.out_channels),
            nn.SiLU(),
            nn.Dropout(p=dropout),
            zero_module(
                conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
            ),
        )

        if self.out_channels == channels:
            self.skip_connection = nn.Identity()
        elif use_conv:
            self.skip_connection = conv_nd(
                dims, channels, self.out_channels, 3, padding=1
            )
        else:
            self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)

    def forward(self, x, emb):
        """
        Apply the block to a Tensor, conditioned on a timestep embedding.
        :param x: an [N x C x ...] Tensor of features.
        :param emb: an [N x emb_channels] Tensor of timestep embeddings.
        :return: an [N x C x ...] Tensor of outputs.
        """
        return checkpoint(
            self._forward, (x, emb), self.parameters(), self.use_checkpoint
        )


    def _forward(self, x, emb):
        if self.updown:
            in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
            h = in_rest(x)
            h = self.h_upd(h)
            x = self.x_upd(x)
            h = in_conv(h)
        else:
            h = self.in_layers(x) # [6, 320, 64, 64] -> [6, 320, 64, 64]
        emb_out = self.emb_layers(emb).type(h.dtype) # [6, 1280] -> [6, 320]
        while len(emb_out.shape) < len(h.shape): # [6, 320] -> [6, 320, 1, 1]
            emb_out = emb_out[..., None]
        if self.use_scale_shift_norm:
            out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
            scale, shift = th.chunk(emb_out, 2, dim=1)
            h = out_norm(h) * (1 + scale) + shift
            h = out_rest(h)
        else:
            h = h + emb_out # [6, 320, 64, 64] + [6, 320, 1, 1] -> [6, 320, 64, 64]
            h = self.out_layers(h) # [6, 320, 64, 64]
        return self.skip_connection(x) + h

2-3-7--SpatialTransformer

        SpatialTransformer的输入是噪声图x和文本特征context,通过CrossAttention机制将文本特征融入到噪声图x中,完成条件驱动文生图,核心代码如下:

from inspect import isfunction
import math
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat

from util import checkpoint


def exists(val):
    return val is not None


def uniq(arr):
    return{el: True for el in arr}.keys()


def default(val, d):
    if exists(val):
        return val
    return d() if isfunction(d) else d


def max_neg_value(t):
    return -torch.finfo(t.dtype).max


def init_(tensor):
    dim = tensor.shape[-1]
    std = 1 / math.sqrt(dim)
    tensor.uniform_(-std, std)
    return tensor


# feedforward
class GEGLU(nn.Module):
    def __init__(self, dim_in, dim_out):
        super().__init__()
        self.proj = nn.Linear(dim_in, dim_out * 2)

    def forward(self, x):
        x, gate = self.proj(x).chunk(2, dim=-1)
        return x * F.gelu(gate)


class FeedForward(nn.Module):
    def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
        super().__init__()
        inner_dim = int(dim * mult)
        dim_out = default(dim_out, dim)
        project_in = nn.Sequential(
            nn.Linear(dim, inner_dim),
            nn.GELU()
        ) if not glu else GEGLU(dim, inner_dim)

        self.net = nn.Sequential(
            project_in,
            nn.Dropout(dropout),
            nn.Linear(inner_dim, dim_out)
        )

    def forward(self, x):
        return self.net(x)


def zero_module(module):
    """
    Zero out the parameters of a module and return it.
    """
    for p in module.parameters():
        p.detach().zero_()
    return module


def Normalize(in_channels):
    return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)

class CrossAttention(nn.Module):
    def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
        super().__init__()
        inner_dim = dim_head * heads # dim_head: 40, heads: 8
        context_dim = default(context_dim, query_dim)

        self.scale = dim_head ** -0.5
        self.heads = heads

        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
        self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
        self.to_v = nn.Linear(context_dim, inner_dim, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, query_dim),
            nn.Dropout(dropout)
        )

    def forward(self, x, context=None, mask=None):
        h = self.heads # 8

        q = self.to_q(x) # [6, 4096, 320] -> [6, 4096, 320]
        context = default(context, x) # return context [6, 77, 768]
        k = self.to_k(context) # [6, 77, 768] -> [6, 77, 320]
        v = self.to_v(context) # [6, 77, 768] -> [6, 77, 320]

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) # [6, 4096, 320] -> [48, 4096, 40] # [6, 77, 320] -> [48, 77, 40]

        sim = einsum('b i d, b j d -> b i j', q, k) * self.scale # [48, 4096, 40] * [48, 77, 40] -> [48, 4096, 77]

        if exists(mask):
            mask = rearrange(mask, 'b ... -> b (...)')
            max_neg_value = -torch.finfo(sim.dtype).max
            mask = repeat(mask, 'b j -> (b h) () j', h=h)
            sim.masked_fill_(~mask, max_neg_value)

        # attention, what we cannot get enough of
        attn = sim.softmax(dim=-1) # softmax

        out = einsum('b i j, b j d -> b i d', attn, v) # [48, 4096, 77] * [48, 77, 40] -> [48, 4096, 40] 
        out = rearrange(out, '(b h) n d -> b n (h d)', h=h) # [48, 4096, 40] -> [6, 4096, 320]
        return self.to_out(out)


class BasicTransformerBlock(nn.Module):
    def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
        super().__init__()
        self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout)  # is a self-attention
        self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
        self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
                                    heads=n_heads, dim_head=d_head, dropout=dropout)  # is self-attn if context is none
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.norm3 = nn.LayerNorm(dim)
        self.checkpoint = checkpoint

    def forward(self, x, context=None):
        return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)

    def _forward(self, x, context=None):
        x = self.attn1(self.norm1(x)) + x # self Attention, [6, 4096, 320] -> [6, 4096, 320]
        x = self.attn2(self.norm2(x), context=context) + x # cross Attention, [6, 4096, 320] -> [6, 4096, 320]
        x = self.ff(self.norm3(x)) + x # FFN, [6, 4096, 320] -> [6, 4096, 320]
        return x


class SpatialTransformer(nn.Module):
    """
    Transformer block for image-like data.
    First, project the input (aka embedding)
    and reshape to b, t, d.
    Then apply standard transformer action.
    Finally, reshape to image
    """
    def __init__(self, in_channels, n_heads, d_head,
                 depth=1, dropout=0., context_dim=None):
        super().__init__()
        self.in_channels = in_channels
        inner_dim = n_heads * d_head
        self.norm = Normalize(in_channels)

        self.proj_in = nn.Conv2d(in_channels,
                                 inner_dim,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)

        self.transformer_blocks = nn.ModuleList(
            [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
                for d in range(depth)]
        )

        self.proj_out = zero_module(nn.Conv2d(inner_dim,
                                              in_channels,
                                              kernel_size=1,
                                              stride=1,
                                              padding=0))

    def forward(self, x, context=None):
        # note: if no context is given, cross-attention defaults to self-attention
        b, c, h, w = x.shape # [6, 320, 64, 64]
        x_in = x
        x = self.norm(x) # [6, 320, 64, 64]
        x = self.proj_in(x) # [6, 320, 64, 64]
        x = rearrange(x, 'b c h w -> b (h w) c') # [6, 4096, 320]
        for block in self.transformer_blocks:
            x = block(x, context=context)
        x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
        x = self.proj_out(x)
        return x + x_in

2-4--self.middle_block

        self.middle_block由两个ResBlock和一个SpatialTransformer组成:

TimestepEmbedSequential(
  (0): ResBlock(
    (in_layers): Sequential(
      (0): GroupNorm32(32, 1280, eps=1e-05, affine=True)
      (1): SiLU()
      (2): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (h_upd): Identity()
    (x_upd): Identity()
    (emb_layers): Sequential(
      (0): SiLU()
      (1): Linear(in_features=1280, out_features=1280, bias=True)
    )
    (out_layers): Sequential(
      (0): GroupNorm32(32, 1280, eps=1e-05, affine=True)
      (1): SiLU()
      (2): Dropout(p=0, inplace=False)
      (3): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (skip_connection): Identity()
  )
  (1): SpatialTransformer(
    (norm): GroupNorm(32, 1280, eps=1e-06, affine=True)
    (proj_in): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
    (transformer_blocks): ModuleList(
      (0): BasicTransformerBlock(
        (attn1): CrossAttention(
          (to_q): Linear(in_features=1280, out_features=1280, bias=False)
          (to_k): Linear(in_features=1280, out_features=1280, bias=False)
          (to_v): Linear(in_features=1280, out_features=1280, bias=False)
          (to_out): Sequential(
            (0): Linear(in_features=1280, out_features=1280, bias=True)
            (1): Dropout(p=0.0, inplace=False)
          )
        )
        (ff): FeedForward(
          (net): Sequential(
            (0): GEGLU(
              (proj): Linear(in_features=1280, out_features=10240, bias=True)
            )
            (1): Dropout(p=0.0, inplace=False)
            (2): Linear(in_features=5120, out_features=1280, bias=True)
          )
        )
        (attn2): CrossAttention(
          (to_q): Linear(in_features=1280, out_features=1280, bias=False)
          (to_k): Linear(in_features=768, out_features=1280, bias=False)
          (to_v): Linear(in_features=768, out_features=1280, bias=False)
          (to_out): Sequential(
            (0): Linear(in_features=1280, out_features=1280, bias=True)
            (1): Dropout(p=0.0, inplace=False)
          )
        )
        (norm1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
        (norm3): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      )
    )
    (proj_out): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
  )
  (2): ResBlock(
    (in_layers): Sequential(
      (0): GroupNorm32(32, 1280, eps=1e-05, affine=True)
      (1): SiLU()
      (2): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (h_upd): Identity()
    (x_upd): Identity()
    (emb_layers): Sequential(
      (0): SiLU()
      (1): Linear(in_features=1280, out_features=1280, bias=True)
    )
    (out_layers): Sequential(
      (0): GroupNorm32(32, 1280, eps=1e-05, affine=True)
      (1): SiLU()
      (2): Dropout(p=0, inplace=False)
      (3): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (skip_connection): Identity()
  )
)

2-5--self.output_blocks上采样

        在 Forward() 中,使用 self.output_blocks 将噪声图进行分辨率上采样,经过上采样具体维度变化为:[B*2, 1280, 8, 8] > [B*2, 4, 64, 64];

        下采样模块共有12个 module,其结构与下采样模块类似,组成如下:

ModuleList(
  (0-1): 2 x TimestepEmbedSequential(
    (0): ResBlock(
      (in_layers): Sequential(
        (0): GroupNorm32(32, 2560, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Conv2d(2560, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (h_upd): Identity()
      (x_upd): Identity()
      (emb_layers): Sequential(
        (0): SiLU()
        (1): Linear(in_features=1280, out_features=1280, bias=True)
      )
      (out_layers): Sequential(
        (0): GroupNorm32(32, 1280, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Dropout(p=0, inplace=False)
        (3): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (skip_connection): Conv2d(2560, 1280, kernel_size=(1, 1), stride=(1, 1))
    )
  )
  (2): TimestepEmbedSequential(
    (0): ResBlock(
      (in_layers): Sequential(
        (0): GroupNorm32(32, 2560, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Conv2d(2560, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (h_upd): Identity()
      (x_upd): Identity()
      (emb_layers): Sequential(
        (0): SiLU()
        (1): Linear(in_features=1280, out_features=1280, bias=True)
      )
      (out_layers): Sequential(
        (0): GroupNorm32(32, 1280, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Dropout(p=0, inplace=False)
        (3): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (skip_connection): Conv2d(2560, 1280, kernel_size=(1, 1), stride=(1, 1))
    )
    (1): Upsample(
      (conv): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
  )
  (3-4): 2 x TimestepEmbedSequential(
    (0): ResBlock(
      (in_layers): Sequential(
        (0): GroupNorm32(32, 2560, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Conv2d(2560, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (h_upd): Identity()
      (x_upd): Identity()
      (emb_layers): Sequential(
        (0): SiLU()
        (1): Linear(in_features=1280, out_features=1280, bias=True)
      )
      (out_layers): Sequential(
        (0): GroupNorm32(32, 1280, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Dropout(p=0, inplace=False)
        (3): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (skip_connection): Conv2d(2560, 1280, kernel_size=(1, 1), stride=(1, 1))
    )
    (1): SpatialTransformer(
      (norm): GroupNorm(32, 1280, eps=1e-06, affine=True)
      (proj_in): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
      (transformer_blocks): ModuleList(
        (0): BasicTransformerBlock(
          (attn1): CrossAttention(
            (to_q): Linear(in_features=1280, out_features=1280, bias=False)
            (to_k): Linear(in_features=1280, out_features=1280, bias=False)
            (to_v): Linear(in_features=1280, out_features=1280, bias=False)
            (to_out): Sequential(
              (0): Linear(in_features=1280, out_features=1280, bias=True)
              (1): Dropout(p=0.0, inplace=False)
            )
          )
          (ff): FeedForward(
            (net): Sequential(
              (0): GEGLU(
                (proj): Linear(in_features=1280, out_features=10240, bias=True)
              )
              (1): Dropout(p=0.0, inplace=False)
              (2): Linear(in_features=5120, out_features=1280, bias=True)
            )
          )
          (attn2): CrossAttention(
            (to_q): Linear(in_features=1280, out_features=1280, bias=False)
            (to_k): Linear(in_features=768, out_features=1280, bias=False)
            (to_v): Linear(in_features=768, out_features=1280, bias=False)
            (to_out): Sequential(
              (0): Linear(in_features=1280, out_features=1280, bias=True)
              (1): Dropout(p=0.0, inplace=False)
            )
          )
          (norm1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
          (norm3): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
        )
      )
      (proj_out): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
    )
  )
  (5): TimestepEmbedSequential(
    (0): ResBlock(
      (in_layers): Sequential(
        (0): GroupNorm32(32, 1920, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Conv2d(1920, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (h_upd): Identity()
      (x_upd): Identity()
      (emb_layers): Sequential(
        (0): SiLU()
        (1): Linear(in_features=1280, out_features=1280, bias=True)
      )
      (out_layers): Sequential(
        (0): GroupNorm32(32, 1280, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Dropout(p=0, inplace=False)
        (3): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (skip_connection): Conv2d(1920, 1280, kernel_size=(1, 1), stride=(1, 1))
    )
    (1): SpatialTransformer(
      (norm): GroupNorm(32, 1280, eps=1e-06, affine=True)
      (proj_in): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
      (transformer_blocks): ModuleList(
        (0): BasicTransformerBlock(
          (attn1): CrossAttention(
            (to_q): Linear(in_features=1280, out_features=1280, bias=False)
            (to_k): Linear(in_features=1280, out_features=1280, bias=False)
            (to_v): Linear(in_features=1280, out_features=1280, bias=False)
            (to_out): Sequential(
              (0): Linear(in_features=1280, out_features=1280, bias=True)
              (1): Dropout(p=0.0, inplace=False)
            )
          )
          (ff): FeedForward(
            (net): Sequential(
              (0): GEGLU(
                (proj): Linear(in_features=1280, out_features=10240, bias=True)
              )
              (1): Dropout(p=0.0, inplace=False)
              (2): Linear(in_features=5120, out_features=1280, bias=True)
            )
          )
          (attn2): CrossAttention(
            (to_q): Linear(in_features=1280, out_features=1280, bias=False)
            (to_k): Linear(in_features=768, out_features=1280, bias=False)
            (to_v): Linear(in_features=768, out_features=1280, bias=False)
            (to_out): Sequential(
              (0): Linear(in_features=1280, out_features=1280, bias=True)
              (1): Dropout(p=0.0, inplace=False)
            )
          )
          (norm1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
          (norm3): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
        )
      )
      (proj_out): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
    )
    (2): Upsample(
      (conv): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
  )
  (6): TimestepEmbedSequential(
    (0): ResBlock(
      (in_layers): Sequential(
        (0): GroupNorm32(32, 1920, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Conv2d(1920, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (h_upd): Identity()
      (x_upd): Identity()
      (emb_layers): Sequential(
        (0): SiLU()
        (1): Linear(in_features=1280, out_features=640, bias=True)
      )
      (out_layers): Sequential(
        (0): GroupNorm32(32, 640, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Dropout(p=0, inplace=False)
        (3): Conv2d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (skip_connection): Conv2d(1920, 640, kernel_size=(1, 1), stride=(1, 1))
    )
    (1): SpatialTransformer(
      (norm): GroupNorm(32, 640, eps=1e-06, affine=True)
      (proj_in): Conv2d(640, 640, kernel_size=(1, 1), stride=(1, 1))
      (transformer_blocks): ModuleList(
        (0): BasicTransformerBlock(
          (attn1): CrossAttention(
            (to_q): Linear(in_features=640, out_features=640, bias=False)
            (to_k): Linear(in_features=640, out_features=640, bias=False)
            (to_v): Linear(in_features=640, out_features=640, bias=False)
            (to_out): Sequential(
              (0): Linear(in_features=640, out_features=640, bias=True)
              (1): Dropout(p=0.0, inplace=False)
            )
          )
          (ff): FeedForward(
            (net): Sequential(
              (0): GEGLU(
                (proj): Linear(in_features=640, out_features=5120, bias=True)
              )
              (1): Dropout(p=0.0, inplace=False)
              (2): Linear(in_features=2560, out_features=640, bias=True)
            )
          )
          (attn2): CrossAttention(
            (to_q): Linear(in_features=640, out_features=640, bias=False)
            (to_k): Linear(in_features=768, out_features=640, bias=False)
            (to_v): Linear(in_features=768, out_features=640, bias=False)
            (to_out): Sequential(
              (0): Linear(in_features=640, out_features=640, bias=True)
              (1): Dropout(p=0.0, inplace=False)
            )
          )
          (norm1): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
          (norm3): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
        )
      )
      (proj_out): Conv2d(640, 640, kernel_size=(1, 1), stride=(1, 1))
    )
  )
  (7): TimestepEmbedSequential(
    (0): ResBlock(
      (in_layers): Sequential(
        (0): GroupNorm32(32, 1280, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Conv2d(1280, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (h_upd): Identity()
      (x_upd): Identity()
      (emb_layers): Sequential(
        (0): SiLU()
        (1): Linear(in_features=1280, out_features=640, bias=True)
      )
      (out_layers): Sequential(
        (0): GroupNorm32(32, 640, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Dropout(p=0, inplace=False)
        (3): Conv2d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (skip_connection): Conv2d(1280, 640, kernel_size=(1, 1), stride=(1, 1))
    )
    (1): SpatialTransformer(
      (norm): GroupNorm(32, 640, eps=1e-06, affine=True)
      (proj_in): Conv2d(640, 640, kernel_size=(1, 1), stride=(1, 1))
      (transformer_blocks): ModuleList(
        (0): BasicTransformerBlock(
          (attn1): CrossAttention(
            (to_q): Linear(in_features=640, out_features=640, bias=False)
            (to_k): Linear(in_features=640, out_features=640, bias=False)
            (to_v): Linear(in_features=640, out_features=640, bias=False)
            (to_out): Sequential(
              (0): Linear(in_features=640, out_features=640, bias=True)
              (1): Dropout(p=0.0, inplace=False)
            )
          )
          (ff): FeedForward(
            (net): Sequential(
              (0): GEGLU(
                (proj): Linear(in_features=640, out_features=5120, bias=True)
              )
              (1): Dropout(p=0.0, inplace=False)
              (2): Linear(in_features=2560, out_features=640, bias=True)
            )
          )
          (attn2): CrossAttention(
            (to_q): Linear(in_features=640, out_features=640, bias=False)
            (to_k): Linear(in_features=768, out_features=640, bias=False)
            (to_v): Linear(in_features=768, out_features=640, bias=False)
            (to_out): Sequential(
              (0): Linear(in_features=640, out_features=640, bias=True)
              (1): Dropout(p=0.0, inplace=False)
            )
          )
          (norm1): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
          (norm3): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
        )
      )
      (proj_out): Conv2d(640, 640, kernel_size=(1, 1), stride=(1, 1))
    )
  )
  (8): TimestepEmbedSequential(
    (0): ResBlock(
      (in_layers): Sequential(
        (0): GroupNorm32(32, 960, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Conv2d(960, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (h_upd): Identity()
      (x_upd): Identity()
      (emb_layers): Sequential(
        (0): SiLU()
        (1): Linear(in_features=1280, out_features=640, bias=True)
      )
      (out_layers): Sequential(
        (0): GroupNorm32(32, 640, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Dropout(p=0, inplace=False)
        (3): Conv2d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (skip_connection): Conv2d(960, 640, kernel_size=(1, 1), stride=(1, 1))
    )
    (1): SpatialTransformer(
      (norm): GroupNorm(32, 640, eps=1e-06, affine=True)
      (proj_in): Conv2d(640, 640, kernel_size=(1, 1), stride=(1, 1))
      (transformer_blocks): ModuleList(
        (0): BasicTransformerBlock(
          (attn1): CrossAttention(
            (to_q): Linear(in_features=640, out_features=640, bias=False)
            (to_k): Linear(in_features=640, out_features=640, bias=False)
            (to_v): Linear(in_features=640, out_features=640, bias=False)
            (to_out): Sequential(
              (0): Linear(in_features=640, out_features=640, bias=True)
              (1): Dropout(p=0.0, inplace=False)
            )
          )
          (ff): FeedForward(
            (net): Sequential(
              (0): GEGLU(
                (proj): Linear(in_features=640, out_features=5120, bias=True)
              )
              (1): Dropout(p=0.0, inplace=False)
              (2): Linear(in_features=2560, out_features=640, bias=True)
            )
          )
          (attn2): CrossAttention(
            (to_q): Linear(in_features=640, out_features=640, bias=False)
            (to_k): Linear(in_features=768, out_features=640, bias=False)
            (to_v): Linear(in_features=768, out_features=640, bias=False)
            (to_out): Sequential(
              (0): Linear(in_features=640, out_features=640, bias=True)
              (1): Dropout(p=0.0, inplace=False)
            )
          )
          (norm1): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
          (norm3): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
        )
      )
      (proj_out): Conv2d(640, 640, kernel_size=(1, 1), stride=(1, 1))
    )
    (2): Upsample(
      (conv): Conv2d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
  )
  (9): TimestepEmbedSequential(
    (0): ResBlock(
      (in_layers): Sequential(
        (0): GroupNorm32(32, 960, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Conv2d(960, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (h_upd): Identity()
      (x_upd): Identity()
      (emb_layers): Sequential(
        (0): SiLU()
        (1): Linear(in_features=1280, out_features=320, bias=True)
      )
      (out_layers): Sequential(
        (0): GroupNorm32(32, 320, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Dropout(p=0, inplace=False)
        (3): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (skip_connection): Conv2d(960, 320, kernel_size=(1, 1), stride=(1, 1))
    )
    (1): SpatialTransformer(
      (norm): GroupNorm(32, 320, eps=1e-06, affine=True)
      (proj_in): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))
      (transformer_blocks): ModuleList(
        (0): BasicTransformerBlock(
          (attn1): CrossAttention(
            (to_q): Linear(in_features=320, out_features=320, bias=False)
            (to_k): Linear(in_features=320, out_features=320, bias=False)
            (to_v): Linear(in_features=320, out_features=320, bias=False)
            (to_out): Sequential(
              (0): Linear(in_features=320, out_features=320, bias=True)
              (1): Dropout(p=0.0, inplace=False)
            )
          )
          (ff): FeedForward(
            (net): Sequential(
              (0): GEGLU(
                (proj): Linear(in_features=320, out_features=2560, bias=True)
              )
              (1): Dropout(p=0.0, inplace=False)
              (2): Linear(in_features=1280, out_features=320, bias=True)
            )
          )
          (attn2): CrossAttention(
            (to_q): Linear(in_features=320, out_features=320, bias=False)
            (to_k): Linear(in_features=768, out_features=320, bias=False)
            (to_v): Linear(in_features=768, out_features=320, bias=False)
            (to_out): Sequential(
              (0): Linear(in_features=320, out_features=320, bias=True)
              (1): Dropout(p=0.0, inplace=False)
            )
          )
          (norm1): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
          (norm3): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
        )
      )
      (proj_out): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))
    )
  )
  (10-11): 2 x TimestepEmbedSequential(
    (0): ResBlock(
      (in_layers): Sequential(
        (0): GroupNorm32(32, 640, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Conv2d(640, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (h_upd): Identity()
      (x_upd): Identity()
      (emb_layers): Sequential(
        (0): SiLU()
        (1): Linear(in_features=1280, out_features=320, bias=True)
      )
      (out_layers): Sequential(
        (0): GroupNorm32(32, 320, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Dropout(p=0, inplace=False)
        (3): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (skip_connection): Conv2d(640, 320, kernel_size=(1, 1), stride=(1, 1))
    )
    (1): SpatialTransformer(
      (norm): GroupNorm(32, 320, eps=1e-06, affine=True)
      (proj_in): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))
      (transformer_blocks): ModuleList(
        (0): BasicTransformerBlock(
          (attn1): CrossAttention(
            (to_q): Linear(in_features=320, out_features=320, bias=False)
            (to_k): Linear(in_features=320, out_features=320, bias=False)
            (to_v): Linear(in_features=320, out_features=320, bias=False)
            (to_out): Sequential(
              (0): Linear(in_features=320, out_features=320, bias=True)
              (1): Dropout(p=0.0, inplace=False)
            )
          )
          (ff): FeedForward(
            (net): Sequential(
              (0): GEGLU(
                (proj): Linear(in_features=320, out_features=2560, bias=True)
              )
              (1): Dropout(p=0.0, inplace=False)
              (2): Linear(in_features=1280, out_features=320, bias=True)
            )
          )
          (attn2): CrossAttention(
            (to_q): Linear(in_features=320, out_features=320, bias=False)
            (to_k): Linear(in_features=768, out_features=320, bias=False)
            (to_v): Linear(in_features=768, out_features=320, bias=False)
            (to_out): Sequential(
              (0): Linear(in_features=320, out_features=320, bias=True)
              (1): Dropout(p=0.0, inplace=False)
            )
          )
          (norm1): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
          (norm3): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
        )
      )
      (proj_out): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))
    )
  )
)

更新时间 2024-06-21