当前位置:AIGC资讯 > AIGC > 正文

Exponential Moving Average (EMA) in Stable Diffusion

1.Moving Average in Stable Diffusion (SMA&EMA)

1.Moving average
2.移动平均值
3.How We Trained Stable Diffusion for Less than $50k (Part 3)

Moving Average
在统计学中,移动平均是通过创建整个数据集中不同选择的一系列平均值来分析数据点的计算。


给定一数字序列和固定子集大小,移动平均值的第一个元素是通过对数字序列的初始固定子集求平均值而获得的。然后通过“前移”的方式修改子集;也就是说,排除系列的第一个数字并包括子集中的下一个值。

移动平均的理解,来自移动平均值

1.1 Simple Moving Average(SMA,an unweighted MA)


1.2 Exponential Moving Average (EMA,a weighted MA)

In the context of Stable Diffusion, the Exponential Moving Average (EMA) is a technique used during the training of machine learning models, particularly neural networks, to stabilize and improve the model’s performance.

The Exponential Moving Average is a method of averaging that gives more weight to recent data points, making it more responsive to recent changes compared to a simple moving average, which treats all data points equally.

1.2.1 EMA in Stable Diffusion

In the context of Stable Diffusion, EMA is applied to the model parameters during training to create a smoothed version of the model. This is particularly useful in machine learning because the training process can be noisy, with the model parameters oscillating as they converge towards an optimal solution. By maintaining an EMA of the model parameters, the training process can benefit from the following:

Smoothing: EMA smooths out the parameter updates, reducing the impact of noise and making the training process more stable. Better Generalization: The EMA version of the model often generalizes better on unseen data compared to the model with the raw parameters. This is because EMA tends to favor parameter values that are more consistent over time. Preventing Overfitting: By averaging the parameters over time, EMA can help mitigate overfitting, especially in cases where the model might otherwise converge too quickly to a suboptimal solution.

笔者个人理解
代价函数(loss function)是关于参数(weight&bias)的函数,也就是说一个loss值对应一组参数值,loss值表现为震荡,也就是说模型参数也在变化。在训练SD时的MSE Loss在梯度下降过程中是上下震荡的,对应的模型参数也在震荡,可以用EMA取得这些模型参数震荡值的中间值,这个模型参数的中间值也就能更好的代表所有时刻模型参数的平均水平,让模型获得了更好的泛化能力

Stable Diffusion 2 uses Exponential Moving Averaging (EMA), which maintains an exponential moving average of the weights. At every time step, the EMA model is updated by taking 0.9999 times the current EMA model plus 0.0001 times the new weights after the latest forward and backward pass. By default, the EMA algorithm is applied after every gradient update for the entire training period. However, this can be slow due to the memory operations required to read and write all the weights at every step.
每个时间步都对所有参数进行EMA代价较大,因为要在每个时刻读写模型的全部参数
EMA t = 0.0001 ⋅ x t + 0.9999 ⋅ EMA t − 1 \text{EMA}_t=0.0001\cdot x_t+0.9999\cdot \text{EMA}_{t-1} EMAt​=0.0001⋅xt​+0.9999⋅EMAt−1​
为了使得计算EMA代价减小,我们仅仅采取在最后时间段进行EMA计算
To avoid this costly procedure, we start with a key observation: since the old weights are decayed by a factor of 0.9999 at every batch, the early iterations of training only contribute minimally to the final average. This means we only need to take the exponential moving average of the final few steps. Concretely, we train for 1,400,000 batches and only apply EMA for the final 50,000 steps, which is about 3.5% of the training period. The weights from the first 1,350,000 iterations decay away by (0.9999)^50000, so their aggregate contribution would have a weight of less than 1% in the final model. Using this technique, we can avoid adding overhead for 96.5% of training and still achieve a nearly equivalent EMA model.

1.2.2 Implementation in Stable Diffusion

During the training of a diffusion model, the EMA of the model’s weights is updated alongside the regular updates. Here’s a typical process:

Initialize EMA Weights: At the start of training, initialize the EMA weights to be the same as the model’s initial weights. Update During Training: After each batch update, update the EMA weights using the formula mentioned above. This requires storing a separate set of weights for the EMA. Use for Inference: At the end of the training, use the EMA weights for inference instead of the raw model weights. This is because the EMA weights represent a more stable and potentially better-performing version of the model.

