当前位置:AIGC资讯 > AIGC > 正文

AIGC-controlnet代码详细解读

hugging face 社区diffusers官方代码:
stable_diffusion/controlnet
controlnet.ipynb
原始代码的解读可以看看这个博主的:
万字长文解读Stable Diffusion的核心插件—ControlNet

小部分讲解引用controlnet代码讲解

解读的是diffusersv0.16.0对应的controlnet代码里面也有对应的注释哈!
controlnet原理指路这篇博客:
https://blog.csdn.net/hwjokcq/article/details/138259623?spm=1001.2014.3001.5501

controlnet.py

路径:diffusers\src\diffusers\models\controlnet.py

class ControlNetOutput(BaseOutput):

这个数据类的作用是作为 ControlNetModel 的输出结果的容器,存储了模型在不同分辨率下的下采样激活值和中间块激活值,以便后续的条件生成或其他操作使用。

@dataclass
class ControlNetOutput(BaseOutput):
    """
    The output of [`ControlNetModel`].

    Args:
        down_block_res_samples (`tuple[torch.Tensor]`):
            A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
            be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
            used to condition the original UNet's downsampling activations.
        mid_down_block_re_sample (`torch.Tensor`):
            The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
            `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
            Output can be used to condition the original UNet's middle block activation.
    """

    down_block_res_samples: Tuple[torch.Tensor]
    #这个输出可以用于调节原始 UNet 的下采样激活值
    mid_block_res_sample: torch.Tensor
    #这个输出可以用于调节原始 UNet 的中间块激活值。

class ControlNetConditioningEmbedding(nn.Module):

这个模块的作用是将输入的图像条件编码为与 Stable Diffusion 模型匹配的特征图,以便在条件生成过程中使用。ControlNetConditioningEmbedding使用一个由四个卷积层组成,其核大小为4×4,步幅为2×2(由ReLU激活,通道数为16、32、64、128,使用高斯权重初始化,并与完整模型一起进行训练),来将condition图像编码为与sample对应的特征。最后一层为zero conv。

模块的结构如下:

输入层:使用 3x3 卷积核和 padding 为 1 的卷积层,将输入条件转换为特征图。 卷积块:通过多个卷积层对特征图进行编码。每个卷积块包括两个卷积层: 第一个卷积层保持通道数不变,使用 3x3 卷积核和 padding 为 1。 第二个卷积层改变通道数,使用 3x3 卷积核、padding 为 1 和 stride 为 2,以减小空间尺寸。 输出层:使用 3x3 卷积核和 padding 为 1 的卷积层,将编码后的特征图转换为指定通道数的嵌入结果。
输入输出: 输入的图像条件通常是一个形状为 (batch_size, conditioning_channels, height, width) 的张量,其中 conditioning_channels 默认为 3,表示 RGB 图像。
class ControlNetConditioningEmbedding(nn.Module):
    """
    Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
    [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
    training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
    convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
    (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
    model) to encode image-space conditions ... into feature maps ..."
    """

    def __init__(
        self,
        conditioning_embedding_channels: int,  # 嵌入后的特征图(输出图)的通道数
        conditioning_channels: int = 3,        # 输入的通道数,默认为3(RGB图像)
        block_out_channels: Tuple[int] = (16, 32, 96, 256),  # 定义各卷积层的输出通道数
    ):
        super().__init__()  

        # 定义输入层,使用3x3卷积核,padding为1以保持空间尺寸
        self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)

        # 初始化一个模块列表,用于存储后续的卷积块
        self.blocks = nn.ModuleList([])

        # 构建多个卷积层,每层可能包括不改变通道数的卷积层和改变通道数的卷积层
        for i in range(len(block_out_channels) - 1):
            channel_in = block_out_channels[i]  # 当前层的输入通道数
            channel_out = block_out_channels[i + 1]  # 当前层的输出通道数
            # 添加保持通道数的卷积层
            self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
            # 添加改变通道数的卷积层,stride为2以减小空间尺寸
            self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))

        # 定义输出层,使用3x3卷积核,padding为1,加上零卷积层
        self.conv_out = zero_module(
            nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
        )

    # 定义前向传播函数
    def forward(self, conditioning):
        # 通过输入层处理条件数据
        #conditioning应该是经过处理条件化处理得到的函数
        embedding = self.conv_in(conditioning)
        # 使用SILU激活函数
        embedding = F.silu(embedding)

        # 依次通过所有卷积块
        for block in self.blocks:
            # 应用卷积块
            embedding = block(embedding)
            # 使用SILU激活函数
            embedding = F.silu(embedding)

        # 最终通过输出层得到嵌入结果
        embedding = self.conv_out(embedding)

        # 返回嵌入结果
        return embedding

class ControlNetModel(ModelMixin, ConfigMixin):

参数

也可以参考https://huggingface.co/docs/diffusers/v0.16.0/en/api/pipelines/stable_diffusion/controlnet#diffusers.StableDiffusionControlNetPipeline

    """
    A ControlNet model.

    Args:
        in_channels (`int`, defaults to 4):
            The number of channels in the input sample.
        flip_sin_to_cos (`bool`, defaults to `True`):
            Whether to flip the sin to cos in the time embedding.
        freq_shift (`int`, defaults to 0):
            The frequency shift to apply to the time embedding.
        down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
            The tuple of downsample blocks to use.
        only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
        block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
            The tuple of output channels for each block.
        layers_per_block (`int`, defaults to 2):
            The number of layers per block.
        downsample_padding (`int`, defaults to 1):
            The padding to use for the downsampling convolution.
        mid_block_scale_factor (`float`, defaults to 1):
            The scale factor to use for the mid block.
        act_fn (`str`, defaults to "silu"):
            The activation function to use.
        norm_num_groups (`int`, *optional*, defaults to 32):
            The number of groups to use for the normalization. If None, normalization and activation layers is skipped
            in post-processing.
        norm_eps (`float`, defaults to 1e-5):
            The epsilon to use for the normalization.
        cross_attention_dim (`int`, defaults to 1280):
            The dimension of the cross attention features.
        attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
            The dimension of the attention heads.
        use_linear_projection (`bool`, defaults to `False`):
        class_embed_type (`str`, *optional*, defaults to `None`):
            The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
            `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
        num_class_embeds (`int`, *optional*, defaults to 0):
            Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
            class conditioning with `class_embed_type` equal to `None`.
        upcast_attention (`bool`, defaults to `False`):
        resnet_time_scale_shift (`str`, defaults to `"default"`):
            Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
        projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
            The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
            `class_embed_type="projection"`.
        controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
            The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
        conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
            The tuple of output channel for each block in the `conditioning_embedding` layer.
        global_pool_conditions (`bool`, defaults to `False`):
    """

    _supports_gradient_checkpointing = True

初始化部分 @register_to_config

构建一个含有controlnet的unet

输入卷积层:用于将输入图像转换为适合网络处理的形式。 时间嵌入层:将时间步信息编码为嵌入向量,这在生成模型中用于引入时间动态。 类别嵌入层:根据class_embed_type的设置,可以为模型提供额外的类别信息。 ControlNet条件嵌入层:用于将条件信息编码到一个嵌入向量中,这通常用于条件生成任务。 下采样块和ControlNet下采样块:构成模型主体的多层网络结构,每个块可能包含标准的卷积层、跨注意力层等。ControlNet每个非最终下采样块后还会添加一个零卷积层(zero_module),用于控制信息流。 中间块和ControlNet中间块:中间块位于UNet架构的最底部,用于处理最深层的特征。ControlNet中间块具有零卷积层
   @register_to_config
    #初始化函数
    def __init__(
        self,
        in_channels: int = 4,#输入的样本通道数
        conditioning_channels: int = 3,
        flip_sin_to_cos: bool = True,
        freq_shift: int = 0,
        down_block_types: Tuple[str] = (#下采样模型,默认为四个块
            "CrossAttnDownBlock2D",
            "CrossAttnDownBlock2D",
            "CrossAttnDownBlock2D",
            "DownBlock2D",
        ),
        only_cross_attention: Union[bool, Tuple[bool]] = False,
        block_out_channels: Tuple[int] = (320, 640, 1280, 1280),#每个块的输出通道数
        layers_per_block: int = 2,
        downsample_padding: int = 1,
        mid_block_scale_factor: float = 1,
        act_fn: str = "silu",
        norm_num_groups: Optional[int] = 32,
        norm_eps: float = 1e-5,
        cross_attention_dim: int = 1280,
        attention_head_dim: Union[int, Tuple[int]] = 8,
        num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
        use_linear_projection: bool = False,
        class_embed_type: Optional[str] = None,
        num_class_embeds: Optional[int] = None,
        upcast_attention: bool = False,
        resnet_time_scale_shift: str = "default",
        projection_class_embeddings_input_dim: Optional[int] = None,
        controlnet_conditioning_channel_order: str = "rgb",
        conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
        global_pool_conditions: bool = False,
    ):
        super().__init__()

        # If `num_attention_heads` is not defined (which is the case for most models)
        # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
        # The reason for this behavior is to correct for incorrectly named variables that were introduced
        # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
        # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
        # which is why we correct for the naming here.
        num_attention_heads = num_attention_heads or attention_head_dim

        # Check inputs
        if len(block_out_channels) != len(down_block_types):
            raise ValueError(
                f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
            )
        # 检查输出通道数和下采样块类型的数量是否一致,如果不一致则抛出ValueError异常

        if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
            raise ValueError(
                f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
            )
        # 检查only_cross_attention参数是否为布尔值,如果不是布尔值,则检查其长度是否与下采样块类型的数量一致,如果不一致则抛出ValueError异常

        if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
            raise ValueError(
                f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
            )
        # 检查注意力头的数量是否为整数,如果不是整数,则检查其长度是否与下采样块类型的数量一致,如果不一致则抛出ValueError异常

