当前位置:AIGC资讯 > AIGC > 正文

pipeline-stable-diffusion.py文件逐行解释

本文是对stabled-diffusion的pipeline文件的代码逐行解释。

60-71行

该函数对经过cfg重组出来的noise_pred,再重组。

def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
    """
    Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
    Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
    """
    """
    函数目的是对cfg出来的noise_pred再调整
    参数:
    noise_pred_text 是由Unet预测出的noise_pred,按第0维平分的后一半
        具体:noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
        
    noise_cfg: Unet计算出的noise_pred再经过guidance_scale的组合
        具体:noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
    
    guidance_rescale: cgf权重参数
    """
    
    # 计算noise_pred_text 除第0维外所有维度的std
    std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
    # 计算noise_cfg 除第0维外所有维度的std
    std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
    # rescale the results from guidance (fixes overexposure)
    # 调整noise_cfg,权重为std_text / std_cfg 
    noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
    # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
    # 将新算出的noise_cfg与原noise_cfg重组
    noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
    return noise_cfg
    

74-115行

获取时间步和推断的步数

def retrieve_timesteps(
    scheduler, # 调度器,用于获取时间步
    num_inference_steps: Optional[int] = None, # 推断的步数,有则timesteps为None
    device: Optional[Union[str, torch.device]] = None,
    timesteps: Optional[List[int]] = None,  # 自定的时间步,有则推断步数需为None,无则使用默认的
    **kwargs,
):
    """
    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.

    Args:
        scheduler (`SchedulerMixin`):
            The scheduler to get timesteps from.
        num_inference_steps (`int`):
            The number of diffusion steps used when generating samples with a pre-trained model. If used,
            `timesteps` must be `None`.
        device (`str` or `torch.device`, *optional*):
            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
        timesteps (`List[int]`, *optional*):
                Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
                timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
                must be `None`.

    Returns:
        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
        second element is the number of inference steps.
    """
    # 如果有自定的时间步,则要根据此得出要推断的步数
    if timesteps is not None:
        # 先检查调度器能否接受自定timesteps,即有没有timesteps的参数名
        # inspect.signature 返回函数输入参数的键值对
        # inspect.signature(scheduler.set_timesteps).parameters.keys() 返回set_timesteps函数中的参数名
        accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
        # 无则报错
        if not accepts_timesteps:
            raise ValueError(
                f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
                f" timestep schedules. Please check whether you are using the correct scheduler."
            )
        # 有则按自定义timesteps设置调度器,然后得出推断步数
        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
        timesteps = scheduler.timesteps
        num_inference_steps = len(timesteps)
    # 如果没有自定的时间步,则按推断步数得出默认时间步
    else:
        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
        timesteps = scheduler.timesteps
    return timesteps, num_inference_steps

118-250行

输入参数,调整参数、判断unet版本和unet输入的size

class StableDiffusionPipeline(
    DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin
):
    r"""
    Pipeline for text-to-image generation using Stable Diffusion.

    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
    implemented for all pipelines (downloading, saving, running on a particular device, etc.).

    The pipeline also inherits the following loading methods:
        - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
        - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
        - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
        - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
        - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters

    Args:
        vae ([`AutoencoderKL`]):
            Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
        text_encoder ([`~transformers.CLIPTextModel`]):
            Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
        tokenizer ([`~transformers.CLIPTokenizer`]):
            A `CLIPTokenizer` to tokenize text.
        unet ([`UNet2DConditionModel`]):
            A `UNet2DConditionModel` to denoise the encoded image latents.
        scheduler ([`SchedulerMixin`]):
            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
        safety_checker ([`StableDiffusionSafetyChecker`]):
            Classification module that estimates whether generated images could be considered offensive or harmful.
            Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
            about a model's potential harms.
        feature_extractor ([`~transformers.CLIPImageProcessor`]):
            A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
    """

    model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
    _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
    _exclude_from_cpu_offload = ["safety_checker"]
    _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]

    def __init__(
        self,
        vae: AutoencoderKL,  # 用作对图片做encoder和decoder
        text_encoder: CLIPTextModel, # 对文本提示做token
        tokenizer: CLIPTokenizer,    # 对token做embed
        unet: UNet2DConditionModel,  
        scheduler: KarrasDiffusionSchedulers, # 调度器
        safety_checker: StableDiffusionSafetyChecker, # 安全检查
        feature_extractor: CLIPImageProcessor,  # 特征提取,对输出图片作用,然后输进safety_checker
        image_encoder: CLIPVisionModelWithProjection = None,  # 对参考图做embed
        requires_safety_checker: bool = True,
    ):
        super().__init__()
        
        # 若scheduler有steps_offset属性,且steps_offset不为1时提出告示,并将steps_offset设为1
        if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
            deprecation_message = (
                f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
                f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
                "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
                " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
                " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
                " file"
            )
            deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
            new_config = dict(scheduler.config)
            new_config["steps_offset"] = 1
            scheduler._internal_dict = FrozenDict(new_config)
        
        # 若schedular有clip_sample属性,提出警告,并设置为False
        if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
            deprecation_message = (
                f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
                " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
                " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
                " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
                " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
            )
            deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
            new_config = dict(scheduler.config)
            new_config["clip_sample"] = False
            scheduler._internal_dict = FrozenDict(new_config)
        
        # 若需要安全检查,但是没有安全检查函数,提出警告
        if safety_checker is None and requires_safety_checker:
            logger.warning(
                f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
                " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
                " results in services or applications open to the public. Both the diffusers team and Hugging Face"
                " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
                " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
                " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
            )
        
        # 若有安全检查,但是没有feature_extractor,提出警告,因为安全检查的输入是由feature_extractor
        if safety_checker is not None and feature_extractor is None:
            raise ValueError(
                "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
                " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
            )
        
        # 检查Unet版本,通过Unet配置文件中是否有_diffusers_version属性,需要版本小于0.9
        # unet为0.6.0
        is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
            version.parse(unet.config._diffusers_version).base_version
        ) < version.parse("0.9.0.dev0")
        
        # 检查unet输入size是不是小于64
        is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
        # 版本大于0.9且输入size小于64,提出警告,并设置unet输入size为64
        if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
            deprecation_message = (
                "The configuration file of the unet has set the default `sample_size` to smaller than"
                " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
                " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
                " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
                " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
                " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
                " in the config might lead to incorrect results in future versions. If you have downloaded this"
                " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
                " the `unet/config.json` file"
            )
            deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
            new_config = dict(unet.config)
            new_config["sample_size"] = 64
            unet._internal_dict = FrozenDict(new_config)

