系列文章目录
【扩散模型(一)】中介绍了 Stable Diffusion 可以被理解为重建分支(reconstruction branch)和条件分支(condition branch) 【扩散模型(二)】IP-Adapter 从条件分支的视角,快速理解相关的可控生成研究 【扩散模型(三)】IP-Adapter 源码详解1-训练输入 介绍了训练代码中的 image prompt 的输入部分,即 img projection 模块。 【扩散模型(四)】IP-Adapter 源码详解2-训练核心(cross-attention)详细介绍 IP-Adapter 训练代码的核心部分,即插入 Unet 中的、针对 Image prompt 的 cross-attention 模块。 【扩散模型(五)】IP-Adapter 源码详解3-推理代码 详细介绍 IP-Adapter 推理过程代码。 【扩散模型(六)】Stable Diffusion 3 diffusers 源码详解1-推理代码-文本处理部分 【扩散模型(七)】Stable Diffusion 3 diffusers 源码详解2 - DiT 与 MMDiT 相关代码(上)介绍了 DiT ,本文则介绍 MMDiT 与 DiT 的区别以及核心代码实现。
文章目录
系列文章目录 MMDiT 四层代码结构 第一层 第二和第三层 第四层MMDiT
四层代码结构
上图中的 (a) 为第一层,(b) 为第二层和三层,而 (b) 中 Attention 的实现是在另外一个代码文件(第四层)中。 文本和图像的融合部分是在第四层 (JointAttnProcessor2_0) 中。 第四层的完整结构如下所示,重点放在了 Joint Attention 的具体实现上。第一层
图(a)对应的代码在 /path/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py
, 是从 noise_pred = self.transformer
进入到整个 transformer ( MM-DiT 1 至 d )中
pipeline_stable_diffusion_3.py 中的 call 函数中的以下片段
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
pooled_projections=pooled_prompt_embeds,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
第二和第三层
(b)对应的代码在进入 transformer( /path/lib/python3.12/site-packages/diffusers/models/transformers/transformer_sd3.py
)后,的 for 循环中,依次进入每个 MM-DiT block(JointTransformerBlock)
第二层: transformer_sd3.py 中的 forward 函数中以下片段进入 for 循环,如果不训练 backbone的话,那么就是从 else 分支进入 block 中。
for index_block, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
**ckpt_kwargs,
)
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
)
第三层: block 的实现是在 /path/lib/python3.12/site-packages/diffusers/models/attention.py
中的 JointTransformerBlock 类,其中 hidden_states (noisy latent)和 encoder_hidden_states (text prompt) 分别通过 norm1 和 norm1_context 后,进入了第四层 self.attn
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
if self.context_pre_only:
norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
else:
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
encoder_hidden_states, emb=temb
)
# Attention.
attn_output, context_attn_output = self.attn(
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
)
第四层
从 self.attn 的 init 中,我们可以找到实际代码在 JointAttnProcessor2_0() 类,即 /path/lib/python3.12/site-packages/diffusers/models/attention_processor.py
中
下方为 self.attn 的 init 初始化
if hasattr(F, "scaled_dot_product_attention"):
processor = JointAttnProcessor2_0()
else:
raise ValueError(
"The current PyTorch version does not support the `scaled_dot_product_attention` function."
)
self.attn = Attention(
query_dim=dim,
cross_attention_dim=None,
added_kv_proj_dim=dim,
dim_head=attention_head_dim // num_attention_heads,
heads=num_attention_heads,
out_dim=attention_head_dim,
context_pre_only=context_pre_only,
bias=True,
processor=processor,
)
下方画出的图片和对应代码即为文图融合的核心关键,在原论文中1对这部分结构的解释是 “等价于两个针对文/图模态的独立的 transformers,但在 attention 操作中两种模态联合(joining)在了一起”,贴出原文描述来更好理解
class JointAttnProcessor2_0:
"""Attention processor used typically in processing the SD3-like self-attention projections."""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
*args,
**kwargs,
) -> torch.FloatTensor:
residual = hidden_states
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
context_input_ndim = encoder_hidden_states.ndim
if context_input_ndim == 4:
batch_size, channel, height, width = encoder_hidden_states.shape
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size = encoder_hidden_states.shape[0]
# `sample` projections.
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
# `context` projections.
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
# attention
query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
hidden_states = hidden_states = F.scaled_dot_product_attention(
query, key, value, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# Split the attention outputs.
hidden_states, encoder_hidden_states = (
hidden_states[:, : residual.shape[1]],
hidden_states[:, residual.shape[1] :],
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if not attn.context_pre_only:
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if context_input_ndim == 4:
encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
return hidden_states, encoder_hidden_states
Scaling Rectified Flow Transformers for High-Resolution Image Synthesis ↩︎
总结
### 文章内容总结本系列文章详细解析了Stable Diffusion及其在IP-Adapter和MMDiT变种中的应用。文章通过七部分,从理论介绍到源码实现逐步深入,为读者全面展示了图像生成技术中的关键技术和代码实现细节。
#### 1. 扩散模型基础
**【扩散模型(一)】** 介绍了Stable Diffusion模型的基本结构,包括重建分支和条件分支,为读者奠定了理论基础。
#### 2. IP-Adapter理论研究
**【扩散模型(二)】** 从条件分支的视角,介绍了IP-Adapter的相关可控生成研究,展示了如何通过调节条件分支来实现图像的精确控制。
#### 3. IP-Adapter源码详解-训练输入
**【扩散模型(三)】** 详细解读了IP-Adapter训练代码中的image prompt输入部分,即img projection模块的具体实现。
#### 4. IP-Adapter源码详解-核心训练
**【扩散模型(四)】** 聚焦于IP-Adapter的训练核心部分,详细介绍了在Unet中引入的针对Image prompt的cross-attention模块。
#### 5. IP-Adapter源码详解-推理过程
**【扩散模型(五)】** 解析了IP-Adapter的推理过程代码,展示了模型如何用于实际生成图像。
#### 6. Stable Diffusion 3 diffusers源码解析-文本处理
**【扩散模型(六)】** 分析了Stable Diffusion 3 diffusers的推理代码中的文本处理部分,揭示了文本转换为模型可接受输入的过程。
#### 7. MMDiT源码详解
**本文重点**:
- **四层代码结构**:文章详细阐述了MMDiT的四层代码结构,从顶层的pipeline_stable_diffusion_3.py到最底层用于文本图像联合注意力处理的JointAttnProcessor2_0类。
- **第一层**:从`noise_pred = self.transformer(...)`进入transformer的MM-DiT blocks。
- **第二和第三层**:通过transformer_sd3.py中的for循环和attention.py中的JointTransformerBlock类,处理图像和文本模态的归一化和注意力前的准备工作。
- **第四层**:进入attention_processor.py中的JointAttnProcessor2_0类,实现文本和图像的联合注意力计算的核心逻辑。代码展示了如何通过拼接查询向量、键向量和值向量,使用scaled_dot_product_attention进行联合注意力计算,并分割处理后的隐藏状态输出。
### 关键支持与结论
- **MMDiT的特殊设计**:MMDiT的实现等效于两个独立的transformers,但在attention操作中实现了文本和图像模态的联合,这一设计大大增强了模型在图像生成时对文本指令的理解和执行能力。
- **代码实现细节**:通过详细的源码解读,展示了MMDiT如何在Stable Diffusion 3的框架下实现文图融合的具体过程,包括数据的维度变换、联合注意力的计算以及最终结果的整理与输出。
- **研究意义**:本文为研究和应用Stable Diffusion变种技术提供了宝贵的参考,特别是MMDiT的实现为生成高解析度、高保真度的图像提供了强有力的技术支持。