输入卷积层:

#-----------------------------------------------输入卷积层
        # input
        conv_in_kernel = 3
        conv_in_padding = (conv_in_kernel - 1) // 2
        self.conv_in = nn.Conv2d(
            in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
        )

时间步嵌入层

#-----------------------------------------------时间步嵌入层

        # time
        time_embed_dim = block_out_channels[0] * 4
        # 定义时间嵌入维度,为第一个下采样块的输出通道数的4倍

        self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
        #创建一个Timesteps层,用于将时间步信息投影到第一个下采样块的输出通道数维度
        
        timestep_input_dim = block_out_channels[0]
        #定义时间步嵌入层的输入维度,也为第一个下采样块的输出通道数。这个维度与time_proj层的输出维度一致。

        self.time_embedding = TimestepEmbedding(
            #创建一个TimestepEmbedding层,用于将时间步信息编码为连续的嵌入向量。
            #TimestepEmbedding是一种特定于时间步的嵌入层,通常用于扩散模型中。它将时间步信息编码为连续的嵌入向量,以捕捉时间相关的信息。
            timestep_input_dim,#时间步嵌入层的输入维度,即时间步信息经过time_proj层投影后的维度。
            time_embed_dim,#时间步嵌入层的输出维度,即时间步嵌入向量的大小。
            act_fn=act_fn,# 激活函数,用于在嵌入层中引入非线性。
        )


类别嵌入层

#-----------------------------------------------类别嵌入层

        # class embedding
        # 根据class_embed_type的值选择不同的类别嵌入方式:
        # 如果class_embed_type为None且num_class_embeds不为None,则使用nn.Embedding作为类别嵌入层
        # 如果class_embed_type为"timestep",则使用TimestepEmbedding作为类别嵌入层
        # 如果class_embed_type为"identity",则使用nn.Identity作为类别嵌入层
        # 如果class_embed_type为"projection",则检查projection_class_embeddings_input_dim是否为None,如果为None则抛出ValueError异常,否则使用TimestepEmbedding作为类别嵌入层
        # 如果class_embed_type为其他值,则将类别嵌入层设为None
        if class_embed_type is None and num_class_embeds is not None:
            self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
        elif class_embed_type == "timestep":
            self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
        elif class_embed_type == "identity":
            self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
        elif class_embed_type == "projection":
            if projection_class_embeddings_input_dim is None:
                raise ValueError(
                    "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
                )
            # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
            # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
            # 2. it projects from an arbitrary input dimension.
            #正弦嵌入:与“timestep”类型不同,在“timestep”类型中,class_labels首先被转换为正弦嵌入以捕捉时间中的周期性模式,“projection”类型直接接受任意输入向量,而不进行这种初始转换。
            #输入维度:它从一个任意的输入维度(由projection_class_embeddings_input_dim指定)进行投影,这允许根据输入数据的性质提供更大的灵活性。
            #
            # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
            # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
            # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
            self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
        else:
            self.class_embedding = None

        

ControlNet条件嵌入层

#-----------------------------------------------ControlNet条件嵌入层
        # control net conditioning embedding
        #将条件信息(如图像)编码为与潜在空间匹配的嵌入向量,以便在生成过程中作为条件引导生成结果
        self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
            conditioning_embedding_channels=block_out_channels[0],
            #条件嵌入层将接收与第一个下采样块输出相同大小的特征图作为输入
            block_out_channels=conditioning_embedding_out_channels,
            #定义各卷积层的输出通道数
            conditioning_channels=conditioning_channels,
            #输入的通道数,默认为 3(RGB 图像)
        )

SD下采样块列表和ControlNet下采样块列表

#-----------------------------------------------SD下采样块列表和ControlNet下采样块列表

        self.down_blocks = nn.ModuleList([])
        self.controlnet_down_blocks = nn.ModuleList([])


        if isinstance(only_cross_attention, bool):
            only_cross_attention = [only_cross_attention] * len(down_block_types)
            #如果only_cross_attention为布尔值,则将其重复len(down_block_types)次,确保每个下采样块都有对应的only_cross_attention设置
            
        if isinstance(attention_head_dim, int):
            attention_head_dim = (attention_head_dim,) * len(down_block_types)
            # 如果attention_head_dim为整数,则将其重复len(down_block_types)次,确保每个下采样块都有对应的attention_head_dim设置

        if isinstance(num_attention_heads, int):
            num_attention_heads = (num_attention_heads,) * len(down_block_types)
            # 如果num_attention_heads为整数,则将其重复len(down_block_types)次,确保每个下采样块都有对应的num_attention_heads设置



        # down
        '''从SD复制过来的DownBlock,执行每层时保留对应的sample(latents) 到down_block_res_samples'''
        output_channel = block_out_channels[0]# 获取第一个下采样块的输出通道数

        controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
        controlnet_block = zero_module(controlnet_block)# 使用 zero_module 函数对控制网络下采样块进行初始化
        self.controlnet_down_blocks.append(controlnet_block)

        for i, down_block_type in enumerate(down_block_types):
            input_channel = output_channel# 将上一个下采样块的输出通道数作为当前下采样块的输入通道数
            output_channel = block_out_channels[i]# 获取当前下采样块的输出通道数
            is_final_block = i == len(block_out_channels) - 1

            down_block = get_down_block(#使用 get_down_block 函数获取当前的一个下采样块
                down_block_type,
                num_layers=layers_per_block,
                in_channels=input_channel,
                out_channels=output_channel,
                temb_channels=time_embed_dim,
                add_downsample=not is_final_block,#是否在下采样块中添加下采样操作,如果当前块不是最后一个块,则为 True
                resnet_eps=norm_eps,
                resnet_act_fn=act_fn,
                resnet_groups=norm_num_groups,
                cross_attention_dim=cross_attention_dim,
                num_attention_heads=num_attention_heads[i],
                attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
                downsample_padding=downsample_padding,
                use_linear_projection=use_linear_projection,
                only_cross_attention=only_cross_attention[i],
                upcast_attention=upcast_attention,
                resnet_time_scale_shift=resnet_time_scale_shift,
            )
            self.down_blocks.append(down_block)#stable diffusion对应的down_blocks,不过都要加零卷积层

            for _ in range(layers_per_block):#与stable diffusion对应的
                # controlnet每个下采样块后面都有零卷积层
                controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
                controlnet_block = zero_module(controlnet_block)
                self.controlnet_down_blocks.append(controlnet_block)

            if not is_final_block:
                controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
                controlnet_block = zero_module(controlnet_block)
                self.controlnet_down_blocks.append(controlnet_block)