使用register_modules,将参数设为内部属性。

self.register_modules(
        vae=vae,
        text_encoder=text_encoder,
        tokenizer=tokenizer,
        unet=unet,
        scheduler=scheduler,
        safety_checker=safety_checker,
        feature_extractor=feature_extractor,
        image_encoder=image_encoder,
    )
    
# 设置vae采样倍数
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
# 实例化VaeImageProcessor
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)

251-279行

分片VAE编码和分块VAE编码

def enable_vae_slicing(self):
    # 开启分片VAE编码,即按顺序一次只对一张图片编码。
    # 能用于大batch时候减少VRAM,但是速度会减慢
    r"""
    Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
    compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
    """
    self.vae.enable_slicing()

def disable_vae_slicing(self):
    # 关闭分片VAE编码
    r"""
    Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
    computing decoding in one step.
    """
    self.vae.disable_slicing()

def enable_vae_tiling(self):
    # 开启分片VAE编码,即将图像分成重叠的块,对每个块解码,最后将输出混合生成最终图像。
    # 能用于处理大尺寸图像
    r"""
    Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
    compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
    processing larger images.
    """
    self.vae.enable_tiling()

def disable_vae_tiling(self):
    # 关闭分片VAE编码
    r"""
    Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
    computing decoding in one step.
    """
    self.vae.disable_tiling()

280-491行

对文本提示编码,这段已经不使用了。用下一段的。

def _encode_prompt(
    self,
    prompt,
    device,
    num_images_per_prompt,
    do_classifier_free_guidance,
    negative_prompt=None,
    prompt_embeds: Optional[torch.FloatTensor] = None,
    negative_prompt_embeds: Optional[torch.FloatTensor] = None,
    lora_scale: Optional[float] = None,
    **kwargs,
):
    deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
    deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)

    prompt_embeds_tuple = self.encode_prompt(
        prompt=prompt,
        device=device,
        num_images_per_prompt=num_images_per_prompt,
        do_classifier_free_guidance=do_classifier_free_guidance,
        negative_prompt=negative_prompt,
        prompt_embeds=prompt_embeds,
        negative_prompt_embeds=negative_prompt_embeds,
        lora_scale=lora_scale,
        **kwargs,
    )

    # concatenate for backwards comp
    prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])

    return prompt_embeds

新的文本编码,输出有条件embed和无条件embed