1.2.3 Practical Considerations

Choosing α \alpha α:The smoothing factor α \alpha α is a hyperparameter that needs to be chosen carefully. A common practice is to set α \alpha α based on the number of iterations or epochs, such as α = 2 N + 1 \alpha=\frac{2}{N+1} α=N+12​ where N N N is the number of iterations Performance Overhead: Maintaining EMA weights requires additional memory and computational overhead, but the benefits in terms of model stability and performance often outweigh these costs.

module.py

class EMA:
# Initializes the EMA object with a smoothing factor (beta) and a step counter (step).
    def __init__(self, beta):
        super().__init__()
        self.beta = beta  # Smoothing factor for the exponential moving average
        self.step = 0  # Step counter to keep track of the number of updates
# Updates the moving average of the parameters of the EMA model (ma_model) based on the current model (current_model)
    def update_model_average(self, ma_model, current_model):
        # Update the moving average (EMA) of model parameters
        for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
            old_weight, up_weight = ma_params.data, current_params.data
            # Update the moving average of the parameters
            ma_params.data = self.update_average(old_weight, up_weight)
# Computes the exponentially weighted average of the old and new parameters.
    def update_average(self, old, new):
        # Compute the updated average
        if old is None:
            return new
        return old * self.beta + (1 - self.beta) * new
# Either resets the EMA model parameters to match the current model parameters 
# if the step count is less than step_start_ema, 
# or updates the EMA model parameters based on the current model parameters. 
# It increments the step counter after each call.
    def step_ema(self, ema_model, model, step_start_ema=2000):
        # Update EMA model parameters or reset them based on the step count
        if self.step < step_start_ema:
            self.reset_parameters(ema_model, model)
        else:
            self.update_model_average(ema_model, model)
        self.step += 1  # Increment the step counter
# Copies the current model's parameters to the EMA model to initialize the EMA model parameters
    def reset_parameters(self, ema_model, model):
        # Initialize EMA model parameters to be the same as the current model's parameters
        ema_model.load_state_dict(model.state_dict())

train.py

def train(args):
    device = args.device  # Get the device to run the training on
    model = UNET().to(device)   # Initialize the model and move it to the device
    model.train()
    optimizer = optim.AdamW(model.parameters(), lr=args.lr)  # set up the optimizer with AdamW
    mse = nn.MSELoss()  # Mean Squared Error loss function
    logger = SummaryWriter(os.path.join("runs", args.run_name))
    len_train = len(train_loader)
# EMA:Exponential Moving Average
    ema = EMA(0.995)  # Exponential Moving Average with decay rate 0.995
# At the start of training, initialize the EMA weights to be the same as the model’s initial weights.
    ema_model = copy.deepcopy(model).eval().requires_grad_(False)  # Create a copy of the model for EMA, set to eval mode and no gradients
    print('Start into the loop !')
    for epoch in range(args.epochs):
        logging.info(f"Starting epoch {epoch}:")  # log the start of the epoch
        progress_bar = tqdm(train_loader)  # progress bar for the dataloader
        optimizer.zero_grad()  # Explicitly zero the gradient buffers
        accumulation_steps = 4
        # Load all data into a batch
        for batch_idx, (images, captions) in enumerate(progress_bar):
            images = images.to(device)  # move images to the device
            # The dataloaer will add a batch size dimension to the tensor, but I've already added batch size to the VAE
            # and CLIP input, so we're going to remove a batch size and just keep the batch size of the dataloader
            images = torch.squeeze(images, dim=1)
            captions = captions.to(device)  # move caption to the device
            text_embeddings = torch.squeeze(captions, dim=1) # squeeze batch_size
            timesteps = ddpm_sampler.sample_timesteps(images.shape[0]).to(device)  # Sample random timesteps
            noisy_latent_images, noises = ddpm_sampler.add_noise(images, timesteps)  # Add noise to the images
            time_embeddings = timesteps_to_time_emb(timesteps)
            # x_t (batch_size, channel, Height/8, Width/8) (bs,4,256/8,256/8)
            # caption (batch_size, seq_len, dim) (bs, 77, 768)
            # t (batch_size, channel) (batch_size, 1280)
            # (bs,320,H/8,W/8)
            with torch.no_grad():
                last_decoder_noise = model(noisy_latent_images, text_embeddings, time_embeddings)
            # (bs,4,H/8,W/8)
            final_output = diffusion.final.to(device)
            predicted_noise = final_output(last_decoder_noise).to(device)
            loss = mse(noises, predicted_noise)  # Compute the loss
            loss.backward()  # Backpropagate the loss
            if (batch_idx + 1) % accumulation_steps == 0:  # Wait for several backward passes
                optimizer.step()  # Now we can do an optimizer step
                optimizer.zero_grad()  # Reset gradients to zero