#----------------------------------UNet 模型的中间块(mid block)以及相应的 ControlNet 中间块

        # mid
        mid_block_channel = block_out_channels[-1]

        controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
        controlnet_block = zero_module(controlnet_block)
        self.controlnet_mid_block = controlnet_block

        self.mid_block = UNetMidBlock2DCrossAttn(
            in_channels=mid_block_channel,
            temb_channels=time_embed_dim,#时间嵌入的通道数,等于 time_embed_dim
            resnet_eps=norm_eps,
            resnet_act_fn=act_fn,
            output_scale_factor=mid_block_scale_factor,
            resnet_time_scale_shift=resnet_time_scale_shift,
            cross_attention_dim=cross_attention_dim,
            num_attention_heads=num_attention_heads[-1],
            resnet_groups=norm_num_groups,
            use_linear_projection=use_linear_projection,
            upcast_attention=upcast_attention,
        )
#-----------------------

类方法 @classmethod

这段代码的主要功能是允许通过一个现有的 UNet2DConditionModel 模型来初始化一个新的 ControlNetModel,包括复制配置和可选的权重复制。这样可以确保新模型在保持原有模型特性的基础上,可以进行进一步的定制和优化

@classmethod
def from_unet(
    cls,
    unet: UNet2DConditionModel,
    controlnet_conditioning_channel_order: str = "rgb",
    conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),#定义了不同层级的条件嵌入的输出通道数
    load_weights_from_unet: bool = True,
):
    # 类方法定义,用于从UNet2DConditionModel实例创建ControlNetModel实例
    r"""
    Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`].

    Parameters:
        unet (`UNet2DConditionModel`):
            The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
            where applicable.
    """
    controlnet = cls(
        # 创建一个新的ControlNetModel实例,复制unet的配置参数
        in_channels=unet.config.in_channels,
        flip_sin_to_cos=unet.config.flip_sin_to_cos,
        freq_shift=unet.config.freq_shift,
        down_block_types=unet.config.down_block_types,
        only_cross_attention=unet.config.only_cross_attention,
        block_out_channels=unet.config.block_out_channels,
        layers_per_block=unet.config.layers_per_block, # 默认值:2
        downsample_padding=unet.config.downsample_padding,
        mid_block_scale_factor=unet.config.mid_block_scale_factor,
        act_fn=unet.config.act_fn,
        norm_num_groups=unet.config.norm_num_groups,
        norm_eps=unet.config.norm_eps,
        cross_attention_dim=unet.config.cross_attention_dim,
        attention_head_dim=unet.config.attention_head_dim,
        num_attention_heads=unet.config.num_attention_heads,
        use_linear_projection=unet.config.use_linear_projection,
        class_embed_type=unet.config.class_embed_type,
        num_class_embeds=unet.config.num_class_embeds,
        upcast_attention=unet.config.upcast_attention,
        resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
        projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
        controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
        conditioning_embedding_out_channels=conditioning_embedding_out_channels,
    )

    if load_weights_from_unet:
        # 如果指定从unet加载权重
        controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
        controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
        controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())

        if controlnet.class_embedding:
            # 如果ControlNetModel有类嵌入层,从unet复制权重
            controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())

        controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict())
        controlnet.mid_block.load_state_dict(unet.mid_block.state_dict())

    return controlnet
    # 返回新配置和权重初始化的ControlNetModel实例
    

controlnet的method定义

注意力相关函数
提供了一些方法来管理和设置模型中的注意力处理器和注意力切片,以优化模型的内存使用和计算效率。通过递归遍历模型的层次结构,可以对所有的注意力层进行统一的设置和管理。 attn_processors 方法: 返回一个包含模型中使用的所有注意力处理器的字典。 通过递归遍历模型的层次结构,收集所有的注意力处理器,并返回一个包含处理器名称和对象的字典。 set_attn_processor 方法: 设置用于计算注意力的注意力处理器。 接受一个 AttentionProcessor 对象或一个包含 AttentionProcessor 对象的字典作为参数。 如果传入的是字典,则字典的键需要定义对应的交叉注意力处理器的路径。 通过递归遍历模型的层次结构,将注意力处理器设置到对应的注意力层中。 set_default_attn_processor 方法: 禁用自定义的注意力处理器,并设置默认的注意力实现。 通过调用 set_attn_processor 方法,传入 AttnProcessor() 对象来设置默认的注意力处理器。 set_attention_slice 方法: 在 UNet2DConditionModel 中启用切片注意力计算。 当启用该选项时,注意力模块会将输入张量分割成多个切片,并在多个步骤中计算注意力,以节省内存,但可能会略微降低速度。 接受一个 slice_size 参数,用于指定切片的大小。 通过递归遍历模型的层次结构,检索所有可切片的注意力头维度,并根据 slice_size 的值确定实际的切片大小。 对于每个注意力层,根据对应的切片大小设置注意力切片。
 @property
    # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
    def attn_processors(self) -> Dict[str, AttentionProcessor]:
#返回一个包含模型中使用的所有注意力处理器的字典
        r"""
        Returns:
            `dict` of attention processors: A dictionary containing all attention processors used in the model with
            indexed by its weight name.
        """
        # set recursively
        processors = {}

        def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
        #遍历整个模型的层次结构,并收集所有的注意力处理器,返回一个包含处理器名称和对象的字典。
            if hasattr(module, "set_processor"):
                processors[f"{name}.processor"] = module.processor

            for sub_name, child in module.named_children():
            #遍历当前模块的所有子模块,并递归地调用,这样可以确保遍历整个模型的层次结构。
                fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)

            return processors

        for name, module in self.named_children():
            fn_recursive_add_processors(name, module, processors)

        return processors
#-------------------------------------------
    # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
    def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
    
        r"""
        Sets the attention processor to use to compute attention.

        Parameters:
            processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
                The instantiated processor class or a dictionary of processor classes that will be set as the processor
                for **all** `Attention` layers.

                If `processor` is a dict, the key needs to define the path to the corresponding cross attention
                processor. This is strongly recommended when setting trainable attention processors.

        """
        count = len(self.attn_processors.keys())

        if isinstance(processor, dict) and len(processor) != count:
            raise ValueError(
                f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
                f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
            )

        def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
            if hasattr(module, "set_processor"):
                if not isinstance(processor, dict):
                    module.set_processor(processor)
                else:
                    module.set_processor(processor.pop(f"{name}.processor"))

            for sub_name, child in module.named_children():
                fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)

        for name, module in self.named_children():
            fn_recursive_attn_processor(name, module, processor)

    # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
    def set_default_attn_processor(self):
        """
        Disables custom attention processors and sets the default attention implementation.
        """
        self.set_attn_processor(AttnProcessor())

    # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
#---------------------------------------------------
    def set_attention_slice(self, slice_size):
    #用于在 UNet2DConditionModel 中启用切片注意力计算
        r"""
        Enable sliced attention computation.

        When this option is enabled, the attention module splits the input tensor in slices to compute attention in
        several steps. This is useful for saving some memory in exchange for a small decrease in speed.

        Args:
            slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
                When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
                `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
                provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
                must be a multiple of `slice_size`.
        """
        sliceable_head_dims = []