def encode_prompt(
    self,
    prompt,  # 文本提示
    device,
    num_images_per_prompt,  # 每个提示生成的文本数量
    do_classifier_free_guidance,  # cfg
    negative_prompt=None,   # 负向词
    prompt_embeds: Optional[torch.FloatTensor] = None,  # 预设的文本embed
    negative_prompt_embeds: Optional[torch.FloatTensor] = None, # 预设的负向文本embed
    lora_scale: Optional[float] = None,  # lora_scale
    clip_skip: Optional[int] = None, # clip输出要跳过的层数
):
    r"""
    Encodes the prompt into text encoder hidden states.

    Args:
        prompt (`str` or `List[str]`, *optional*):
            prompt to be encoded
        device: (`torch.device`):
            torch device
        num_images_per_prompt (`int`):
            number of images that should be generated per prompt
        do_classifier_free_guidance (`bool`):
            whether to use classifier free guidance or not
        negative_prompt (`str` or `List[str]`, *optional*):
            The prompt or prompts not to guide the image generation. If not defined, one has to pass
            `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
            less than `1`).
        prompt_embeds (`torch.FloatTensor`, *optional*):
            Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
            provided, text embeddings will be generated from `prompt` input argument.
        negative_prompt_embeds (`torch.FloatTensor`, *optional*):
            Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
            weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
            argument.
        lora_scale (`float`, *optional*):
            A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
        clip_skip (`int`, *optional*):
            Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
            the output of the pre-final layer will be used for computing the prompt embeddings.
    """
    # set lora scale so that monkey patched LoRA
    # function of text encoder can correctly access it
    # 设置text_encoder中的lora_scale
    if lora_scale is not None and isinstance(self, LoraLoaderMixin):
        self._lora_scale = lora_scale

        # dynamically adjust the LoRA scale
        if not USE_PEFT_BACKEND:
            adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
        else:
            scale_lora_layers(self.text_encoder, lora_scale)
    
    # 根据prompt数量设置batch_size
    if prompt is not None and isinstance(prompt, str):
        batch_size = 1
    elif prompt is not None and isinstance(prompt, list):
        batch_size = len(prompt)
    else:
        batch_size = prompt_embeds.shape[0]
    
    # 若没有预设的prompt_embed
    if prompt_embeds is None:
        # textual inversion: procecss multi-vector tokens if necessary
        # 若当前对象是TextualInversionLoaderMixin的实例,
        # 那么会调用maybe_convert_prompt函数对多向量标记进行处理,以进行文本反转。
        if isinstance(self, TextualInversionLoaderMixin):
            prompt = self.maybe_convert_prompt(prompt, self.tokenizer)

        # 文本编码,输出是字典形式,有input_ids,attention_mask
        # input_ids是token, attention_mask指出哪些是输入文本,哪些是填充
        text_inputs = self.tokenizer(
            prompt,
            padding="max_length",
            max_length=self.tokenizer.model_max_length,
            truncation=True,
            return_tensors="pt",
        )
        # 获取token,即input_ids
        text_input_ids = text_inputs.input_ids
        # 使用longest,获取没截断的文本tokens
        untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
        # 如果没截断的文本tokens大于tokens,获取removed_text,即被截断的文本。
        if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
            text_input_ids, untruncated_ids
        ):
            # 对被截断的文本token做decoder
            removed_text = self.tokenizer.batch_decode(
                untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
            )
            logger.warning(
                "The following part of your input was truncated because CLIP can only handle sequences up to"
                f" {self.tokenizer.model_max_length} tokens: {removed_text}"
            )
        
        # 若文本编码器有use_attention_mask参数,且use_attention_mask不为空,获取attn_mask
        if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
            attention_mask = text_inputs.attention_mask.to(device)
        else:
            attention_mask = None
        
        # 获取text_embed,没有clip_skip,直接输出
        if clip_skip is None:
            prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
            prompt_embeds = prompt_embeds[0]
        # 有clip_skip
        else:
            prompt_embeds = self.text_encoder(
                text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
            )
            # Access the `hidden_states` first, that contains a tuple of
            # all the hidden states from the encoder layers. Then index into
            # the tuple to access the hidden states from the desired layer.
            # keys中有['last_hidden_state', 'pooler_output', 'hidden_states']
            # 选择hidden_states,-(clip_skip+1)表示输出倒数第几层
            prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
            # We also need to apply the final LayerNorm here to not mess with the
            # representations. The `last_hidden_states` that we typically use for
            # obtaining the final prompt representations passes through the LayerNorm
            # layer.
            # 对输出的hidden_states做最后一层的layer_norm,使分布一样
            prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
    
    # 根据模型提取数据格式,将text_embed数据格式统一化
    if self.text_encoder is not None:
        prompt_embeds_dtype = self.text_encoder.dtype
    elif self.unet is not None:
        prompt_embeds_dtype = self.unet.dtype
    else:
        prompt_embeds_dtype = prompt_embeds.dtype
    
    # 设置数据格式
    prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
    
    # 获取batch_size,序列长度
    bs_embed, seq_len, _ = prompt_embeds.shape
    # duplicate text embeddings for each generation per prompt, using mps friendly method
    # 按序列长度维度重复num_images_per_prompt次
    prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
    # 将[bs_embed, num_images_per_prompt*seq_len, channel] -> [bs_embed*num_images_per_prompt, seq_len, channel]
    prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)

    # get unconditional embeddings for classifier free guidance
    # 获取无条件tokens
    # 1.若要cfg而没有预设的负向embed
    # 2.若负向prompt为空,用空文本代替。否则用负向prompt
    if do_classifier_free_guidance and negative_prompt_embeds is None:
        uncond_tokens: List[str]
        if negative_prompt is None:
            uncond_tokens = [""] * batch_size
        # 判断prompt格式和neg_prompt格式
        elif prompt is not None and type(prompt) is not type(negative_prompt):
            raise TypeError(
                f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
                f" {type(prompt)}."
            )
        # 根据neg_prompt格式设定无条件tokens
        elif isinstance(negative_prompt, str):
            uncond_tokens = [negative_prompt]
        elif batch_size != len(negative_prompt):
            raise ValueError(
                f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
                f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
                " the batch size of `prompt`."
            )
        else:
            uncond_tokens = negative_prompt

        # textual inversion: procecss multi-vector tokens if necessary
        # tokens反推文本。
        if isinstance(self, TextualInversionLoaderMixin):
            uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
        
        # 重新提取无条件tokens
        max_length = prompt_embeds.shape[1]
        uncond_input = self.tokenizer(
            uncond_tokens,
            padding="max_length",
            max_length=max_length,
            truncation=True,
            return_tensors="pt",
        )
        
        # 提取无条件tokens的attention_mask
        if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
            attention_mask = uncond_input.attention_mask.to(device)
        else:
            attention_mask = None
        
        # 用无条件tokens的embed做负向embed
        negative_prompt_embeds = self.text_encoder(
            uncond_input.input_ids.to(device),
            attention_mask=attention_mask,
        )
        negative_prompt_embeds = negative_prompt_embeds[0]
    
    # 若有cfg,复制负向embed,和正向embed一样。
    if do_classifier_free_guidance:
        # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
        seq_len = negative_prompt_embeds.shape[1]

        negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)

        negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
        negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)

    if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
        # Retrieve the original scale by scaling back the LoRA layers
        unscale_lora_layers(self.text_encoder, lora_scale)

    return prompt_embeds, negative_prompt_embeds

