本篇是《Diffusion Model (扩散生成模型)的基本原理详解(一)Denoising Diffusion Probabilistic Models(DDPM)》的续写,继续介绍有关diffusion的另一个相关模型,同理,参考文献和详细内容与上一篇相同,读者可自行查阅,本篇着重介绍Score-Based Generative Modeling(SGM)的部分,本篇的理论部分参考与上一节相同,当然涉及了一些原文的理论部分,笔者在这里为了更能让各位读懂,略掉了原文的一些理论证明,感兴趣读者可以自行阅读Song Yang et al.SGM原文。笔者只介绍重要思想和重要理论,省略了较多细节篇幅。下一节介绍本基础系列最后一部重点:Stochastic Differential Equation(SDE)。
2、Score-Based Generative Models(SGM)
不同于DDPM,这是一个基于分数的Model,换而言之通过预测评分来获取最终的信息。在阅读本篇之前,这样我们显然会产生两个问题:
一、Network应该以什么为基准的评分?
二、DDPM网络是一个可以直接给出的后验分布(去噪链),那么如果我得到了一个基于评分的网络,如何进行所谓的“采样”,这里没有直接对分布进行所谓的预测。
 下面开始进入正题介绍,先从评分函数讲起。
2.1、Score-Function(评分函数)
SGMS的核心思想在于评分函数的定义,让我们来一起看一下它的Score-Function的定义是怎么样定义的,以下是它的定义:
假若有一个概率密度函数 
     
      
       
       
         p 
        
       
         ( 
        
       
         x 
        
       
         ) 
        
       
      
        p(x) 
       
      
    p(x),定义“Stein-ScoreFunction”如下:
 
      
       
        
        
          S 
         
        
          t 
         
        
          e 
         
        
          i 
         
        
          n 
         
        
          − 
         
        
          S 
         
        
          c 
         
        
          o 
         
        
          r 
         
        
          e 
         
        
          F 
         
        
          u 
         
        
          n 
         
        
          c 
         
        
          t 
         
        
          i 
         
        
          o 
         
        
          n 
         
        
          = 
         
         
         
           ∇ 
          
         
           x 
          
         
        
          l 
         
        
          o 
         
        
          g 
         
        
          ( 
         
        
          p 
         
        
          ( 
         
        
          x 
         
        
          ) 
         
        
          ) 
         
        
       
         Stein-ScoreFunction=\nabla_xlog(p(x)) 
        
       
     Stein−ScoreFunction=∇xlog(p(x))显然的,该score表示了概率密度函数增长的快慢程度。