# 定义一个内部函数,用于递归地检索模型中所有可切片的注意力头维度
        def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
            if hasattr(module, "set_attention_slice"):
                sliceable_head_dims.append(module.sliceable_head_dim)

            for child in module.children():
                fn_recursive_retrieve_sliceable_dims(child)

        # retrieve number of attention layers
            # 通过遍历模型的所有子模块,调用 fn_recursive_retrieve_sliceable_dims 函数来检索可切片的注意力头维度

        for module in self.children():
            fn_recursive_retrieve_sliceable_dims(module)
        # 获取可切片的注意力层的数量,即 sliceable_head_dims 列表的长度
        num_sliceable_layers = len(sliceable_head_dims)
		 # 根据 slice_size 的值,确定实际的切片大小
        if slice_size == "auto":
            # half the attention head size is usually a good trade-off between
            # speed and memory
            slice_size = [dim // 2 for dim in sliceable_head_dims]
        elif slice_size == "max":
            # make smallest slice possible
            slice_size = num_sliceable_layers * [1]
		    # 如果 slice_size 是一个整数值,则将其复制 num_sliceable_layers 次,形成一个列表
        slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size

        if len(slice_size) != len(sliceable_head_dims):
            raise ValueError(
                f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
                f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
            )

        for i in range(len(slice_size)):
            size = slice_size[i]
            dim = sliceable_head_dims[i]
            if size is not None and size > dim:
                raise ValueError(f"size {size} has to be smaller or equal to {dim}.")

        # Recursively walk through all the children.
        # Any children which exposes the set_attention_slice method
        # gets the message
         # 定义一个内部函数,用于递归地设置模型中的注意力切片大小
        def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
            if hasattr(module, "set_attention_slice"):
                module.set_attention_slice(slice_size.pop())

            for child in module.children():
                fn_recursive_set_attention_slice(child, slice_size)

        reversed_slice_size = list(reversed(slice_size))
        for module in self.children():
            fn_recursive_set_attention_slice(module, reversed_slice_size)

controlnet的前向传播函数

实现了 ControlNet 模型的前向传播过程,包括对输入样本的预处理、时间步嵌入、下采样块的处理、ControlNet 块的处理以及残差样本的缩放和池化等操作。通过这些步骤,ControlNet 模型可以根据条件输入和编码器的隐藏状态生成控制信号,用于指导图像生成过程。

方法参数
sample: 输入的噪声数据张量。
timestep: 时间步长,用于去噪过程。
encoder_hidden_states: 编码器的隐藏状态。
controlnet_cond: 控制网络的条件输入张量。
conditioning_scale: 控制网络输出的缩放因子。
class_labels: 可选的类别标签,用于条件生成。
timestep_cond: 与时间步长相关的条件。
attention_mask: 注意力掩码。
cross_attention_kwargs: 传递给注意力处理器的额外参数。
guess_mode: 一种模式,在此模式下,模型尝试识别输入内容,即使没有提示。
return_dict: 是否返回一个 ControlNetOutput 对象。

方法主体

通道顺序检查: 根据配置确定条件输入张量的通道顺序,如果需要,翻转颜色通道。

注意力掩码准备: 如果提供了注意力掩码,将其转换为适合模型的形式。

时间步处理: 确保时间步长与输入样本的设备一致,并且广播到批次大小。

时间嵌入: 使用self.time_proj层将时间步张量映射到嵌入空间,生成时间步的嵌入表示t_emb。使用时间投影层和时间嵌入层生成时间嵌入。使用self.time_embedding层将时间步嵌入t_emb和时间步条件timestep_cond组合成最终的时间嵌入emb。

类别嵌入: 如果有classs embedding,则还会将class label转换为class_embedding,并将其添加到time_embedding中。

条件嵌入: 将控制网络的条件输入转换为特征空间向量。如果配置中指定了其他类型的嵌入(如text_embedding),则还会将text_embedding并添加到time_embedding中。最后,所有这些embedding向量被concat在一起并返回。

预处理: 使用self.conv_in层对输入样本进行初始卷积处理。使用self.controlnet_cond_embedding层将条件输入张量controlnet_cond转换为嵌入向量。将预处理后的输入样本与条件嵌入向量相加。

下采样self.down_blocks: 对于每个下采样块,根据是否有跨注意力机制,调用相应的前向传播函数。收集每个下采样块的输出,并将其存储在down_block_res_samples中。

中间块: 如果存在中间块,将其应用于样本。

ControlNet块: 将下采样的结果传递给对应的ControlNet块进行处理。对于每个下采样块的输出和对应的ControlNet块,调用ControlNet块的前向传播函数进行处理。将处理后的结果存储在controlnet_down_block_res_samples中。更新down_block_res_samples,使其指向经过ControlNet处理后的结果。

缩放: 根据猜测模式和全局池化条件,对下采样结果和中间块结果进行缩放。

全局池化: 如果配置要求,对结果进行全局平均池化。

输出: 根据 return_dict 参数,返回一个 ControlNetOutput 对象或一个包含结果的元组。


    def _set_gradient_checkpointing(self, module, value=False):
        if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
            module.gradient_checkpointing = value

    def forward(
        self,
        sample: torch.FloatTensor,
        timestep: Union[torch.Tensor, float, int],
        encoder_hidden_states: torch.Tensor,
        controlnet_cond: torch.FloatTensor,
        conditioning_scale: float = 1.0,
        class_labels: Optional[torch.Tensor] = None,
        timestep_cond: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
        guess_mode: bool = False,
        return_dict: bool = True,
    ) -> Union[ControlNetOutput, Tuple]:
        """
        The [`ControlNetModel`] forward method.

        Args:
            sample (`torch.FloatTensor`):
                The noisy input tensor.
            timestep (`Union[torch.Tensor, float, int]`):
                The number of timesteps to denoise an input.
            encoder_hidden_states (`torch.Tensor`):
                The encoder hidden states.
            controlnet_cond (`torch.FloatTensor`):
                The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
            conditioning_scale (`float`, defaults to `1.0`):
                The scale factor for ControlNet outputs.
            class_labels (`torch.Tensor`, *optional*, defaults to `None`):
                Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
            timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
            attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
            cross_attention_kwargs(`dict[str]`, *optional*, defaults to `None`):
                A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
            guess_mode (`bool`, defaults to `False`):
                In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
                you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
            return_dict (`bool`, defaults to `True`):
                Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.

        Returns:
            [`~models.controlnet.ControlNetOutput`] **or** `tuple`:
                If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
                returned where the first element is the sample tensor.
        """
        # check channel order
        #用于确定条件输入张量的通道顺序
        channel_order = self.config.controlnet_conditioning_channel_order

        if channel_order == "rgb":
            # in rgb order by default
            ...
        elif channel_order == "bgr":
            controlnet_cond = torch.flip(controlnet_cond, dims=[1])
            #一个条件输入张量,它用于向模型提供额外的输入信息,以指导模型的生成或处理过程。
        else:
            raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")

        # prepare attention_mask
        if attention_mask is not None:
            attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0#转换成与样本相同的类型
            attention_mask = attention_mask.unsqueeze(1)

        # 1. time
        #确保形式正确
        timesteps = timestep
        if not torch.is_tensor(timesteps):
            # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
            # This would be a good case for the `match` statement (Python 3.10+)
            is_mps = sample.device.type == "mps"
            if isinstance(timestep, float):
                dtype = torch.float32 if is_mps else torch.float64
            else:
                dtype = torch.int32 if is_mps else torch.int64
            timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
        elif len(timesteps.shape) == 0:
            timesteps = timesteps[None].to(sample.device)

        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
        timesteps = timesteps.expand(sample.shape[0])
        #增加张量的尺寸不增加维度

        t_emb = self.time_proj(timesteps)
        #self.time_proj 层将时间步张量映射到一个嵌入空间,生成时间步的嵌入表示 t_emb

        # timesteps does not contain any weights and will always return f32 tensors
        # but time_embedding might actually be running in fp16. so we need to cast here.
        # there might be better ways to encapsulate this.
        t_emb = t_emb.to(dtype=sample.dtype)

        emb = self.time_embedding(t_emb, timestep_cond)

        if self.class_embedding is not None:
            if class_labels is None:
                #检查类别嵌入层是否存在,并确保提供了类别标签。
                raise ValueError("class_labels should be provided when num_class_embeds > 0")

            if self.config.class_embed_type == "timestep":
                class_labels = self.time_proj(class_labels)

            class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
            emb = emb + class_emb

        # 2. pre-process
        sample = self.conv_in(sample)#对输入样本进行初始处理
        ''''self.conv_in = nn.Conv2d(
            in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
        )'''
        controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
        #基于图像的条件(例如边缘图、深度图等)转换为一个特征空间向量

        sample = sample + controlnet_cond

        # 3. down
        #特征图会通过多个下采样层逐步降低分辨率,同时增加通道数,以便捕捉更高层次的特征
        down_block_res_samples = (sample,)#用于收集每次下采样后的结果。
        for downsample_block in self.down_blocks:
            #是否有跨注意力机制
            if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
                sample, res_samples = downsample_block(
                    hidden_states=sample,
                    temb=emb,
                    encoder_hidden_states=encoder_hidden_states,
                    attention_mask=attention_mask,
                    cross_attention_kwargs=cross_attention_kwargs,
                )
            else:
                sample, res_samples = downsample_block(hidden_states=sample, temb=emb)

            down_block_res_samples += res_samples

        # 4. mid
        if self.mid_block is not None:
            sample = self.mid_block(
                sample,
                emb,
                encoder_hidden_states=encoder_hidden_states,
                attention_mask=attention_mask,
                cross_attention_kwargs=cross_attention_kwargs,
            )

        # 5. Control net blocks
            '''这段代码的作用是将常规下采样块的结果传递给对应的ControlNet块进行处理,
            并将处理后的结果收集起来,以便它们可以在模型的后续部分中使用。
            这样,ControlNet能够通过学习到的条件控制来引导图像生成过程,从而生成满足特定条件的图像。'''

        controlnet_down_block_res_samples = ()
		'''使用controlnet_block 中的 zero_conv对每层特征进行处理:将经过ControlNetModel的down_block
		然后经过zero convolution的结果,全部保存在controlnet_down_block_res_samples当中,
		再保存到down_block_res_samples中。Middle Block经过zero convolution输出保存在mid_block_res_sample当中。
		'''
        for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
        #这行代码开始一个循环,通过 zip 函数同时迭代 down_block_res_samples(常规下采样块的结果)和 self.controlnet_down_blocks(对应的ControlNet块) 
               
            down_block_res_sample = controlnet_block(down_block_res_sample)
            #对于每一对下采样结果和ControlNet块,使用ControlNet块处理下采样结果。这将添加条件控制到数据流中,从而影响最终的图像生成。

            controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)

        down_block_res_samples = controlnet_down_block_res_samples
        #更新 down_block_res_samples 变量,使其指向经过ControlNet处理后的结果。



        mid_block_res_sample = self.controlnet_mid_block(sample)

        # 6. scaling
        ''' 。缩放因子允许模型在不同的层次上关注不同的特征,
        而全局池化可以减少特征的空间尺寸,
        使得模型更加关注全局特征而非局部细节'''
        '''对于 MultiControlNetModel 的推理过程,在有多个 condition_image 情况下,
        down_block_res_samples 以及 mid_block_res_sample 则为所有类型的 condition_image 输出加和。'''
        if guess_mode and not self.config.global_pool_conditions:
            scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device)  # 0.1 to 1.0
            #自适应缩放残差样本
            scales = scales * conditioning_scale
            down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
            mid_block_res_sample = mid_block_res_sample * scales[-1]  # last one
        else:
            down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
            mid_block_res_sample = mid_block_res_sample * conditioning_scale

        if self.config.global_pool_conditions:
            down_block_res_samples = [
                torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
            ]
            mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)

        if not return_dict:
            return (down_block_res_samples, mid_block_res_sample)

        return ControlNetOutput(
            down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
        )

