当前位置:AIGC资讯 > AIGC > 正文

腾讯HunyuanDit代码解析

注意:本文仅供自己记录学习过程使用。

训练

全参训练过程

输入图像用VAE编码得到输入的x_start(1,4,128,128);文本的两个特征:bert的encoder feature(1,77,1024)和T5 的feature(1,256,2048),和旋转位置编码freqs_cis_img: cos (4096,88),sin (4096,88)。 生成随机的时间步长t;生成随机的噪声(1,4,128,128),给输入的x_start加上噪声得到输出的x_t;
    def q_sample(self, x_start, t, noise=None):
        """
        Diffuse the data for a given number of diffusion steps.

        In other words, sample from q(x_t | x_0).

        :param x_start: the initial data batch.
        :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
        :param noise: if specified, the split-out normal noise.
        :return: A noisy version of x_start.
        """
        if noise is None:
            noise = th.randn_like(x_start)
        assert_shape(noise, x_start)
        return (
            _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
            + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
            * noise
        )

对T5 的feature(1,256,2048)用mlp降为到(1,256,1024),然后把它和bert的feature cat起来得到text_states (1,33,1024); 对时间t编码(1,1408),x_t打成path,x(1,4096,1048); 对t5 feature进行pooling(multihead self-attention)得到extra_vec(1,1024); 时间t+mlp(extra_vec)=c(1,1408),得到condition; 上述步骤已得到以下参数:x ,c,text_states,freqs_cis_img。开始迭代处理。
x = block(x, c, text_states, freqs_cis_img)
mlp(c)+x得到self-attention block的输入,把输入分成q/k/v,然后把q/k用旋转位置编码进行编码,得到新的qk。然后mlp提特征,输出x(1,4096,1408);简单来说,就是输入的x和文本的全局特征做了一次注意力提取特征的操作;
    def forward(self, x, freqs_cis_img=None):
        """
        Parameters
        ----------
        x: torch.Tensor
            (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim)
        freqs_cis_img: torch.Tensor
            (batch, hidden_dim // 2), RoPE for image
        """
        b, s, d = x.shape

        qkv = self.Wqkv(x)
        qkv = qkv.view(b, s, 3, self.num_heads, self.head_dim)  # [b, s, 3, h, d]
        q, k, v = qkv.unbind(dim=2) # [b, s, h, d]
        q = self.q_norm(q).half()   # [b, s, h, d]
        k = self.k_norm(k).half()

        # Apply RoPE if needed
        if freqs_cis_img is not None:
            qq, kk = apply_rotary_emb(q, k, freqs_cis_img)
            assert qq.shape == q.shape and kk.shape == k.shape, f'qq: {qq.shape}, q: {q.shape}, kk: {kk.shape}, k: {k.shape}'
            q, k = qq, kk

        qkv = torch.stack([q, k, v], dim=2)     # [b, s, 3, h, d]
        context = self.inner_attn(qkv)
        out = self.out_proj(context.view(b, s, d))
        out = self.proj_drop(out)

        out_tuple = (out,)

        return out_tuple