2.2、SGM forward Markov Chain(加噪链)—— q ( x t ∣ x t − 1 ) q(x_t|x_{t-1}) q(xt∣xt−1)
我们仍旧假设原始数据 
     
      
       
        
        
          x 
         
        
          0 
         
        
       
      
        x_0 
       
      
    x0是从某一分布 
     
      
       
        
        
          x 
         
        
          0 
         
        
       
         ~ 
        
       
         q 
        
       
         ( 
        
        
        
          x 
         
        
          0 
         
        
       
         ) 
        
       
      
        x_0~q(x_0) 
       
      
    x0~q(x0)中采样得到。不同于之前DDPM,SGM使用另外一种更直接化的分布来进行采样,具体操作如下:假设生成噪声步长为 
     
      
       
       
         T 
        
       
      
        T 
       
      
    T,给予一组逐步弱化的噪声: 
     
      
       
       
         [ 
        
        
        
          σ 
         
        
          1 
         
        
       
         , 
        
        
        
          σ 
         
        
          2 
         
        
       
         , 
        
        
        
          σ 
         
        
          3 
         
        
       
         ⋅ 
        
       
         ⋅ 
        
       
         ⋅ 
        
        
        
          σ 
         
        
          T 
         
        
       
         ] 
        
       
      
        [\sigma_1,\sigma_2,\sigma_3···\sigma_T] 
       
      
    [σ1,σ2,σ3⋅⋅⋅σT],则会有如下结论:
 在第 
     
      
       
       
         i 
        
       
      
        i 
       
      
    i步噪声数据 
     
      
       
        
        
          x 
         
        
          i 
         
        
       
      
        x_i 
       
      
    xi满足从如下分布中进行采样:
 
      
       
        
         
         
           x 
          
         
           i 
          
         
        
          ~ 
         
        
          N 
         
        
          ( 
         
         
         
           x 
          
         
           0 
          
         
        
          , 
         
         
         
           σ 
          
         
           i 
          
         
           2 
          
         
        
          I 
         
        
          ) 
           
        
          ⟺ 
           
        
          q 
         
        
          ( 
         
         
         
           x 
          
         
           i 
          
         
        
          ∣ 
         
         
         
           x 
          
         
           0 
          
         
        
          ) 
         
        
          = 
         
        
          N 
         
        
          ( 
         
         
         
           x 
          
         
           0 
          
         
        
          , 
         
         
         
           σ 
          
         
           i 
          
         
           2 
          
         
        
          I 
         
        
          ) 
           
        
          ⟺ 
           
        
          q 
         
        
          ( 
         
         
         
           x 
          
         
           i 
          
         
        
          ) 
         
        
          = 
         
        
          ∫ 
         
        
          q 
         
        
          ( 
         
         
         
           x 
          
         
           i 
          
         
        
          ∣ 
         
         
         
           x 
          
         
           0 
          
         
        
          ) 
         
        
          q 
         
        
          ( 
         
         
         
           x 
          
         
           0 
          
         
        
          ) 
         
        
          d 
         
        
          ( 
         
         
         
           x 
          
         
           0 
          
         
        
          ) 
         
        
       
         x_i~N(x_0,\sigma_i^2I)\iff q(x_i|x_0)=N(x_0,\sigma_i^2I)\iff q(x_i)=\int q(x_i|x_0)q(x_0)d(x_0) 
        
       
     xi~N(x0,σi2I)⟺q(xi∣x0)=N(x0,σi2I)⟺q(xi)=∫q(xi∣x0)q(x0)d(x0)