zero_module(module)

零卷积层:
初始的参数都为零(包括weight和bias)

def zero_module(module):
    for p in module.parameters():
        nn.init.zeros_(p)
    return module

pipeline_controlnet.py

相对路径:\diffusers\src\diffusers\pipelines\controlnet\pipeline_controlnet.py
根据prompt和image和ip_adapter_image的引导,生成新的图像。

输入:prompt、ip_adapter_image(reference_image)、image(condition_image),输出:output_image。

一些初始化函数和用于减少内存的函数

类的初始化

    _optional_components = ["safety_checker", "feature_extractor"]

    def __init__(
        self,
        vae: AutoencoderKL,
        text_encoder: CLIPTextModel,
        tokenizer: CLIPTokenizer,
        unet: UNet2DConditionModel,
        controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
        scheduler: KarrasDiffusionSchedulers,
        safety_checker: StableDiffusionSafetyChecker,
        feature_extractor: CLIPImageProcessor,
        requires_safety_checker: bool = True,
    ):
        super().__init__()

        if safety_checker is None and requires_safety_checker:
            logger.warning(
                f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
                " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
                " results in services or applications open to the public. Both the diffusers team and Hugging Face"
                " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
                " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
                " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
            )

        if safety_checker is not None and feature_extractor is None:
            raise ValueError(
                "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
                " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
            )

        if isinstance(controlnet, (list, tuple)):
            controlnet = MultiControlNetModel(controlnet)

        self.register_modules(
            vae=vae,
            text_encoder=text_encoder,
            tokenizer=tokenizer,
            unet=unet,
            controlnet=controlnet,
            scheduler=scheduler,
            safety_checker=safety_checker,
            feature_extractor=feature_extractor,
        )
        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
        self.control_image_processor = VaeImageProcessor(
            vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
        )
        self.register_to_config(requires_safety_checker=requires_safety_checker)

diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling一样的函数,包括启动/禁用VAE的切片解码功能,平铺解码功能,将模型卸载到cpu

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
    def enable_vae_slicing(self):
        r"""
        Enable sliced VAE decoding.

        When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
        steps. This is useful to save some memory and allow larger batch sizes.
        """
        self.vae.enable_slicing()

    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
    def disable_vae_slicing(self):
        r"""
        Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
        computing decoding in one step.
        """
        self.vae.disable_slicing()

    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
    def enable_vae_tiling(self):
        r"""
        Enable tiled VAE decoding.

        When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
        several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
        """
        self.vae.enable_tiling()

    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
    def disable_vae_tiling(self):
        r"""
        Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
        computing decoding in one step.
        """
        self.vae.disable_tiling()

    def enable_sequential_cpu_offload(self, gpu_id=0):
        r"""
        Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
        text_encoder, vae, controlnet, and safety checker have their state dicts saved to CPU and then are moved to a
        `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
        Note that offloading happens on a submodule basis. Memory savings are higher than with
        `enable_model_cpu_offload`, but performance is lower.
        """
        if is_accelerate_available():
            from accelerate import cpu_offload
        else:
            raise ImportError("Please install accelerate via `pip install accelerate`")

        device = torch.device(f"cuda:{gpu_id}")

        for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.controlnet]:
            cpu_offload(cpu_offloaded_model, device)

        if self.safety_checker is not None:
            cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)

    def enable_model_cpu_offload(self, gpu_id=0):
        r"""
        Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
        to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
        method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
        `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
        """
        if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
            from accelerate import cpu_offload_with_hook
        else:
            raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")

        device = torch.device(f"cuda:{gpu_id}")

        hook = None
        for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
            _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)

        if self.safety_checker is not None:
            # the safety checker can offload the vae again
            _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)

        # control net hook has be manually offloaded as it alternates with unet
        cpu_offload_with_hook(self.controlnet, device)

        # We'll offload the last model manually.
        self.final_offload_hook = hook



##处理prompts和输入数据的函数函数
这段代码的主要功能是处理文本提示,将其转换为模型可以理解的嵌入表示,这些嵌入随后用于图像生成模型以生成与文本描述相匹配的图像。代码还包含了对无条件文本嵌入的处理,这些嵌入在分类器自由引导(classifier-free guidance)中使用,以提高图像生成的质量和效率。此外,还检查了文本提示和嵌入的一致性,并处理了文本超长时的截断问题。