493-626行

对参考图作编码,输出有条件embed和无条件embed

def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
    """
    对图片做encoder,输出image_embeds和ucond_image_embeds
    Args:
        image:参考图(ip_adapter_image)
        num_images_per_prompt:每个prompt生成的图片数量
        output_hidden_states:是否输出隐藏状态,具体为倒数第二层(hidden_states[-2])
             = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
    """
    # 获取图像编码器的数据类型,以便统一后续image数据
    dtype = next(self.image_encoder.parameters()).dtype
    
    # 判断输入image是不是Tensor,不是则用feature_extractor(CLIPImageProcessor)函数转Tensor
    if not isinstance(image, torch.Tensor):
        image = self.feature_extractor(image, return_tensors="pt").pixel_values
    # 将image存入device,修改数据类型
    image = image.to(device=device, dtype=dtype)
    
    # 输出隐藏状态
    if output_hidden_states:
        # 获取倒数第二层的隐藏状态输出
        image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
        # 在第0维重复num_images_per_prompt次,用作输出num_images_per_prompt张图
        image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
        # 以零矩阵作为无条件输入图像编码器
        uncond_image_enc_hidden_states = self.image_encoder(
            torch.zeros_like(image), output_hidden_states=True
        ).hidden_states[-2]
        uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
            num_images_per_prompt, dim=0
        )
        return image_enc_hidden_states, uncond_image_enc_hidden_states
    # 不输出隐藏状态
    else:
        # 直接输出image_embed
        image_embeds = self.image_encoder(image).image_embeds
        # 重复堆叠num_images_per_prompt次
        image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
        # 直接以零矩阵作为无条件编码
        uncond_image_embeds = torch.zeros_like(image_embeds)

        return image_embeds, uncond_image_embeds

对最终输出的图片作安全检查,输出image和has_nsfw_concept。

def run_safety_checker(self, image, device, dtype):
    # 若不需要安全检查,则设置has_nsfw_concept为None
    if self.safety_checker is None:
        has_nsfw_concept = None
    # 需要安全检查
    # 先将image转为pil格式,然后再用CLIP提取特征,最后将image和特征输入安全检查器
    else:
        # 若image是Tensor,用后处理将image转为pil
        if torch.is_tensor(image):
            feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
        # 若image不是Tensor,用numpy_to_pil转为pil
        else:
            feature_extractor_input = self.image_processor.numpy_to_pil(image)
        # 用feature_extractor(CLIPImageProcessor)提取特征
        safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
        # 安全检查
        image, has_nsfw_concept = self.safety_checker(
            images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
        )
    return image, has_nsfw_concept

对潜变量作decoder,此方法已经在1.0.0被移除,用VaeImageProcessor.postprocess代替。

def decode_latents(self, latents):
    """
    Args:
        latents:隐空间变量
    """
    # 提出警告,此方法在1.0.0被移除,用VaeImageProcessor.postprocess代替
    deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
    deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
    
    # 要除以rescale系数,系数是0.18215
    # diffusion论文中提出了一种rescaling方法:首先计算出第一个batch数据中的latent的标准差σ^,
    # 然后采用1/σ^的系数来rescale latent,
    # 这样就尽量保证latent的标准差接近1(
    # 防止扩散过程的SNR较高,影响生成效果,具体见latent diffusion论文的D1部分讨论)
    latents = 1 / self.vae.config.scaling_factor * latents
    image = self.vae.decode(latents, return_dict=False)[0]
    # 将像素值从[-1, 1]的范围映射到[0, 1]的范围。
    image = (image / 2 + 0.5).clamp(0, 1)
    # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
    image = image.cpu().permute(0, 2, 3, 1).float().numpy()
    return image