# EMA:Exponential Moving Average
    		ema.step_ema(ema_model, model)
            progress_bar.set_postfix(MSE=loss.item())  # Update the progress bar with the loss
            # log the loss to TensorBoard
            logger.add_scalar("MSE", loss.item(), global_step=epoch * len_train + batch_idx)
        # Save the model checkpoint
        os.makedirs(os.path.join("models", args.run_name), exist_ok=True)
        torch.save(model.state_dict(), os.path.join("models", args.run_name, f"stable_diffusion.ckpt"))
        torch.save(optimizer.state_dict(),
                   os.path.join("models", args.run_name, f"optim.pt"))  # Save the optimizer state

总结

### 文章总结:Stable Diffusion中的移动平均(SMA与EMA)
#### 1. 移动平均介绍
- **移动平均(Moving Average, MA)**是统计学中用于分析数据点的一种计算方法,它通过创建整个数据集中不同选择的一系列平均值来实现。
- 基本的移动平均值包括**简单移动平均(Simple Moving Average, SMA)**和**指数移动平均(Exponential Moving Average, EMA)**。
#### 2. 简单移动平均(SMA)
- SMA是对固定子集内的数据点求平均值。所有数据点权重相同。
#### 3. 指数移动平均(EMA)
- **EMA**是一种加权的平均方法,它给予近期数据点更高的权重,使其对近期变化更加敏感。
- **在Stable Diffusion中的应用**:EMA用于训练过程中稳定和改进模型性能。通过在模型参数上应用EMA,可以创建参数的平滑版本,减少训练过程中的噪声,提升模型的稳定性和泛化能力。
#### 4. EMA在Stable Diffusion中的具体作用
- **平滑作用**:减少参数更新的波动,降低噪声影响,使训练过程更加稳定。
- **提升泛化能力**:EMA版本的模型往往在未见数据上表现更佳。
- **防止过拟合**:通过长期平均参数,有助于缓解模型过快收敛到次优解的情况。
#### 5. 代价优化
- 为了降低在每个时间步对所有参数进行EMA的计算成本,可以选择仅在训练的最后阶段应用EMA。例如,在训练了140万批次后,仅在最后5万批次应用EMA(占总训练期的3.5%),这样可以避免在96.5%的训练过程中增加开销,同时达到近似的EMA模型效果。
#### 6. EMA的实现
- **初始化EMA权重**:训练开始时,将EMA权重初始化为模型初始权重。
- **更新EMA权重**:在每个批次更新后,使用EMA公式更新EMA权重。
- **用于推断**:训练结束时,使用EMA权重进行模型推断。
#### 7. 实践考量
- **选择平滑因子(α)**:需要根据迭代次数或周期数仔细选择。
- **性能开销**:维护EMA权重会增加内存和计算开销,但其带来的模型稳定性和性能提升通常可以抵消这些成本。
#### 8. 代码实例
- 提供了EMA的Python实现示例 (`EMA` 类) 和集成EMA到Stable Diffusion训练过程(`train.py`)的示例代码,展示如何在实际应用中更新EMA权重和使用它们进行推断。
#### 9. 总结
- 文章详细介绍了移动平均的概念,特别是EMA在Stable Diffusion训练中的应用及其优化方法。通过在实际训练过程中恰当使用EMA,可以在不显著提升计算成本的前提下,显著提升模型的稳定性和泛化能力。

更新时间 2024-08-12