_encode_prompt: 将文本提示编码为文本编码器的隐藏状态。 处理分类器自由引导(classifier-free guidance)。 生成无条件文本嵌入(如果需要)。 run_safety_checker:
检查生成的图像是否包含不适当的内容。 decode_latents: 将潜在空间的表示解码回图像。 此方法已被弃用,建议使用 VaeImageProcessor。 prepare_extra_step_kwargs:
为调度器步骤准备额外的关键字参数,如 eta 和 generator。 check_inputs:
检查输入参数的有效性,包括提示、图像、回调步骤、负面提示、嵌入张量、控制网条件缩放、控制引导开始和结束。 check_image:
检查图像输入的有效性,包括类型和批量大小。 prepare_image:
对输入图像进行预处理,调整大小,并准备用于生成过程。 prepare_latents:
准备潜在空间张量,以便在生成模型中使用。
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
#  _encode_prompt,用于将文本提示编码为文本编码器的隐藏状态。
def _encode_prompt(
    self,
    prompt,  # 输入的文本提示,可以是单个字符串或字符串列表。
    device,  # 指定了执行计算的设备,例如 CPU 或 GPU。
    num_images_per_prompt,  # 每个提示生成的图像数量。
    do_classifier_free_guidance,  # 是否使用分类器自由引导。
    negative_prompt=None,  # 不希望指导图像生成的提示或提示列表。
    prompt_embeds: Optional[torch.FloatTensor] = None,  # 预生成的文本嵌入。
    negative_prompt_embeds: Optional[torch.FloatTensor] = None,  # 预生成的负面文本嵌入。
    lora_scale: Optional[float] = None,  # LoRA 层的缩放因子。
):

    # 如果提供了 lora_scale 并且当前实例是 LoraLoaderMixin 的一个实例,则设置 LoRA 缩放因子。
    if lora_scale is not None and isinstance(self, LoraLoaderMixin):
        self._lora_scale = lora_scale

    # 根据 prompt 的类型确定批量大小。
    if prompt is not None and isinstance(prompt, str):
        batch_size = 1
    elif prompt is not None and isinstance(prompt, list):
        batch_size = len(prompt)
    else:
        batch_size = prompt_embeds.shape[0]

    # 如果没有提供预生成的文本嵌入,则使用分词器和文本编码器生成它们。
    if prompt_embeds is None:
        # 如果是 TextualInversionLoaderMixin 实例,可能需要转换提示文本。
        if isinstance(self, TextualInversionLoaderMixin):
            prompt = self.maybe_convert_prompt(prompt, self.tokenizer)

        # 使用分词器处理文本提示。
        text_inputs = self.tokenizer(
            prompt,
            padding="max_length",  # 填充策略。
            max_length=self.tokenizer.model_max_length,  # 分词器模型的最大长度。
            truncation=True,  # 截断超出最大长度的文本。
            return_tensors="pt",  # 返回 PyTorch 张量。
        )
        # 获取处理后的输入 ID。
        text_input_ids = text_inputs.input_ids#经过分词器编码后的文本的token ID序列

        # 如果原始文本被截断,记录一条警告信息。
        untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
        if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
            removed_text = self.tokenizer.batch_decode(
                untruncated_ids[:, self.tokenizer.model_max_length - 1: -1]
            )
            logger.warning(
                "The following part of your input was truncated because CLIP can only handle sequences up to "
                f"{self.tokenizer.model_max_length} tokens: {removed_text}"
            )

        # 根据配置确定是否使用注意力掩码。
        if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
            attention_mask = text_inputs.attention_mask.to(device)
        else:
            attention_mask = None

        # 使用文本编码器生成文本嵌入。
        prompt_embeds = self.text_encoder(
            #对text_input_ids进行编码,生成对应的文本嵌入向量。
            text_input_ids.to(device),
            attention_mask=attention_mask,
        )
        prompt_embeds = prompt_embeds[0]#第一个元素是编码后的文本嵌入向量

    # 确保文本嵌入的设备和数据类型与文本编码器一致。
    prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)

    # 获取嵌入的维度信息:批次大小:序列长度
    bs_embed, seq_len, _ = prompt_embeds.shape

    # 为每个提示生成的每张图像重复文本嵌入。
    prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
    #(bs_embed, seq_len, embed_dim),在seq_len维度上重复num_images_per_promp次,
    #为每个文本提示生成的每张图像都重复一次文本嵌入向量
    #种重复操作确保了在后续的生成过程中,每个文本提示都能生成num_images_per_prompt张对应的图像,并且每张图像都使用相同的文本嵌入向量作为条件。

    # 调整张量的形状以匹配模型的期望输入。
    prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)

    # 如果需要分类器自由引导并且没有提供负面文本嵌入,则生成它们。
    if do_classifier_free_guidance and negative_prompt_embeds is None:
        # 根据负面提示的类型创建无条件的标记列表。
        uncond_tokens: List[str]
        if negative_prompt is None:
            uncond_tokens = [""] * batch_size#长度与批次大小相同
        elif prompt is not None and type(prompt) is not type(negative_prompt):
            raise TypeError(
                f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
                f" {type(prompt)}."
            )
        elif isinstance(negative_prompt, str):
            uncond_tokens = [negative_prompt]
            #如果 negative_prompt 是字符串类型,则将其作为单个元素放入 uncond_tokens 列表中
        elif batch_size != len(negative_prompt):
            raise ValueError(
                f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
                f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
                " the batch size of `prompt`."
            )
        else:
            uncond_tokens = negative_prompt

        # 如果是 TextualInversionLoaderMixin 实例,可能需要转换无条件的标记。
        if isinstance(self, TextualInversionLoaderMixin):
            uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)

        # 使用分词器处理无条件的标记。
        max_length = prompt_embeds.shape[1]
        uncond_input = self.tokenizer(
            uncond_tokens,
            padding="max_length",
            max_length=max_length,
            truncation=True,
            return_tensors="pt",
        )

        # 根据配置确定是否使用注意力掩码。
        if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
            attention_mask = uncond_input.attention_mask.to(device)
        else:
            attention_mask = None

        # 使用文本编码器生成无条件的文本嵌入。
        negative_prompt_embeds = self.text_encoder(
            uncond_input.input_ids.to(device),
            attention_mask=attention_mask,
        )
        negative_prompt_embeds = negative_prompt_embeds[0]

    # 如果启用了分类器自由引导,重复无条件嵌入并调整形状以匹配模型的期望输入。
    if do_classifier_free_guidance:
        seq_len = negative_prompt_embeds.shape[1]
        #无条件嵌入(negative_prompt_embeds)的序列长度

        negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
        #为了确保无条件嵌入与文本编码器的数据类型和设备一致。
        negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
        #样做是为了将无条件嵌入的数量与每个提示生成的图像数量相匹配。

        negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)

        # 将无条件嵌入和文本嵌入拼接起来,以便于进行分类器自由引导。
        prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
        #(2 * batch_size * num_images_per_prompt, seq_len, embedding_dim)

    # 返回最终的文本嵌入。
    return prompt_embeds

    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
    def run_safety_checker(self, image, device, dtype):
        if self.safety_checker is None:
            has_nsfw_concept = None
        else:
            if torch.is_tensor(image):
                feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
            else:
                feature_extractor_input = self.image_processor.numpy_to_pil(image)
            safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
            image, has_nsfw_concept = self.safety_checker(
                images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
            )
        return image, has_nsfw_concept

    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
    def decode_latents(self, latents):
        warnings.warn(
            "The decode_latents method is deprecated and will be removed in a future version. Please"
            " use VaeImageProcessor instead",
            FutureWarning,
        )
        latents = 1 / self.vae.config.scaling_factor * latents
        image = self.vae.decode(latents, return_dict=False)[0]
        image = (image / 2 + 0.5).clamp(0, 1)
        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
        image = image.cpu().permute(0, 2, 3, 1).float().numpy()
        return image

    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
    def prepare_extra_step_kwargs(self, generator, eta):
        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
        # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
        # and should be between [0, 1]

        accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
        extra_step_kwargs = {}
        if accepts_eta:
            extra_step_kwargs["eta"] = eta

        # check if the scheduler accepts generator
        accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
        if accepts_generator:
            extra_step_kwargs["generator"] = generator
        return extra_step_kwargs