准备调度器额外的关键参数,eta和generator。

 def prepare_extra_step_kwargs(self, generator, eta):
    # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
    # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
    # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
    # and should be between [0, 1]
    
    # 检查调度器函数是否有'eta'关键字,只有DDIM要用
    accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
    extra_step_kwargs = {}
    if accepts_eta:
        extra_step_kwargs["eta"] = eta
    
    # 检查调度器函数是否有'generator'关键字
    # check if the scheduler accepts generator
    accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
    if accepts_generator:
        extra_step_kwargs["generator"] = generator
    return extra_step_kwargs

检查输入是否正确

def check_inputs(
    self,
    prompt, # 提示词
    height, # 图像高
    width,  # 图像宽
    callback_steps, # 表示在进行模型推断过程中,每经过callback_steps个步骤后,会执行一次回调操作。回调操作可以是对模型状态的检查、记录或其他自定义操作,用于监控和控制模型的行为。
    negative_prompt=None, # 负向词
    prompt_embeds=None, # 指定提示词embed
    negative_prompt_embeds=None, # 指定负向词embed
    callback_on_step_end_tensor_inputs=None, # 指定在每个步骤结束时要传递给回调函数的张量输入。
):
    # 长宽需8的倍数
    if height % 8 != 0 or width % 8 != 0:
        raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
    
    # callback_steps需正整数
    if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
        raise ValueError(
            f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
            f" {type(callback_steps)}."
        )
      
    # callback_on_step_end_tensor_inputs必须在self._callback_tensor_inputs中
    # _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
    if callback_on_step_end_tensor_inputs is not None and not all(
        k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
    ):
        raise ValueError(
            f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
        )
     
    # 提示词和提示词embed必须一个非空一个为空,且提示词必须是str或者list格式
    if prompt is not None and prompt_embeds is not None:
        raise ValueError(
            f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
            " only forward one of the two."
        )
    elif prompt is None and prompt_embeds is None:
        raise ValueError(
            "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
        )
    elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
        raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
    
    # 负向提示词和负向提示词embed不能同时有
    if negative_prompt is not None and negative_prompt_embeds is not None:
        raise ValueError(
            f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
            f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
        )
     
    # 提示词的shape和负向提示词的shape要一样 
    if prompt_embeds is not None and negative_prompt_embeds is not None:
        if prompt_embeds.shape != negative_prompt_embeds.shape:
            raise ValueError(
                "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
                f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
                f" {negative_prompt_embeds.shape}."

准备潜变量,用于StableDiffusion模型的推断过程。

def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
    """
    Args:
        generator:torch的随机数生成器
    """
    # 定义潜变量的shape, self.vae_scale_factor为VAE下采样倍数
    shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
    
    # 如果generator不是list或者长度不等于batch_size,报错
    if isinstance(generator, list) and len(generator) != batch_size:
        raise ValueError(
            f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
            f" size of {batch_size}. Make sure the batch size matches the length of the generators."
        )
    
    # 若潜变量为空,则随机生成高斯噪声。不为空,则直接用
    if latents is None:
        latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
    else:
        latents = latents.to(device)

    # scale the initial noise by the standard deviation required by the scheduler
    # 将初始化噪声乘以scheduler.init_noise_sigma,以缩放噪声的标准差。
    latents = latents * self.scheduler.init_noise_sigma
    return latents

628-776行

这段代码定义了一个enable_freeu方法,用于启用Unet的FreeU机制。FreeU机制是一个用于增强去噪过程的方法.

该方法接受四个参数:

s1:用于衰减跳跃特征贡献的阶段1的缩放因子,用于减轻增强去噪过程中的“过度平滑效应”。

s2:用于衰减跳跃特征贡献的阶段2的缩放因子,用于减轻增强去噪过程中的“过度平滑效应”。

b1:用于放大主干特征贡献的阶段1的缩放因子。

b2:用于放大主干特征贡献的阶段2的缩放因子。

def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
    r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.

    The suffixes after the scaling factors represent the stages where they are being applied.

    Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
    that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.

    Args:
        s1 (`float`):
            Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
            mitigate "oversmoothing effect" in the enhanced denoising process.
        s2 (`float`):
            Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
            mitigate "oversmoothing effect" in the enhanced denoising process.
        b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
        b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
    """
    if not hasattr(self, "unet"):
        raise ValueError("The pipeline must have `unet` for using FreeU.")
    self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)


# 关闭freeu方法
def disable_freeu(self):
    """Disables the FreeU mechanism if enabled."""
    self.unet.disable_freeu()

启用融合的qky投影。对于self attn,qkv都融合。对于cross-att,ky融合。

具体来说,FusedAttnProcessor2_0使用更高效的矩阵乘法计算,将查询、键和值的投影矩阵合并为一个更大的投影矩阵,从而减少了计算的复杂度和内存消耗。

# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.fuse_qkv_projections
def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):
    """
    Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
    key, value) are fused. For cross-attention modules, key and value projection matrices are fused.

    <Tip warning={true}>

    This API is 🧪 experimental.

    </Tip>

    Args:
        unet (`bool`, defaults to `True`): To apply fusion on the UNet.
        vae (`bool`, defaults to `True`): To apply fusion on the VAE.
    """
    # 先设为False
    self.fusing_unet = False
    self.fusing_vae = False
    
    # 设置unet的融合attn,即FusedAttnProcessor2_0()
    if unet:
        self.fusing_unet = True
        self.unet.fuse_qkv_projections()
        self.unet.set_attn_processor(FusedAttnProcessor2_0())
    
    # 设置vae的融合attn。警告:只有AutoencoderKL才能用
    if vae:
        if not isinstance(self.vae, AutoencoderKL):
            raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.")

        self.fusing_vae = True
        self.vae.fuse_qkv_projections()
        self.vae.set_attn_processor(FusedAttnProcessor2_0())

用于生成指导尺度嵌入向量

 # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
    """
    See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298

    Args:
        timesteps (`torch.Tensor`):
            generate embedding vectors at these timesteps
        embedding_dim (`int`, *optional*, defaults to 512):
            dimension of the embeddings to generate
        dtype:
            data type of the generated embeddings

    Returns:
        `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
    """
    # 确保timesteps长度为1,扩大1000倍
    assert len(w.shape) == 1
    w = w * 1000.0
    
    half_dim = embedding_dim // 2
    # log(10000) / (half_dim - 1)
    emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)

    # 生成从零到half_dim的列表,乘以-emb,再做exp
    emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)  # half_dim个数
    
    # 用广播机制,相乘
    emb = w.to(dtype)[:, None] * emb[None, :]              # [1, half_dim]
    # concat
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)  # [1, embedding_dim]
    # 若embedding_dim奇数则要padding
    if embedding_dim % 2 == 1:  # zero pad
        emb = torch.nn.functional.pad(emb, (0, 1))
    
    # 确定输出emb的shape是对的
    assert emb.shape == (w.shape[0], embedding_dim)
    return emb

用装饰器将下面的函数设置为可以直接通过方法名来访问方法。

@property
def guidance_scale(self):
    return self._guidance_scale

@property
def guidance_rescale(self):
    return self._guidance_rescale

@property
def clip_skip(self):
    return self._clip_skip

# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
    return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None

@property
def cross_attention_kwargs(self):
    return self._cross_attention_kwargs

@property
def num_timesteps(self):
    return self._num_timesteps

@property
def interrupt(self):
    return self._interrupt

777-1063行

Stablediffusion的call函数

第0步:定义长宽

第1步:检查参数是否正确

第2步:定义batch_size

第3步:对提示词编码,获取prompt_embed

第4步:获取时间步和推断步数

第5步:准备输入Unet的潜变量

第6步:准备调度器的eta,ip-adater的image_embed,Guidance Scale Embedding

第7步:去噪循环

第8步:输出

@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
    self,
    prompt: Union[str, List[str]] = None, # 提示词
    height: Optional[int] = None,  # 图片长
    width: Optional[int] = None,   # 图片宽
    num_inference_steps: int = 50,  # 推断步数
    timesteps: List[int] = None,   # 时间步
    guidance_scale: float = 7.5,   # cfg的权重
    negative_prompt: Optional[Union[str, List[str]]] = None,  # 负向词
    num_images_per_prompt: Optional[int] = 1,   # 每个提示词输出图片数
    eta: float = 0.0,                # eta,调度器参数,DDIMScheduler要用
    generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,  # 随机数生成器
    latents: Optional[torch.FloatTensor] = None,  # 潜变量
    prompt_embeds: Optional[torch.FloatTensor] = None,  # 提示词embed
    negative_prompt_embeds: Optional[torch.FloatTensor] = None, # 负向词embed
    ip_adapter_image: Optional[PipelineImageInput] = None, # ip-adapter的图片
    output_type: Optional[str] = "pil",  # 输出格式
    return_dict: bool = True,  # 最终结果是否以字典形式输出,True则输出StableDiffusionPipelineOutput格式
                               # 否,则输出元组(image, has_nsfw_concept)
    cross_attention_kwargs: Optional[Dict[str, Any]] = None, # crossattn额外的关键字
    guidance_rescale: float = 0.0, # cfg调整系数,用于rescale_noise_cfg函数
    clip_skip: Optional[int] = None, # clip跳过层数
    callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, # 推断过程的回调函数
    callback_on_step_end_tensor_inputs: List[str] = ["latents"],  # 回调输入
    **kwargs,
):
    r"""
    The call function to the pipeline for generation.

    Args:
        prompt (`str` or `List[str]`, *optional*):
            The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
        height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
            The height in pixels of the generated image.
        width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
            The width in pixels of the generated image.
        num_inference_steps (`int`, *optional*, defaults to 50):
            The number of denoising steps. More denoising steps usually lead to a higher quality image at the
            expense of slower inference.
        timesteps (`List[int]`, *optional*):
            Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
            in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
            passed will be used. Must be in descending order.
        guidance_scale (`float`, *optional*, defaults to 7.5):
            A higher guidance scale value encourages the model to generate images closely linked to the text
            `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
        negative_prompt (`str` or `List[str]`, *optional*):
            The prompt or prompts to guide what to not include in image generation. If not defined, you need to
            pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
        num_images_per_prompt (`int`, *optional*, defaults to 1):
            The number of images to generate per prompt.
        eta (`float`, *optional*, defaults to 0.0):
            Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
            to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
        generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
            A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
            generation deterministic.
        latents (`torch.FloatTensor`, *optional*):
            Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
            generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
            tensor is generated by sampling using the supplied random `generator`.
        prompt_embeds (`torch.FloatTensor`, *optional*):
            Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
            provided, text embeddings are generated from the `prompt` input argument.
        negative_prompt_embeds (`torch.FloatTensor`, *optional*):
            Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
            not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
        ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
        output_type (`str`, *optional*, defaults to `"pil"`):
            The output format of the generated image. Choose between `PIL.Image` or `np.array`.
        return_dict (`bool`, *optional*, defaults to `True`):
            Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
            plain tuple.
        cross_attention_kwargs (`dict`, *optional*):
            A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
            [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
        guidance_rescale (`float`, *optional*, defaults to 0.0):
            Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
            Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
            using zero terminal SNR.
        clip_skip (`int`, *optional*):
            Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
            the output of the pre-final layer will be used for computing the prompt embeddings.
        callback_on_step_end (`Callable`, *optional*):
            A function that calls at the end of each denoising steps during the inference. The function is called
            with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
            callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
            `callback_on_step_end_tensor_inputs`.
        callback_on_step_end_tensor_inputs (`List`, *optional*):
            The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
            will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
            `._callback_tensor_inputs` attribute of your pipeline class.

    Examples:

    Returns:
        [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
            If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
            otherwise a `tuple` is returned where the first element is a list with the generated images and the
            second element is a list of `bool`s indicating whether the corresponding generated image contains
            "not-safe-for-work" (nsfw) content.
    """
    
    # 获取和callback和callback_steps参数,不存中则为None
    callback = kwargs.pop("callback", None)
    callback_steps = kwargs.pop("callback_steps", None)
    
    # 弃用警告,警告用callback_on_step_end替换allback参数
    if callback is not None:
        deprecate(
            "callback",
            "1.0.0",
            "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
        )
    # 弃用警告,警告用callback_on_step_end替换callback_steps参数
    if callback_steps is not None:
        deprecate(
            "callback_steps",
            "1.0.0",
            "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
        )

    # 0. Default height and width to unet
    # 定义长宽,默认为64*3=512
    height = height or self.unet.config.sample_size * self.vae_scale_factor
    width = width or self.unet.config.sample_size * self.vae_scale_factor
    # to deal with lora scaling and other possible forward hooks

    # 1. Check inputs. Raise error if not correct
    # 检查参数是否正确
    self.check_inputs(
        prompt,
        height,
        width,
        callback_steps,
        negative_prompt,
        prompt_embeds,
        negative_prompt_embeds,
        callback_on_step_end_tensor_inputs,
    )

    self._guidance_scale = guidance_scale
    self._guidance_rescale = guidance_rescale
    self._clip_skip = clip_skip
    self._cross_attention_kwargs = cross_attention_kwargs
    self._interrupt = False

    # 2. Define call parameters
    # 根据prompt数量定义batch_size
    if prompt is not None and isinstance(prompt, str):
        batch_size = 1
    elif prompt is not None and isinstance(prompt, list):
        batch_size = len(prompt)
    else:
        batch_size = prompt_embeds.shape[0]

    device = self._execution_device

    # 3. Encode input prompt
    # 获取cross-attn中的scale作为lora-scale,无则None
    lora_scale = (
        self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
    )
    
    # 对提示词编码,输出embed
    prompt_embeds, negative_prompt_embeds = self.encode_prompt(
        prompt,
        device,
        num_images_per_prompt,
        self.do_classifier_free_guidance,
        negative_prompt,
        prompt_embeds=prompt_embeds,
        negative_prompt_embeds=negative_prompt_embeds,
        lora_scale=lora_scale,
        clip_skip=self.clip_skip,
    )

    # For classifier free guidance, we need to do two forward passes.
    # Here we concatenate the unconditional and text embeddings into a single batch
    # to avoid doing two forward passes
    # 设置CFG,拼接[负向词,正向提示词],负向词会作为无条件embed
    if self.do_classifier_free_guidance:
        prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
    
    # 如果有ip-adapter参考图
    if ip_adapter_image is not None:
        # 设置是否输出隐藏层状态
        output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
        # 使用图像编码器将参考图编码,输出图像embed和无条件的图像embed
        image_embeds, negative_image_embeds = self.encode_image(
            ip_adapter_image, device, num_images_per_prompt, output_hidden_state
        )
        # 若启用CFG,则 拼接[无条件embed,图像embed]
        if self.do_classifier_free_guidance:
            image_embeds = torch.cat([negative_image_embeds, image_embeds])

    # 4. Prepare timesteps
    # 获取时间步和推断步数
    timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)

    # 5. Prepare latent variables
    # 初始化潜变量
    num_channels_latents = self.unet.config.in_channels
    latents = self.prepare_latents(
        batch_size * num_images_per_prompt,
        num_channels_latents,
        height,
        width,
        prompt_embeds.dtype,
        device,
        generator,
        latents,
    )
    
     # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
     # 设定调度器的eta
     extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
     
     # 6.1 Add image embeds for IP-Adapter
     # 增加image_embeds给ip-adapter
     added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
     
     # 6.2 Optionally get Guidance Scale Embedding
     # 获取Guidance Scale Embedding作为时间步的条件
    timestep_cond = None
    if self.unet.config.time_cond_proj_dim is not None:
        guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
        timestep_cond = self.get_guidance_scale_embedding(
            guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
        ).to(device=device, dtype=latents.dtype)
        
        
     # 7. Denoising loop
     # 去噪过程
     # 热身步数,总步数 - 推断步数*order,这里order=1
    num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
    # 总步数
    self._num_timesteps = len(timesteps)
    # progress_bar是调用tqdm
    with self.progress_bar(total=num_inference_steps) as progress_bar:
        for i, t in enumerate(timesteps):
            # self.interrupt默认为False
            if self.interrupt:
                continue

            # expand the latents if we are doing classifier free guidance
            # 要cfg则要两份latent。因为prompt是concat了[无条件,有条件]
            latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
            latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

            # predict the noise residual
            # 预测噪声
            noise_pred = self.unet(
                latent_model_input,  # 潜变量
                t,                   # 时间
                encoder_hidden_states=prompt_embeds, # prompt
                timestep_cond=timestep_cond,         # 时间步的条件
                cross_attention_kwargs=self.cross_attention_kwargs,  # 可选的字典参数,用于传递到 AttentionProcessor ,默认None
                added_cond_kwargs=added_cond_kwargs,  # ip-adapter的image_embed键值对
                return_dict=False,
            )[0]

            # perform guidance
            # 做cfg需要将无条件和有条件的预测拆分,然后重组
            if self.do_classifier_free_guidance:
                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
               
            # 对有条件和无条件组合预测再重组
            if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
                # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
                noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)

            # compute the previous noisy sample x_t -> x_t-1
            # 计算预测噪声,计算方法是论文的公式
            latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
            
            # 若需要回调函数存在
            if callback_on_step_end is not None:
                callback_kwargs = {}
                # 根据callback_on_step_end_tensor_inputs设定回调函数的参数
                for k in callback_on_step_end_tensor_inputs:
                    callback_kwargs[k] = locals()[k]
                callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
                
                # 从结果中弹出latents,prompt_embeds,negat_prompt_embeds,并赋值
                latents = callback_outputs.pop("latents", latents)
                prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
                negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
            
            # call the callback, if provided
            # 若满足回调条件:最后一步,或者,超过热身步数且步数%调度器order==0
            #     更新tqdm。
            #     若回调函数存在,则
            if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
                progress_bar.update()
                if callback is not None and i % callback_steps == 0:
                    step_idx = i // getattr(self.scheduler, "order", 1)
                    callback(step_idx, t, latents)
     
     # 若输出不是'latent'形式,则使用vae.decoder做解码
     if not output_type == "latent":
        # latent要除以scaling_factor
        image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
            0
        ]
        # 做安全检查
        image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
    # 若输出要'latent'形式,则直接输出latent
    else:
        image = latents
        has_nsfw_concept = None
     
    # 设定做反归一化的布尔索引
    if has_nsfw_concept is None:
        do_denormalize = [True] * image.shape[0]
    else:
        do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
    
    # tensor转pil,且do_denormalize是将[-1,1]反归一化到[0,1]
    image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)

    # Offload all models
    # 卸模型
    self.maybe_free_model_hooks()
    
    # 若不是输出字典,则输出元组形式
    if not return_dict:
        return (image, has_nsfw_concept)
    
    # 若输出字典形式,则调用StableDiffusionPipelineOutput输出
    return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)

更新时间 2024-06-07