2.3、Score-Network—— s θ ( x t , t ) s_\theta(x_t,t) sθ(xt,t)
2.3.1、SGM与DDPM的一致性&Loss-Function
不同于DDPM,这里并不是直接训练出一个去噪链直接解决问题并生成数据,我们的目的是想训练出一个可以模拟Score-Function的良好网络,再基于该Score-Function进行反向的采样,即想要设计一个Network : 
     
      
       
        
        
          s 
         
        
          θ 
         
        
       
         ( 
        
        
        
          x 
         
        
          t 
         
        
       
         , 
        
       
         t 
        
       
         ) 
        
       
      
        s_\theta(x_t,t) 
       
      
    sθ(xt,t)用来模拟当前的Score-Function: 
     
      
       
        
        
          ∇ 
         
         
         
           x 
          
         
           t 
          
         
        
       
         l 
        
       
         o 
        
       
         g 
        
       
         ( 
        
       
         q 
        
       
         ( 
        
        
        
          x 
         
        
          t 
         
        
       
         ∣ 
        
        
        
          x 
         
        
          0 
         
        
       
         ) 
        
       
         ) 
        
       
      
        \nabla_{x_{t}} log(q(x_t|x_0)) 
       
      
    ∇xtlog(q(xt∣x0))。那么显然地,目标函数变为:
 
      
       
        
        
          L 
         
        
          o 
         
        
          s 
         
        
          s 
         
        
          = 
         
        
          ∣ 
         
        
          ∣ 
         
         
         
           ∇ 
          
          
          
            x 
           
          
            t 
           
          
         
        
          l 
         
        
          o 
         
        
          g 
         
        
          ( 
         
        
          q 
         
        
          ( 
         
         
         
           x 
          
         
           t 
          
         
        
          ∣ 
         
         
         
           x 
          
         
           0 
          
         
        
          ) 
         
        
          ) 
         
        
          − 
         
         
         
           s 
          
         
           θ 
          
         
        
          ( 
         
         
         
           x 
          
         
           t 
          
         
        
          , 
         
        
          t 
         
        
          ) 
         
        
          ∣ 
         
         
         
           ∣ 
          
         
           2 
          
         
        
       
         Loss=||\nabla_{x_{t}} log(q(x_t|x_0))-s_\theta(x_t,t)||^2 
        
       
     Loss=∣∣∇xtlog(q(xt∣x0))−sθ(xt,t)∣∣2
 而我们已经知道了 
     
      
       
       
         q 
        
       
         ( 
        
        
        
          x 
         
        
          t 
         
        
       
         ∣ 
        
        
        
          x 
         
        
          0 
         
        
       
         ) 
        
       
         = 
        
       
         N 
        
       
         ( 
        
        
        
          x 
         
        
          0 
         
        
       
         , 
        
        
        
          σ 
         
        
          t 
         
        
          2 
         
        
       
         I 
        
       
         ) 
        
       
         = 
        
        
        
          1 
         
         
          
           
           
             2 
            
           
             π 
            
           
          
          
          
            σ 
           
          
            t 
           
          
         
        
        
        
          e 
         
         
         
           − 
          
          
           
           
             ( 
            
            
            
              x 
             
            
              t 
             
            
           
             − 
            
            
            
              x 
             
            
              0 
             
            
            
            
              ) 
             
            
              2 
             
            
           
           
           
             2 
            
            
            
              σ 
             
            
              t 
             
            
              2 
             
            
           
          
         
        
       
      
        q(x_t|x_0)=N(x_0,\sigma_t^2I)=\frac{1}{\sqrt{2\pi}\sigma_t}e^{-\frac{(x_t-x_0)^2}{2\sigma_t^{2}}} 
       
      
    q(xt∣x0)=N(x0,σt2I)=2π 
                   σt1e−2σt2(xt−x0)2
 那么显然地,我们会有 
     
      
       
        
        
          ∇ 
         
         
         
           x 
          
         
           t 
          
         
        
       
         l 
        
       
         o 
        
       
         g 
        
       
         [ 
        
       
         q 
        
       
         ( 
        
        
        
          x 
         
        
          t 
         
        
       
         ∣ 
        
        
        
          x 
         
        
          0 
         
        
       
         ) 
        
       
         ] 
        
       
         = 
        
       
         − 
        
        
         
          
          
            x 
           
          
            t 
           
          
         
           − 
          
          
          
            x 
           
          
            0 
           
          
         
         
         
           σ 
          
         
           t 
          
         
           2 
          
         
        
       
      
        \nabla_{x_{t}}log[q(x_t|x_0)]=-\frac{x_t-x_0}{\sigma_t^{2}} 
       
      
    ∇xtlog[q(xt∣x0)]=−σt2xt−x0
 则会有当前优化目标可以视为如下的函数
 
      
       
        
        
          L 
         
        
          o 
         
        
          s 
         
        
          s 
         
        
          = 
         
        
          ∣ 
         
        
          ∣ 
         
        
          − 
         
         
          
           
           
             x 
            
           
             t 
            
           
          
            − 
           
           
           
             x 
            
           
             0 
            
           
          
          
          
            σ 
           
          
            t 
           
          
            2 
           
          
         
        
          − 
         
         
         
           s 
          
         
           θ 
          
         
        
          ( 
         
         
         
           x 
          
         
           t 
          
         
        
          , 
         
        
          t 
         
        
          ) 
         
        
          ∣ 
         
         
         
           ∣ 
          
         
           2 
          
         
        
       
         Loss=||-\frac{x_t-x_0}{\sigma_t^{2}}-s_\theta(x_t,t)||^2 
        
       
     Loss=∣∣−σt2xt−x0−sθ(xt,t)∣∣2
 注意到第一项,这可视为从正态分布进行的采样,差了一个非网络参数 
     
      
       
        
        
          σ 
         
        
          t 
         
        
       
      
        \sigma_t 
       
      
    σt即:
 
      
       
        
        
          L 
         
        
          o 
         
        
          s 
         
         
         
           s 
          
         
           ∗ 
          
         
        
          = 
         
         
         
           σ 
          
         
           t 
          
         
           2 
          
         
        
          L 
         
        
          o 
         
        
          s 
         
        
          s 
         
        
          = 
         
        
          ∣ 
         
        
          ∣ 
         
         
          
           
           
             x 
            
           
             t 
            
           
          
            − 
           
           
           
             x 
            
           
             0 
            
           
          
          
          
            σ 
           
          
            t 
           
          
         
        
          + 
         
         
         
           σ 
          
         
           t 
          
         
         
         
           s 
          
         
           θ 
          
         
        
          ( 
         
         
         
           x 
          
         
           t 
          
         
        
          , 
         
        
          t 
         
        
          ) 
         
        
          ∣ 
         
         
         
           ∣ 
          
         
           2 
          
         
        
       
         Loss^*=\sigma_t^2Loss=||\frac{x_t-x_0}{\sigma_t}+\sigma_ts_\theta(x_t,t)||^2 
        
       
     Loss∗=σt2Loss=∣∣σtxt−x0+σtsθ(xt,t)∣∣2
 令 
     
      
       
       
         − 
        
        
        
          σ 
         
        
          t 
         
        
        
        
          s 
         
        
          θ 
         
        
       
         ( 
        
        
        
          x 
         
        
          t 
         
        
       
         , 
        
       
         t 
        
       
         ) 
        
       
         = 
        
        
        
          z 
         
        
          θ 
         
        
       
         ( 
        
        
        
          x 
         
        
          t 
         
        
       
         , 
        
       
         t 
        
       
         ) 
        
       
      
        -\sigma_ts_\theta(x_t,t)=z_\theta(x_t,t) 
       
      
    −σtsθ(xt,t)=zθ(xt,t)
 
      
       
        
        
          L 
         
        
          o 
         
        
          s 
         
         
         
           s 
          
         
           ∗ 
          
         
        
          = 
         
         
         
           σ 
          
         
           t 
          
         
           2 
          
         
        
          L 
         
        
          o 
         
        
          s 
         
        
          s 
         
        
          = 
         
        
          ∣ 
         
        
          ∣ 
         
        
          z 
         
        
          − 
         
         
         
           z 
          
         
           θ 
          
         
        
          ∣ 
         
         
         
           ∣ 
          
         
           2 
          
         
        
       
         Loss^*=\sigma_t^2Loss=||z-z_\theta||^2 
        
       
     Loss∗=σt2Loss=∣∣z−zθ∣∣2