def check_inputs(
    self,
    prompt,
    image,
    callback_steps,
    negative_prompt=None,
    prompt_embeds=None,
    negative_prompt_embeds=None,
    controlnet_conditioning_scale=1.0,
    control_guidance_start=0.0,
    control_guidance_end=1.0,
):
    # 检查 callback_steps 是否为正整数
    if (callback_steps is None) or (
        callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
    ):
        raise ValueError(
            f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
            f" {type(callback_steps)}."
        )

    # 检查 prompt 和 prompt_embeds 是否同时提供
    if prompt is not None and prompt_embeds is not None:
        raise ValueError(
            f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
            " only forward one of the two."
        )
    # 检查 prompt 和 prompt_embeds 是否都为 None
    elif prompt is None and prompt_embeds is None:
        raise ValueError(
            "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
        )
    # 检查 prompt 的类型是否为字符串或列表
    elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
        raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")

    # 检查 negative_prompt 和 negative_prompt_embeds 是否同时提供
    if negative_prompt is not None and negative_prompt_embeds is not None:
        raise ValueError(
            f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
            f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
        )

    # 检查 prompt_embeds 和 negative_prompt_embeds 的形状是否一致
    if prompt_embeds is not None and negative_prompt_embeds is not None:
        if prompt_embeds.shape != negative_prompt_embeds.shape:
            raise ValueError(
                "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
                f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
                f" {negative_prompt_embeds.shape}."
            )

    # 当有多个 ControlNet 时,检查 prompt 的处理
    if isinstance(self.controlnet, MultiControlNetModel):
        if isinstance(prompt, list):
            logger.warning(
                f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
                " prompts. The conditionings will be fixed across the prompts."
            )

    # 检查 image 的类型和数量
    is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
        self.controlnet, torch._dynamo.eval_frame.OptimizedModule
    )
    if (
        isinstance(self.controlnet, ControlNetModel)
        or is_compiled
        and isinstance(self.controlnet._orig_mod, ControlNetModel)
    ):
        self.check_image(image, prompt, prompt_embeds)
    elif (
        isinstance(self.controlnet, MultiControlNetModel)
        or is_compiled
        and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
    ):
        if not isinstance(image, list):
            raise TypeError("For multiple controlnets: `image` must be type `list`")
        elif any(isinstance(i, list) for i in image):
            raise ValueError("A single batch of multiple conditionings are supported at the moment.")
        elif len(image) != len(self.controlnet.nets):
            raise ValueError(
                f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
            )

        for image_ in image:
            self.check_image(image_, prompt, prompt_embeds)
    else:
        assert False

    # 检查 controlnet_conditioning_scale 的类型
    if (
        isinstance(self.controlnet, ControlNetModel)
        or is_compiled
        and isinstance(self.controlnet._orig_mod, ControlNetModel)
    ):
        if not isinstance(controlnet_conditioning_scale, float):
            raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
    elif (
        isinstance(self.controlnet, MultiControlNetModel)
        or is_compiled
        and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
    ):
        if isinstance(controlnet_conditioning_scale, list):
            if any(isinstance(i, list) for i in controlnet_conditioning_scale):
                raise ValueError("A single batch of multiple conditionings are supported at the moment.")
        elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
            self.controlnet.nets
        ):
            raise ValueError(
                "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
                " the same length as the number of controlnets"
            )
    else:
        assert False

    # 检查 control_guidance_start 和 control_guidance_end 的长度是否一致
    if len(control_guidance_start) != len(control_guidance_end):
        raise ValueError(
            f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
        )

    # 检查 control_guidance_start 的长度是否与 ControlNet 的数量一致
    if isinstance(self.controlnet, MultiControlNetModel):
        if len(control_guidance_start) != len(self.controlnet.nets):
            raise ValueError(
                f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
            )

    # 检查 control_guidance_start 和 control_guidance_end 的取值范围
    for start, end in zip(control_guidance_start, control_guidance_end):
        if start >= end:
            raise ValueError(
                f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
            )
        if start < 0.0:
            raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
        if end > 1.0:
            raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")

    def check_image(self, image, prompt, prompt_embeds):
        image_is_pil = isinstance(image, PIL.Image.Image)
        image_is_tensor = isinstance(image, torch.Tensor)
        image_is_np = isinstance(image, np.ndarray)
        image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
        image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
        image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)

        if (
            not image_is_pil
            and not image_is_tensor
            and not image_is_np
            and not image_is_pil_list
            and not image_is_tensor_list
            and not image_is_np_list
        ):
            raise TypeError(
                f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
            )

        if image_is_pil:
            image_batch_size = 1
        else:
            image_batch_size = len(image)

        if prompt is not None and isinstance(prompt, str):
            prompt_batch_size = 1
        elif prompt is not None and isinstance(prompt, list):
            prompt_batch_size = len(prompt)
        elif prompt_embeds is not None:
            prompt_batch_size = prompt_embeds.shape[0]

        if image_batch_size != 1 and image_batch_size != prompt_batch_size:
            raise ValueError(
                f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
            )

    def prepare_image(
        self,
        image,
        width,
        height,
        batch_size,
        num_images_per_prompt,
        device,
        dtype,
        do_classifier_free_guidance=False,
        guess_mode=False,
    ):
        image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
        #对输入的图像进行预处理,调整为指定的高度和宽度

        image_batch_size = image.shape[0]
        #获取预处理后的图像的批次大小。

        if image_batch_size == 1:
            repeat_by = batch_size
        else:
            # image batch size is the same as prompt batch size
            repeat_by = num_images_per_prompt

        image = image.repeat_interleave(repeat_by, dim=0)#在批次维度上重复图图像

        image = image.to(device=device, dtype=dtype)

        if do_classifier_free_guidance and not guess_mode:
            image = torch.cat([image] * 2)

        return image

    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
    def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
        #用于准备潜在空间张量,以便在生成模型中使用。让我们逐步分析这个方法的功能。
        shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
        #空间缩小
        if isinstance(generator, list) and len(generator) != batch_size:
            raise ValueError(
                f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
                f" size of {batch_size}. Make sure the batch size matches the length of the generators."
            )

        if latents is None:#如果提供,则直接使用该张量而不生成新的噪声。
            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
        else:
            latents = latents.to(device)

        # scale the initial noise by the standard deviation required by the scheduler
        latents = latents * self.scheduler.init_noise_sigma
        #对初始噪声进行缩放
        return latents

核心函数__call__()

方法的主要步骤包括:

检查输入参数。 定义调用参数,如设备和批次大小。 编码输入提示。 准备图像输入。 准备潜在变量。 准备额外的步骤参数。 执行去噪循环,生成图像。 运行安全性检查器。 后处理图像并返回结果。

参数

      @torch.no_grad()
    @replace_example_docstring(EXAMPLE_DOC_STRING)
    def __call__(
        self,
        prompt: Union[str, List[str]] = None,#文本提示
        image: Union[# ControlNet输入条件
            torch.FloatTensor,
            PIL.Image.Image,
            np.ndarray,
            List[torch.FloatTensor],
            List[PIL.Image.Image],
            List[np.ndarray],
        ] = None,
        height: Optional[int] = None,
        width: Optional[int] = None,
        num_inference_steps: int = 50,
        guidance_scale: float = 7.5,
        negative_prompt: Optional[Union[str, List[str]]] = None,
        num_images_per_prompt: Optional[int] = 1,
        eta: float = 0.0,#DDIM算法中的参数
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        latents: Optional[torch.FloatTensor] = None,
        prompt_embeds: Optional[torch.FloatTensor] = None,
        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
        output_type: Optional[str] = "pil",
        return_dict: bool = True,
        callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
        callback_steps: int = 1,
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
        controlnet_conditioning_scale: Union[float, List[float]] = 1.0,#用于缩放 ControlNet 的条件特征图的强度
        guess_mode: bool = False,
        control_guidance_start: Union[float, List[float]] = 0.0,#用于指定控制引导强度的起始值
        control_guidance_end: Union[float, List[float]] = 1.0,#这个参数用于指定控制引导强度的结束值
    ):
        r"""
        Function invoked when calling the pipeline for generation.

        Args:
            prompt (`str` or `List[str]`, *optional*):
                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
                instead.
            image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
                    `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
                The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
                the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can
                also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If
                height and/or width are passed, `image` is resized according to them. If multiple ControlNets are
                specified in init, images must be passed as a list such that each element of the list can be correctly
                batched for input to a single controlnet.
            height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
                The height in pixels of the generated image.
            width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
                The width in pixels of the generated image.
            num_inference_steps (`int`, *optional*, defaults to 50):
                The number of denoising steps. More denoising steps usually lead to a higher quality image at the
                expense of slower inference.
            guidance_scale (`float`, *optional*, defaults to 7.5):
                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
                `guidance_scale` is defined as `w` of equation 2. of [Imagen
                Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
                usually at the expense of lower image quality.
            negative_prompt (`str` or `List[str]`, *optional*):
                The prompt or prompts not to guide the image generation. If not defined, one has to pass
                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
                less than `1`).
            num_images_per_prompt (`int`, *optional*, defaults to 1):
                The number of images to generate per prompt.
            eta (`float`, *optional*, defaults to 0.0):
                Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
                [`schedulers.DDIMScheduler`], will be ignored for others.
            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
                to make generation deterministic.
            latents (`torch.FloatTensor`, *optional*):
                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
                tensor will ge generated by sampling using the supplied random `generator`.
            prompt_embeds (`torch.FloatTensor`, *optional*):
                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
                provided, text embeddings will be generated from `prompt` input argument.
            negative_prompt_embeds (`torch.FloatTensor`, *optional*):
                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
                argument.
            output_type (`str`, *optional*, defaults to `"pil"`):
                The output format of the generate image. Choose between
                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
                plain tuple.
            callback (`Callable`, *optional*):
                A function that will be called every `callback_steps` steps during inference. The function will be
                called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
            callback_steps (`int`, *optional*, defaults to 1):
                The frequency at which the `callback` function will be called. If not specified, the callback will be
                called at every step.
            cross_attention_kwargs (`dict`, *optional*):
                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
                `self.processor` in
                [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
            controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
                The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
                to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
                corresponding scale as a list.
            guess_mode (`bool`, *optional*, defaults to `False`):
                In this mode, the ControlNet encoder will try best to recognize the content of the input image even if
                you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.
            control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
                The percentage of total steps at which the controlnet starts applying.
            control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
                The percentage of total steps at which the controlnet stops applying.

        Examples:

        Returns:
            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
            When returning a tuple, the first element is a list with the generated images, and the second element is a
            list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
            (nsfw) content, according to the `safety_checker`.
        """
        controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet

        # align format for control guidance
        if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
            control_guidance_start = len(control_guidance_end) * [control_guidance_start]
        elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
            control_guidance_end = len(control_guidance_start) * [control_guidance_end]
        elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
            mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
            control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [
                control_guidance_end
            ]

        # 1. Check inputs. Raise error if not correct
        self.check_inputs(
            prompt,
            image,
            callback_steps,
            negative_prompt,
            prompt_embeds,
            negative_prompt_embeds,
            controlnet_conditioning_scale,
            control_guidance_start,
            control_guidance_end,
        )

       