def apply_rotary_emb(
        xq: torch.Tensor,
        xk: Optional[torch.Tensor],
        freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
        head_first: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Apply rotary embeddings to input tensors using the given frequency tensor.

    This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
    frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
    is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
    returned as real tensors.

    Args:
        xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D]
        xk (torch.Tensor): Key tensor to apply rotary embeddings.   [B, S, H, D]
        freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Precomputed frequency tensor for complex exponentials.
        head_first (bool): head dimension first (except batch dim) or not.

    Returns:
        Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.

    """
    xk_out = None
    if isinstance(freqs_cis, tuple):
        cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first)    # [S, D]
        cos, sin = cos.to(xq.device), sin.to(xq.device)
        xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
        if xk is not None:
            xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
    else:
        xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))  # [B, S, H, D//2]
        freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(xq.device)   # [S, D//2] --> [1, S, 1, D//2]
        xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
        if xk is not None:
            xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))  # [B, S, H, D//2]
            xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)

    return xq_out, xk_out
上面得到的x,以及文本特征text_states,和旋转位置编码freqs_cis_img,作为cross attention block的输入;y是文本特征text_states;x作为q, text_states作为kv,q加上位置编码后,和kv作cross attention,得到输出x(1,4096,1408);
    def forward(self, x, y, freqs_cis_img=None):
        """
        Parameters
        ----------
        x: torch.Tensor
            (batch, seqlen1, hidden_dim) (where hidden_dim = num_heads * head_dim)
        y: torch.Tensor
            (batch, seqlen2, hidden_dim2)
        freqs_cis_img: torch.Tensor
            (batch, hidden_dim // num_heads), RoPE for image
        """
        b, s1, _ = x.shape     # [b, s1, D]
        _, s2, _ = y.shape     # [b, s2, 1024]

        q = self.q_proj(x).view(b, s1, self.num_heads, self.head_dim)       # [b, s1, h, d]
        kv = self.kv_proj(y).view(b, s2, 2, self.num_heads, self.head_dim)  # [b, s2, 2, h, d]
        k, v = kv.unbind(dim=2)                 # [b, s2, h, d]
        q = self.q_norm(q).half()               # [b, s1, h, d]
        k = self.k_norm(k).half()               # [b, s2, h, d]

        # Apply RoPE if needed
        if freqs_cis_img is not None:
            qq, _ = apply_rotary_emb(q, None, freqs_cis_img)
            assert qq.shape == q.shape, f'qq: {qq.shape}, q: {q.shape}'
            q = qq                              # [b, s1, h, d]
        kv = torch.stack([k, v], dim=2)         # [b, s1, 2, h, d]
        context = self.inner_attn(q, kv)        # [b, s1, h, d]
        context = context.view(b, s1, -1)       # [b, s1, D]

        out = self.out_proj(context)
        out = self.proj_drop(out)

        out_tuple = (out,)

        return out_tuple
最后mlp输出x(1,4096,1408)。共有19个hunyuan block,每个block输出的都是(1,4096,1408);(类似于unet encoder的操作,后续就是解码了,但是它这里“编解码”并没有分辨率的概念)。 开始“解码操作”了,其实就是前面最后输出x和前面block的输出cat起来,然后提取特征,后续步骤和前面是一样的。
    def _forward(self, x, c=None, text_states=None, freq_cis_img=None, skip=None):
        # Long Skip Connection
        if self.skip_linear is not None:
            cat = torch.cat([x, skip], dim=-1)
            cat = self.skip_norm(cat)
            x = self.skip_linear(cat)

        # Self-Attention
        shift_msa = self.default_modulation(c).unsqueeze(dim=1)
        attn_inputs = (
            self.norm1(x) + shift_msa, freq_cis_img,
        )
        x = x + self.attn1(*attn_inputs)[0]

        # Cross-Attention
        cross_inputs = (
            self.norm3(x), text_states, freq_cis_img
        )
        x = x + self.attn2(*cross_inputs)[0]

        # FFN Layer
        mlp_inputs = self.norm2(x)
        x = x + self.mlp(mlp_inputs)

        return x
最后整个网络输出(1,8,128,128); 网络的输出前4个通道(1,4,128,128)和输入的纯净的x_start作mse loss,后四个通道作什么变分概率误差;至此训练完成;

lora训练过程

训练过程和全参一样,低秩矩阵调用库peft训练的,略;

controlnet训练过程

架构和hunyuandit一致; 1-6步和全参训练一样,第六步后有个VAE编码后的control img(1,4,128,128)作为condition, 把它和x_t相加+,得到网络的输入x;其他c,text_states,freqs_cis_img和之前一样;
        condition = self.x_embedder(condition)

        # ========================= Forward pass through HunYuanDiT blocks =========================
        controls = []
        x = x + self.before_proj(condition) # add condition
        for layer, block in enumerate(self.blocks):
            x = block(x, c, text_states, freqs_cis_img)
            controls.append(self.after_proj_list[layer](x)) # zero linear for output
输出19个block的control feature;与冻结后的hunyuandit的“解码层”特征相加即可;
        for layer, block in enumerate(self.blocks):
            if layer > self.depth // 2:
                if controls is not None:
                    skip = skips.pop() + controls.pop()
                else:
                    skip = skips.pop()
                x = block(x, c, text_states, freqs_cis_img, skip)   # (N, L, D)
            else:
                x = block(x, c, text_states, freqs_cis_img)         # (N, L, D)

            if layer < (self.depth // 2 - 1):
                skips.append(x)
损失和之前一致,训练完毕;

推理

全参推理,基本根训练一样
a. 准备正向prompt 和负向prompt,cat起来得到prompt_embeds (2,77,1024)。t5还有个文本的embeding:prompt_embeds_t5 (2,256,2048)。
b. 随机生成噪声(1,4,128,128);
c. 放到unet中的得到预测的噪声(1,4,128,128);然后减去反向提示词noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
d. 根据公式得到新的latent,通过不断的去噪,得到最终的结果,放到Vae解码器中得到输出图片。 lora推理,同上;lora融合权重公式,模型原始权重 = 模型原始权重+系数 * lora权重。
def load_hunyuan_dit_lora(transformer_state_dict, lora_state_dict, lora_scale):
    for i in range(num_layers):
        Wqkv = torch.matmul(lora_state_dict[f"blocks.{i}.attn1.Wqkv.lora_B.weight"], lora_state_dict[f"blocks.{i}.attn1.Wqkv.lora_A.weight"])  # lora权重
        q, k, v = torch.chunk(Wqkv, 3, dim=0)
        transformer_state_dict[f"blocks.{i}.attn1.to_q.weight"] += lora_scale * q # 原始权重+lora权重
        transformer_state_dict[f"blocks.{i}.attn1.to_k.weight"] += lora_scale * k
        transformer_state_dict[f"blocks.{i}.attn1.to_v.weight"] += lora_scale * v

        out_proj = torch.matmul(lora_state_dict[f"blocks.{i}.attn1.out_proj.lora_B.weight"], lora_state_dict[f"blocks.{i}.attn1.out_proj.lora_A.weight"]) 
        transformer_state_dict[f"blocks.{i}.attn1.to_out.0.weight"] += lora_scale * out_proj

        q_proj = torch.matmul(lora_state_dict[f"blocks.{i}.attn2.q_proj.lora_B.weight"], lora_state_dict[f"blocks.{i}.attn2.q_proj.lora_A.weight"])
        transformer_state_dict[f"blocks.{i}.attn2.to_q.weight"] += lora_scale * q_proj

        kv_proj = torch.matmul(lora_state_dict[f"blocks.{i}.attn2.kv_proj.lora_B.weight"], lora_state_dict[f"blocks.{i}.attn2.kv_proj.lora_A.weight"])
        k, v = torch.chunk(kv_proj, 2, dim=0)
        transformer_state_dict[f"blocks.{i}.attn2.to_k.weight"] += lora_scale * k
        transformer_state_dict[f"blocks.{i}.attn2.to_v.weight"] += lora_scale * v

        out_proj = torch.matmul(lora_state_dict[f"blocks.{i}.attn2.out_proj.lora_B.weight"], lora_state_dict[f"blocks.{i}.attn2.out_proj.lora_A.weight"]) 
        transformer_state_dict[f"blocks.{i}.attn2.to_out.0.weight"] += lora_scale * out_proj
    
    q_proj = torch.matmul(lora_state_dict["pooler.q_proj.lora_B.weight"], lora_state_dict["pooler.q_proj.lora_A.weight"])
    transformer_state_dict["time_extra_emb.pooler.q_proj.weight"] += lora_scale * q_proj
    
    return transformer_state_dict
controlnet推理,和训练基本一样,不一样的在于controlnet的特征可以乘上权重,再加上原始unet的特征。
                    controls = self.controlnet(
                        latent_model_input,
                        t_expand,
                        condition,
                        encoder_hidden_states=prompt_embeds,
                        text_embedding_mask=attention_mask,
                        encoder_hidden_states_t5=prompt_embeds_t5,
                        text_embedding_mask_t5=attention_mask_t5,
                        image_meta_size=ims,
                        style=style,
                        cos_cis_img=freqs_cis_img[0],
                        sin_cis_img=freqs_cis_img[1],
                        return_dict=False,
                    )
                    if isinstance(control_weight, list):
                        assert len(control_weight) == len(controls)
                        controls = [control * weight for control, weight in zip(controls, control_weight)] # 每一层特征乘以权重
                    else:
                        controls = [control * control_weight for control in controls]
                    noise_pred = self.unet(
                        latent_model_input,
                        t_expand,
                        encoder_hidden_states=prompt_embeds,
                        text_embedding_mask=attention_mask,
                        encoder_hidden_states_t5=prompt_embeds_t5,
                        text_embedding_mask_t5=attention_mask_t5,
                        image_meta_size=ims,
                        style=style,
                        cos_cis_img=freqs_cis_img[0],
                        sin_cis_img=freqs_cis_img[1],
                        return_dict=False,
                        controls=controls
                    )

总结

注意:本文仅供自己记录学习过程使用。



训练



全参训练过程


输入图像用VAE编码得到输入的x_start(1,4,128,128);文本的两个特征:bert的encoder feature(1,77,1024)和T5 的feature(1,256,2048),和旋转位置编码freqs_cis_img: cos (4096,88),sin (4096,88)。
生成随机的时间步长t;生成随机的噪声(1,4,128,128),给输入的x_start加上噪声得到输出的x_t;
    def q_sample(self, x_start, t, noise=None):
"""
Diffuse the data for a given number of diffusion steps.
In other words, sample from q(x_t | x_0).
:param x_start: the initial data batch.
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
:param noise: if specified, the split-out normal noise.
:return: A noisy version of x_start.
"""
if noise is None:
noise = th.randn_like(x_start)
assert_shape(noise, x_start)
return (
_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
* noise
)

对T5 的feature(1,256,2048)用mlp降为到(1,256,1024),然后把它和bert的feature cat起来得到text_states (1,33,1024);
对时间t编码(1,1408),x_t打成path,x(1,4096,1048);
对t5 feature进行pooling(multihead self-attention)得到extra_vec(1,1024);
时间t+mlp(extra_vec)=c(1,1408),得到condition;
上述步骤已得到以下参数:x ,c,text_states,freqs_cis_img。开始迭代处理。
x = block(x, c, text_states, freqs_cis_img)

mlp(c)+x得到self-attention block的输入,把输入分成q/k/v,然后把q/k用旋转位置编码进行编码,得到新的qk。然后mlp提特征,输出x(1,4096,1408);简单来说,就是输入的x和文本的全局特征做了一次注意力提取特征的操作;
    def forward(self, x, freqs_cis_img=None):
"""
Parameters
----------
x: torch.Tensor
(batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim)
freqs_cis_img: torch.Tensor
(batch, hidden_dim // 2), RoPE for image
"""
b, s, d = x.shape
qkv = self.Wqkv(x)
qkv = qkv.view(b, s, 3, self.num_heads, self.head_dim) # [b, s, 3, h, d]
q, k, v = qkv.unbind(dim=2) # [b, s, h, d]
q = self.q_norm(q).half() # [b, s, h, d]
k = self.k_norm(k).half()
# Apply RoPE if needed
if freqs_cis_img is not None:
qq, kk = apply_rotary_emb(q, k, freqs_cis_img)
assert qq.shape == q.shape and kk.shape == k.shape, f'qq: {qq.shape}, q: {q.shape}, kk: {kk.shape}, k: {k.shape}'
q, k = qq, kk
qkv = torch.stack([q, k, v], dim=2) # [b, s, 3, h, d]
context = self.inner_attn(qkv)
out = self.out_proj(context.view(b, s, d))
out = self.proj_drop(out)
out_tuple = (out,)
return out_tuple

def apply_rotary_emb(
xq: torch.Tensor,
xk: Optional[torch.Tensor],
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
head_first: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary embeddings to input tensors using the given frequency tensor.
This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
returned as real tensors.
Args:
xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D]
xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D]
freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Precomputed frequency tensor for complex exponentials.
head_first (bool): head dimension first (except batch dim) or not.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
"""
xk_out = None
if isinstance(freqs_cis, tuple):
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
cos, sin = cos.to(xq.device), sin.to(xq.device)
xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
if xk is not None:
xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
else:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [B, S, H, D//2]
freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(xq.device) # [S, D//2] --> [1, S, 1, D//2]
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
if xk is not None:
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) # [B, S, H, D//2]
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
return xq_out, xk_out

上面得到的x,以及文本特征text_states,和旋转位置编码freqs_cis_img,作为cross attention block的输入;y是文本特征text_states;x作为q, text_states作为kv,q加上位置编码后,和kv作cross attention,得到输出x(1,4096,1408);
    def forward(self, x, y, freqs_cis_img=None):
"""
Parameters
----------
x: torch.Tensor
(batch, seqlen1, hidden_dim) (where hidden_dim = num_heads * head_dim)
y: torch.Tensor
(batch, seqlen2, hidden_dim2)
freqs_cis_img: torch.Tensor
(batch, hidden_dim // num_heads), RoPE for image
"""
b, s1, _ = x.shape # [b, s1, D]
_, s2, _ = y.shape # [b, s2, 1024]
q = self.q_proj(x).view(b, s1, self.num_heads, self.head_dim) # [b, s1, h, d]
kv = self.kv_proj(y).view(b, s2, 2, self.num_heads, self.head_dim) # [b, s2, 2, h, d]
k, v = kv.unbind(dim=2) # [b, s2, h, d]
q = self.q_norm(q).half() # [b, s1, h, d]
k = self.k_norm(k).half() # [b, s2, h, d]
# Apply RoPE if needed
if freqs_cis_img is not None:
qq, _ = apply_rotary_emb(q, None, freqs_cis_img)
assert qq.shape == q.shape, f'qq: {qq.shape}, q: {q.shape}'
q = qq # [b, s1, h, d]
kv = torch.stack([k, v], dim=2) # [b, s1, 2, h, d]
context = self.inner_attn(q, kv) # [b, s1, h, d]
context = context.view(b, s1, -1) # [b, s1, D]
out = self.out_proj(context)
out = self.proj_drop(out)
out_tuple = (out,)
return out_tuple

最后mlp输出x(1,4096,1408)。共有19个hunyuan block,每个block输出的都是(1,4096,1408);(类似于unet encoder的操作,后续就是解码了,但是它这里“编解码”并没有分辨率的概念)。
开始“解码操作”了,其实就是前面最后输出x和前面block的输出cat起来,然后提取特征,后续步骤和前面是一样的。
    def _forward(self, x, c=None, text_states=None, freq_cis_img=None, skip=None):
# Long Skip Connection
if self.skip_linear is not None:
cat = torch.cat([x, skip], dim=-1)
cat = self.skip_norm(cat)
x = self.skip_linear(cat)
# Self-Attention
shift_msa = self.default_modulation(c).unsqueeze(dim=1)
attn_inputs = (
self.norm1(x) + shift_msa, freq_cis_img,
)
x = x + self.attn1(*attn_inputs)[0]
# Cross-Attention
cross_inputs = (
self.norm3(x), text_states, freq_cis_img
)
x = x + self.attn2(*cross_inputs)[0]
# FFN Layer
mlp_inputs = self.norm2(x)
x = x + self.mlp(mlp_inputs)
return x

最后整个网络输出(1,8,128,128);
网络的输出前4个通道(1,4,128,128)和输入的纯净的x_start作mse loss,后四个通道作什么变分概率误差;至此训练完成;


lora训练过程


训练过程和全参一样,低秩矩阵调用库peft训练的,略;



controlnet训练过程


架构和hunyuandit一致;
1-6步和全参训练一样,第六步后有个VAE编码后的control img(1,4,128,128)作为condition, 把它和x_t相加+,得到网络的输入x;其他c,text_states,freqs_cis_img和之前一样;
        condition = self.x_embedder(condition)
# ========================= Forward pass through HunYuanDiT blocks =========================
controls = []
x = x + self.before_proj(condition) # add condition
for layer, block in enumerate(self.blocks):
x = block(x, c, text_states, freqs_cis_img)
controls.append(self.after_proj_list[layer](x)) # zero linear for output

输出19个block的control feature;与冻结后的hunyuandit的“解码层”特征相加即可;
        for layer, block in enumerate(self.blocks):
if layer > self.depth // 2:
if controls is not None:
skip = skips.pop() + controls.pop()
else:
skip = skips.pop()
x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)
else:
x = block(x, c, text_states, freqs_cis_img) # (N, L, D)
if layer < (self.depth // 2 - 1):
skips.append(x)

损失和之前一致,训练完毕;


推理


全参推理,基本根训练一样
a. 准备正向prompt 和负向prompt,cat起来得到prompt_embeds (2,77,1024)。t5还有个文本的embeding:prompt_embeds_t5 (2,256,2048)。
b. 随机生成噪声(1,4,128,128);
c. 放到unet中的得到预测的噪声(1,4,128,128);然后减去反向提示词noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
d. 根据公式得到新的latent,通过不断的去噪,得到最终的结果,放到Vae解码器中得到输出图片。
lora推理,同上;lora融合权重公式,模型原始权重 = 模型原始权重+系数 * lora权重。
def load_hunyuan_dit_lora(transformer_state_dict, lora_state_dict, lora_scale):
for i in range(num_layers):
Wqkv = torch.matmul(lora_state_dict[f"blocks.{i}.attn1.Wqkv.lora_B.weight"], lora_state_dict[f"blocks.{i}.attn1.Wqkv.lora_A.weight"]) # lora权重
q, k, v = torch.chunk(Wqkv, 3, dim=0)
transformer_state_dict[f"blocks.{i}.attn1.to_q.weight"] += lora_scale * q # 原始权重+lora权重
transformer_state_dict[f"blocks.{i}.attn1.to_k.weight"] += lora_scale * k
transformer_state_dict[f"blocks.{i}.attn1.to_v.weight"] += lora_scale * v
out_proj = torch.matmul(lora_state_dict[f"blocks.{i}.attn1.out_proj.lora_B.weight"], lora_state_dict[f"blocks.{i}.attn1.out_proj.lora_A.weight"])
transformer_state_dict[f"blocks.{i}.attn1.to_out.0.weight"] += lora_scale * out_proj
q_proj = torch.matmul(lora_state_dict[f"blocks.{i}.attn2.q_proj.lora_B.weight"], lora_state_dict[f"blocks.{i}.attn2.q_proj.lora_A.weight"])
transformer_state_dict[f"blocks.{i}.attn2.to_q.weight"] += lora_scale * q_proj
kv_proj = torch.matmul(lora_state_dict[f"blocks.{i}.attn2.kv_proj.lora_B.weight"], lora_state_dict[f"blocks.{i}.attn2.kv_proj.lora_A.weight"])
k, v = torch.chunk(kv_proj, 2, dim=0)
transformer_state_dict[f"blocks.{i}.attn2.to_k.weight"] += lora_scale * k
transformer_state_dict[f"blocks.{i}.attn2.to_v.weight"] += lora_scale * v
out_proj = torch.matmul(lora_state_dict[f"blocks.{i}.attn2.out_proj.lora_B.weight"], lora_state_dict[f"blocks.{i}.attn2.out_proj.lora_A.weight"])
transformer_state_dict[f"blocks.{i}.attn2.to_out.0.weight"] += lora_scale * out_proj

q_proj = torch.matmul(lora_state_dict["pooler.q_proj.lora_B.weight"], lora_state_dict["pooler.q_proj.lora_A.weight"])
transformer_state_dict["time_extra_emb.pooler.q_proj.weight"] += lora_scale * q_proj

return transformer_state_dict

controlnet推理,和训练基本一样,不一样的在于controlnet的特征可以乘上权重,再加上原始unet的特征。
                    controls = self.controlnet(
latent_model_input,
t_expand,
condition,
encoder_hidden_states=prompt_embeds,
text_embedding_mask=attention_mask,
encoder_hidden_states_t5=prompt_embeds_t5,
text_embedding_mask_t5=attention_mask_t5,
image_meta_size=ims,
style=style,
cos_cis_img=freqs_cis_img[0],
sin_cis_img=freqs_cis_img[1],
return_dict=False,
)
if isinstance(control_weight, list):
assert len(control_weight) == len(controls)
controls = [control * weight for control, weight in zip(controls, control_weight)] # 每一层特征乘以权重
else:
controls = [control * control_weight for control in controls]
noise_pred = self.unet(
latent_model_input,
t_expand,
encoder_hidden_states=prompt_embeds,
text_embedding_mask=attention_mask,
encoder_hidden_states_t5=prompt_embeds_t5,
text_embedding_mask_t5=attention_mask_t5,
image_meta_size=ims,
style=style,
cos_cis_img=freqs_cis_img[0],
sin_cis_img=freqs_cis_img[1],
return_dict=False,
controls=controls
)

更新时间 2024-09-30