如果读者已经度过了笔者写过的(一)DDPM部分,读者会惊奇的发现,这与DDPM的优化目标是一致的,从原理上,它们的目的是相同的。
2.3.2、SGM-Sampling(Langevin Monte Carlo)
SGM的采样办法有很多种,不同于DDPM的那种“一步一步”反向估计后验估计的办法,这里首先介绍使用Langevin Monte Carlo(基于Langevin Dynamics的一种办法)进行采样,这里介绍算法过程,有关Langevin Monte Carlo理论部分的介绍可见随机过程中的Important-Sampling等会有一些介绍,笔者之后会给予一些简单的补充资料来对Langevin Monte Carlo理论进行说明,这里读者可以认为它的目的是去模拟原始分布 
     
      
       
       
         q 
        
       
         ( 
        
        
        
          x 
         
        
          0 
         
        
       
         ) 
        
       
      
        q(x_0) 
       
      
    q(x0),然后直接用该分布采样生成数据。
先来介绍Langevin Monte Carlo采样算法的过程:
 假设我们已经训练好了一个网络 
     
      
       
        
        
          s 
         
        
          θ 
         
        
       
         ( 
        
        
        
          x 
         
        
          t 
         
        
       
         , 
        
       
         t 
        
       
         ) 
        
       
      
        s_\theta(x_t,t) 
       
      
    sθ(xt,t),它可以作为 
     
      
       
        
        
          ∇ 
         
         
         
           x 
          
         
           t 
          
         
        
       
         l 
        
       
         o 
        
       
         g 
        
       
         ( 
        
       
         q 
        
       
         ( 
        
        
        
          x 
         
        
          t 
         
        
       
         ∣ 
        
        
        
          x 
         
        
          0 
         
        
       
         ) 
        
       
         ) 
        
       
      
        \nabla_{x_{t}} log(q(x_t|x_0)) 
       
      
    ∇xtlog(q(xt∣x0))的近似,我们下面要利用该网络进行分布预估:给定比步长(固定)为 
      
       
        
         
         
           s 
          
         
           ∗ 
          
         
        
       
         s^* 
        
       
     s∗ 
      
       
        
        
          ( 
         
         
         
           s 
          
         
           ∗ 
          
         
        
       
         (s^* 
        
       
     (s∗足够小 
      
       
        
        
          ) 
         
        
       
         ) 
        
       
     ),迭代次数为 
      
       
        
        
          N 
         
        
       
         N 
        
       
     N。下面进行反向生成过程,笔者将其总结为SGM算法:
SGM算法(Langevin Monte Carlo法):
①、随机采样一个样本 
      
       
        
         
         
           x 
          
         
           T 
          
         
           0 
          
         
        
          ~ 
         
        
          N 
         
        
          ( 
         
        
          0 
         
        
          , 
         
        
          1 
         
        
          ) 
         
        
       
         x_T^{0}~N(0,1) 
        
       
     xT0~N(0,1) 
      
       
        
        
          ( 
         
        
          T 
         
        
       
         (T 
        
       
     (T足够大 
      
       
        
        
          ) 
         
        
       
         ) 
        
       
     ),记录当前时间 
      
       
        
        
          t 
         
        
          = 
         
        
          T 
         
        
       
         t=T 
        
       
     t=T
②、迭代 
      
       
        
         
         
           x 
          
         
           t 
          
          
          
            i 
           
          
            + 
           
          
            1 
           
          
         
        
          = 
         
         
         
           x 
          
         
           t 
          
         
           i 
          
         
        
          + 
         
         
         
           1 
          
         
           2 
          
         
         
         
           s 
          
         
           ∗ 
          
         
         
         
           s 
          
         
           θ 
          
         
        
          ( 
         
         
         
           x 
          
         
           t 
          
         
           i 
          
         
        
          , 
         
        
          t 
         
        
          ) 
         
        
          + 
         
         
          
          
            s 
           
          
            ∗ 
           
          
         
        
          z 
         
        
       
         x_t^{i+1}=x_t^{i}+\frac{1}{2}s^*s_\theta(x_t^{i},t)+\sqrt{s^*}z 
        
       
     xti+1=xti+21s∗sθ(xti,t)+s∗ 
            z 直到 
      
       
        
        
          i 
         
        
          = 
         
        
          N 
         
        
       
         i=N 
        
       
     i=N
③、 
      
       
        
         
         
           x 
          
          
          
            t 
           
          
            − 
           
          
            1 
           
          
         
           0 
          
         
        
          = 
         
         
         
           x 
          
         
           t 
          
         
           T 
          
         
        
          , 
         
        
          t 
         
        
          = 
         
        
          t 
         
        
          − 
         
        
          1 
         
        
       
         x_{t-1}^{0}=x_{t}^T,t=t-1 
        
       
     xt−10=xtT,t=t−1
④、重复②~③直到 
      
       
        
        
          t 
         
        
          = 
         
        
          0 
         
        
       
         t=0 
        
       
     t=0