定义调用参数,准备图像,时间步长,潜变量和额外步骤参数

定义调用参数:

根据输入的 prompt 类型确定批处理大小 batch_size。

判断是否需要进行分类器自由引导,根据 guidance_scale 的值是否大于 1.0 来确定。

确定是否处于猜测模式 guess_mode,根据全局池化条件和 guess_mode 参数的值确定。

编码输入提示:

如果存在 cross_attention_kwargs,则从中获取 scale 键对应的值,赋给 text_encoder_lora_scale。 调用 _encode_prompt 方法对输入提示进行编码,生成 prompt_embeds。

准备图像,时间步长,潜变量和额外步骤参数:

根据 controlnet 的类型,准备输入的图像数据。

使用调度器 scheduler 设置推理步数 num_inference_steps,并获取时间步长信息。

准备额外的步骤参数,包括随机生成器和参数 eta。

 # 2. Define call parameters
        if prompt is not None and isinstance(prompt, str):
            batch_size = 1
        elif prompt is not None and isinstance(prompt, list):
            batch_size = len(prompt)
        else:
            batch_size = prompt_embeds.shape[0]

        device = self._execution_device
        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
        # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
        # corresponds to doing no classifier free guidance.
        do_classifier_free_guidance = guidance_scale > 1.0

        if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
            controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)

        global_pool_conditions = (
            controlnet.config.global_pool_conditions
            if isinstance(controlnet, ControlNetModel)
            else controlnet.nets[0].config.global_pool_conditions
        )
        guess_mode = guess_mode or global_pool_conditions

        # 3. Encode input prompt
        text_encoder_lora_scale = (
           #如果 cross_attention_kwargs 不为 None,则从其中获取 "scale" 键对应的值,并将其赋值给 text_encoder_lora_scale

            cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
        )

        prompt_embeds = self._encode_prompt(
            #进行编码
            prompt,
            device,
            num_images_per_prompt,
            do_classifier_free_guidance,
            negative_prompt,
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
            lora_scale=text_encoder_lora_scale,
        )

        # 4. Prepare image
        if isinstance(controlnet, ControlNetModel):
            image = self.prepare_image(
                image=image,
                width=width,
                height=height,
                batch_size=batch_size * num_images_per_prompt,
                num_images_per_prompt=num_images_per_prompt,
                device=device,
                dtype=controlnet.dtype,
                do_classifier_free_guidance=do_classifier_free_guidance,
                guess_mode=guess_mode,
            )
            height, width = image.shape[-2:]
        elif isinstance(controlnet, MultiControlNetModel):
            images = []

            for image_ in image:
                image_ = self.prepare_image(
                    image=image_,
                    width=width,
                    height=height,
                    batch_size=batch_size * num_images_per_prompt,
                    num_images_per_prompt=num_images_per_prompt,
                    device=device,
                    dtype=controlnet.dtype,
                    do_classifier_free_guidance=do_classifier_free_guidance,
                    guess_mode=guess_mode,
                )

                images.append(image_)

            image = images
            height, width = image[0].shape[-2:]
        else:
            assert False

        # 5. Prepare timesteps
        self.scheduler.set_timesteps(num_inference_steps, device=device)
        timesteps = self.scheduler.timesteps

        # 6. Prepare latent variables
        num_channels_latents = self.unet.config.in_channels
        latents = self.prepare_latents(
            batch_size * num_images_per_prompt,
            num_channels_latents,
            height,
            width,
            prompt_embeds.dtype,
            device,
            generator,
            latents,
        )

        # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)

       

生成模型进行去噪循环,最终生成图像

 # 7.1 Create tensor stating which controlnets to keep
        controlnet_keep = []#控制引导
        for i in range(len(timesteps)):
            keeps = [
                1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
                for s, e in zip(control_guidance_start, control_guidance_end)
            ]
            #对于每一对开始值 s 和结束值 e,计算当前时间步 i 是否在开始值和结束值之间。
            #如果是,计算一个介于 0 到 1 之间的比例,表示控制信号的衰减程度。
            controlnet_keep.append(keeps[0] if len(keeps) == 1 else keeps)
            #列表包含了每个时间步对应的控制信号强度,
            #这些强度在扩散模型的迭代过程中用于调整控制网络的输出,从而影响生成图像的内容

        # 8. Denoising loop
            
        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
            #计算预热步数,用于控制模型的初始化,。热身步数通常是总时间步数减去推理步数乘以调度器的顺序

        with self.progress_bar(total=num_inference_steps) as progress_bar:
            # 创建进度条,总步数为num_inference_steps

            for i, t in enumerate(timesteps):
                # 遍历时间步长timesteps中的索引和值

                # expand the latents if we are doing classifier free guidance
                latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
                # 如果进行无分类器引导,则扩展潜变量;否则使用原始潜变量

                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
                # 对潜变量进行缩放处理,根据时间步t

                # controlnet(s) inference
                # ControlNet的推理过程
                if guess_mode and do_classifier_free_guidance:
                    # 如果是猜测模式且进行无分类器引导

                    # Infer ControlNet only for the conditional batch,只有在条件批次中才推断ControlNet
                    control_model_input = latents
                    control_model_input = self.scheduler.scale_model_input(control_model_input, t)
                    controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
                    #prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
                    # 仅为条件批次推理ControlNet,设置控制模型输入和控制网络提示嵌入

                else:
                    control_model_input = latent_model_input
                    controlnet_prompt_embeds = prompt_embeds
                    # 否则,设置控制模型输入为潜变量输入,控制网络提示嵌入为提示嵌入

                if isinstance(controlnet_keep[i], list):
                    cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
                else:
                    cond_scale = controlnet_conditioning_scale * controlnet_keep[i]
                # 根据controlnet_keep的类型,设置条件缩放

                down_block_res_samples, mid_block_res_sample = self.controlnet(
                    control_model_input,
                    t,
                    encoder_hidden_states=controlnet_prompt_embeds,
                    controlnet_cond=image,
                    conditioning_scale=cond_scale,
                    guess_mode=guess_mode,
                    return_dict=False,
                )
                # 使用controlnet进行推理,得到下采样块样本和中间块样本

                if guess_mode and do_classifier_free_guidance:
                    # Infered ControlNet only for the conditional batch.
                    # To apply the output of ControlNet to both the unconditional and conditional batches,
                    # add 0 to the unconditional batch to keep it unchanged.
                    down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
                    #通过在无条件批次中添加全零张量
                    mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])

                # predict the noise residual
                noise_pred = self.unet(
                    latent_model_input,
                    t,
                    encoder_hidden_states=prompt_embeds,
                    cross_attention_kwargs=cross_attention_kwargs,
                    down_block_additional_residuals=down_block_res_samples,
                    mid_block_additional_residual=mid_block_res_sample,
                    return_dict=False,
                )[0]

                # perform guidance
                if do_classifier_free_guidance:
                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

                # compute the previous noisy sample x_t -> x_t-1
                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]

                # call the callback, if provided
                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
                    progress_bar.update()
                    if callback is not None and i % callback_steps == 0:
                        callback(i, t, latents)

        # If we do sequential model offloading, let's offload unet and controlnet
        # manually for max memory savings
        if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
            self.unet.to("cpu")
            self.controlnet.to("cpu")
            torch.cuda.empty_cache()

        if not output_type == "latent":
            #解码
            image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
            image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
        else:
            image = latents
            has_nsfw_concept = None

        if has_nsfw_concept is None:
            do_denormalize = [True] * image.shape[0]
        else:
            do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]

        image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)

        # Offload last model to CPU
        if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
            self.final_offload_hook.offload()

        if not return_dict:
            return (image, has_nsfw_concept)

        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)

更新时间 2024-05-24