【扩散模型Diffusion Model系列】0-从VAE开始(隐变量模型、KL散度、最大化似然与AIGC的关系)
VAE
VAE(Variational AutoEncoder),变分自编码器,是一种无监督学习算法,被用于压缩、特征提取和生成式任务。相比于GAN(Generative Adversarial Network),VAE在数学上有着更加良好的性质,有利于理论的分析和实现。
 
1 生成式模型的目标——KL散度和最大化似然MLE
生成式模型(Generative Model)的目标是学习一个模型,从一个简单的分布 
      
       
        
        
          p 
         
        
          ( 
         
        
          x 
         
        
          ) 
         
        
       
         p(x) 
        
       
     p(x)中采样出数据 
      
       
        
        
          x 
         
        
       
         x 
        
       
     x,通过生成模型 
      
       
        
        
          f 
         
        
          ( 
         
        
          x 
         
        
          ) 
         
        
       
         f(x) 
        
       
     f(x)来逼近真实数据的分布 
      
       
        
         
         
           p 
          
          
          
            d 
           
          
            a 
           
          
            t 
           
          
            a 
           
          
         
        
          ( 
         
        
          x 
         
        
          ) 
         
        
       
         p_{data}(x) 
        
       
     pdata?(x),并生成样本,实现了上面这一点即使我们所希望的结果。
 自然,我们可以想到,生成模型最本质的目标就是最小化模型生成的样本分布 
      
       
        
         
         
           p 
          
         
           θ 
          
         
        
          ( 
         
        
          x 
         
        
          ) 
         
        
       
         p_{\theta}(x) 
        
       
     pθ?(x)和真实样本分布 
      
       
        
         
         
           p 
          
          
          
            d 
           
          
            a 
           
          
            t 
           
          
            a 
           
          
         
        
          ( 
         
        
          x 
         
        
          ) 
         
        
       
         p_{data}(x) 
        
       
     pdata?(x)之间的KL散度:
  
      
       
        
         
          
          
           
            
           
          
          
           
            
             
             
              
              
                a 
               
              
                r 
               
              
                g 
               
              
                m 
               
              
                i 
               
              
                n 
               
              
             
               θ 
              
            ?? 
             
             
               D 
              
              
              
                K 
               
              
                L 
               
              
             
            
              ( 
            ? 
             
             
               p 
              
              
              
                d 
               
              
                a 
               
              
                t 
               
              
                a 
               
              
             
            
              ( 
             
            
              x 
             
            
              ) 
            ? 
            
              ∣ 
             
            
              ∣ 
            ? 
             
             
               p 
              
             
               θ 
              
             
            
              ( 
             
            
              x 
             
            
              ) 
            ? 
            
              ) 
             
            
           
          
          
          
         
         
          
          
           
           
             = 
            
           
          
          
           
            
             
             
              
              
                a 
               
              
                r 
               
              
                g 
               
              
                m 
               
              
                i 
               
              
                n 
               
              
             
               θ 
              
             
            
              ∫ 
             
             
             
               p 
              
              
              
                d 
               
              
                a 
               
              
                t 
               
              
                a 
               
              
             
            
              ( 
             
            
              x 
             
            
              ) 
            ?? 
            
              l 
             
            
              o 
             
            
              g 
            ? 
             
              
               
               
                 p 
                
                
                
                  d 
                 
                
                  a 
                 
                
                  t 
                 
                
                  a 
                 
                
               
              
                ( 
               
              
                x 
               
              
                ) 
               
              
              
               
               
                 p 
                
               
                 θ 
                
               
              
                ( 
               
              
                x 
               
              
                ) 
               
              
             
            
           
          
          
          
         
         
          
          
           
           
             = 
            
           
          
          
           
            
             
             
              
              
                a 
               
              
                r 
               
              
                g 
               
              
                m 
               
              
                a 
               
              
                x 
               
              
             
               θ 
              
             
            
              ∫ 
             
             
             
               p 
              
              
              
                d 
               
              
                a 
               
              
                t 
               
              
                a 
               
              
             
            
              ( 
             
            
              x 
             
            
              ) 
            ?? 
            
              l 
             
            
              o 
             
            
              g 
            ? 
             
              
              
                p 
               
              
                θ 
               
              
             
               ( 
              
             
               x 
              
             
               ) 
              
             
            ? 
            
              【 
             
             
             
               p 
              
              
              
                d 
               
              
                a 
               
              
                t 
               
              
                a 
               
              
             
            
              ( 
             
            
              x 
             
            
              ) 
             
            
              无参数优化】 
             
            
           
          
          
          
         
         
          
          
           
           
             = 
            
           
          
          
           
            
             
             
              
              
                a 
               
              
                r 
               
              
                g 
               
              
                m 
               
              
                a 
               
              
                x 
               
              
             
               θ 
              
             
             
             
               E 
              
              
              
                x 
               
              
                ~ 
               
               
               
                 p 
                
                
                
                  d 
                 
                
                  a 
                 
                
                  t 
                 
                
                  a 
                 
                
               
              
                ( 
               
              
                x 
               
              
                ) 
               
              
             
             
             
               [ 
              
             
               log 
              
             
               ? 
              
              
               
               
                 p 
                
               
                 θ 
                
               
              
                ( 
               
              
                x 
               
              
                ) 
               
              
             
               ] 
              
             
            ? 
            
              【期望的定义】 
             
            
           
          
          
          
         
         
          
          
           
           
             ≈ 
            
           
          
          
           
            
             
             
              
              
                a 
               
              
                r 
               
              
                g 
               
              
                m 
               
              
                a 
               
              
                x 
               
              
             
               θ 
              
             
             
             
               1 
              
             
               m 
              
             
             
             
               ∑ 
              
              
              
                i 
               
              
                = 
               
              
                1 
               
              
             
               m 
              
             
            
              log 
             
            
              ? 
             
             
              
              
                p 
               
              
                θ 
               
              
             
               ( 
              
              
              
                x 
               
              
                i 
               
              
             
               ) 
              
             
             
            ???? 
            
              【从数据集中采样 
             
            
              m 
             
            
              个,估算期望,对应于训练过程】 
             
            
           
          
          
          
         
         
          
          
           
           
             = 
            
           
          
          
           
            
             
             
              
              
                a 
               
              
                r 
               
              
                g 
               
              
                m 
               
              
                a 
               
              
                x 
               
              
             
               θ 
              
             
             
             
               ∏ 
              
              
              
                i 
               
              
                = 
               
              
                1 
               
              
             
               m 
              
             
             
             
               p 
              
             
               θ 
              
             
            
              ( 
             
             
             
               x 
              
             
               i 
              
             
            
              ) 
             
             
             
            ??? 
            
              【最大化似然】 
             
            
           
          
          
          
         
        
       
         \begin{align} &\mathop{argmin}\limits_{\theta} \;D_{KL}(\,p_{data}(x)\,||\,p_{\theta}(x)\,) \\=&\mathop{argmin}\limits_{\theta} \int p_{data}(x)\;log\,\frac{p_{data}(x)}{p_{\theta}(x)} \\=&\mathop{argmax}\limits_{\theta} \int p_{data}(x)\;log\,{p_{\theta}(x)} \qquad\, 【p_{data}(x)无参数优化】 \\=&\mathop{argmax}\limits_{\theta}E_{x\sim p_{data}(x)}\left[\log{p_{\theta}(x)}\right] \qquad \, 【期望的定义】 \\\approx&\mathop{argmax}\limits_{\theta}\frac{1}{m}\sum\limits_{i=1}^{m}\log{p_{\theta}(x_{i})} \qquad \quad \;\; 【从数据集中采样m个,估算期望,对应于训练过程】 \\=&\mathop{argmax}\limits_{\theta} \prod\limits_{i=1}^{m}p_{\theta}(x_{i}) \qquad \qquad \qquad \;\,【最大化似然】 \end{align} 
        
       
     ===≈=?θargmin?DKL?(pdata?(x)∣∣pθ?(x))θargmin?∫pdata?(x)logpθ?(x)pdata?(x)?θargmax?∫pdata?(x)logpθ?(x)【pdata?(x)无参数优化】θargmax?Ex~pdata?(x)?[logpθ?(x)]【期望的定义】θargmax?m1?i=1∑m?logpθ?(xi?)【从数据集中采样m个,估算期望,对应于训练过程】θargmax?i=1∏m?pθ?(xi?)【最大化似然】??
2 从AE到VAE
显然上述的生成式模型并不专门针对VAE,任何一个输出和输入相同分布的模型都可以得到此结论,那么不得不提的就是AE(AutoEncoder),诸如MAE、DAE、VQVAE等。
 AE的目标是最小化重构误差,即重构误差越小,则表示模型生成的数据和真实数据的分布越接近,和上述描述的生成式模型目标一致,但AE之所以不能用于生成式模型,是因为AE的Bottleneck的分布实际上是未知的,我们无法凭空采样一个符合bottleneck分布的数据,所以AE不能直接用于生成式模型。
 AE和VAE实际上都可以被视为一个隐变量模型 
     
      
       
       
         p 
        
       
         ( 
        
       
         x 
        
       
         ∣ 
        
       
         z 
        
       
         ) 
        
       
      
        p(x|z) 
       
      
    p(x∣z),认为在真实数据分布之后,存在着一个隐变量 
      
       
        
        
          z 
         
        
       
         z 
        
       
     z,其分布为 
      
       
        
        
          p 
         
        
          ( 
         
        
          z 
         
        
          ) 
         
        
       
         p(z) 
        
       
     p(z), 
      
       
        
        
          x 
         
        
       
         x 
        
       
     x和 
      
       
        
        
          z 
         
        
       
         z 
        
       
     z之间存在一个隐变量连接,即 
      
       
        
         
         
           p 
          
         
           ? 
          
         
        
          ( 
         
        
          x 
         
        
          ∣ 
         
        
          z 
         
        
          ) 
         
        
       
         p_{\phi}(x|z) 
        
       
     p??(x∣z)。
 例如可以将所有的矩形视为一个真实分布 
     
      
       
       
         p 
        
       
         ( 
        
       
         x 
        
       
         ) 
        
       
      
        p(x) 
       
      
    p(x),而所有的长和宽的分布视为 
     
      
       
       
         p 
        
       
         ( 
        
       
         z 
        
       
         ) 
        
       
      
        p(z) 
       
      
    p(z),那么显然,当我们从 
     
      
       
       
         p 
        
       
         ( 
        
       
         z 
        
       
         ) 
        
       
      
        p(z) 
       
      
    p(z)采样一个长宽 
     
      
       
       
         z 
        
       
         即 
        
       
         z 
        
       
         ~ 
        
       
         p 
        
       
         ( 
        
       
         z 
        
       
         ) 
        
       
      
        z即z \sim p(z) 
       
      
    z即z~p(z)时,事实上也采样到了一个矩形,这是因为我们认为存在明确的 
     
      
       
        
        
          p 
         
        
          ? 
         
        
       
         ( 
        
       
         x 
        
       
         ∣ 
        
       
         z 
        
       
         ) 
        
       
      
        p_{\phi}(x|z) 
       
      
    p??(x∣z),即矩形的宽和高和矩形的分布存在一个连接。
 在AE中, 
     
      
       
       
         z 
        
       
      
        z 
       
      
    z是bottleneck特征向量,很好地表征了原始数据的特征,因此可以利用Decoder即 
     
      
       
        
        
          p 
         
        
          θ 
         
        
       
         ( 
        
       
         x 
        
       
         ∣ 
        
       
         z 
        
       
         ) 
        
       
      
        p_{\theta}(x|z) 
       
      
    pθ?(x∣z)进行复原,理论上如果我们可以采样到 
     
      
       
       
         z 
        
       
      
        z 
       
      
    z,那么就可以进行复原,但事实上我们不知道 
     
      
       
       
         z 
        
       
      
        z 
       
      
    z的分布,因此我们无法用AE进行生成式。
 而在VAE中,我们希望通过Encoder的学习,将真实的后验分布 
     
      
       
        
        
          p 
         
        
          ? 
         
        
       
         ( 
        
       
         z 
        
       
         ∣ 
        
       
         x 
        
       
         ) 
        
       
      
        p_{\phi}(z|x) 
       
      
    p??(z∣x)进行近似,即 
     
      
       
        
        
          p 
         
        
          θ 
         
        
       
         ( 
        
       
         z 
        
       
         ∣ 
        
       
         x 
        
       
         ) 
        
       
      
        p_{\theta}(z|x) 
       
      
    pθ?(z∣x),并且希望后验分布 
      
       
        
         
         
           p 
          
         
           ? 
          
         
        
          ( 
         
        
          z 
         
        
          ∣ 
         
        
          x 
         
        
          ) 
         
        
       
         p_{\phi}(z|x) 
        
       
     p??(z∣x)服从于正态分布 
      
       
        
        
          N 
         
        
          ( 
         
        
          0 
         
        
          , 
         
        
          I 
         
        
          ) 
         
        
       
         N(0,I) 
        
       
     N(0,I),这样的话,在优化足够好的Encoder,即 
     
      
       
        
        
          p 
         
        
          θ 
         
        
       
         ( 
        
       
         z 
        
       
         ∣ 
        
       
         x 
        
       
         ) 
        
       
         ≈ 
        
       
         N 
        
       
         ( 
        
       
         0 
        
       
         , 
        
       
         I 
        
       
         ) 
        
       
      
        p_{\theta}(z|x) \approx N(0,I) 
       
      
    pθ?(z∣x)≈N(0,I)时,我们有:
  
      
       
        
         
          
          
           
            
            
              p 
             
            
              ( 
             
            
              z 
             
            
              ) 
             
            
              = 
             
            
           
          
          
           
            
             
            
              ∫ 
             
             
             
               p 
              
             
               ? 
              
             
            
              ( 
             
            
              z 
             
            
              ∣ 
             
            
              x 
             
            
              ) 
             
            
              p 
             
            
              ( 
             
            
              x 
             
            
              ) 
            ? 
            
              d 
             
            
              x 
             
            
              = 
             
            
              ∫ 
             
             
             
               p 
              
             
               θ 
              
             
            
              ( 
             
            
              z 
             
            
              ∣ 
             
            
              x 
             
            
              ) 
             
            
              p 
             
            
              ( 
             
            
              x 
             
            
              ) 
            ? 
            
              d 
             
            
              x 
             
            
           
          
          
          
         
         
          
          
           
           
             = 
            
           
          
          
           
            
             
            
              ∫ 
             
            
              N 
             
            
              ( 
             
            
              0 
             
            
              , 
             
            
              I 
             
            
              ) 
             
            
              p 
             
            
              ( 
             
            
              x 
             
            
              ) 
            ? 
            
              d 
             
            
              x 
             
            
              = 
             
            
              N 
             
            
              ( 
             
            
              0 
             
            
              , 
             
            
              I 
             
            
              ) 
             
            
              ∫ 
             
            
              p 
             
            
              ( 
             
            
              x 
             
            
              ) 
             
            
              d 
             
            
              x 
             
            
              = 
             
            
              N 
             
            
              ( 
             
            
              0 
             
            
              , 
             
            
              I 
             
            
              ) 
             
            
           
          
          
          
         
        
       
         \begin{align} p(z)=&\int p_{\phi}(z|x)p(x)\,dx=\int p_{\theta}(z|x)p(x)\,dx\\=&\int N(0,I)p(x)\,dx=N(0,I)\int p(x)dx=N(0,I) \end{align} 
        
       
     p(z)==?∫p??(z∣x)p(x)dx=∫pθ?(z∣x)p(x)dx∫N(0,I)p(x)dx=N(0,I)∫p(x)dx=N(0,I)??
 这样的话,我们就可以轻松地从正态分布中采样 
     
      
       
       
         z 
        
       
         ~ 
        
       
         p 
        
       
         ( 
        
       
         z 
        
       
         ) 
        
       
      
        z\sim p(z) 
       
      
    z~p(z),为此我们必须考虑对“AE的bottleneck”进行修改,从而让 
      
       
        
         
         
           p 
          
         
           θ 
          
         
        
          ( 
         
        
          z 
         
        
          ∣ 
         
        
          x 
         
        
          ) 
         
        
       
         p_{\theta}(z|x) 
        
       
     pθ?(z∣x)的分布近似于 
      
       
        
        
          N 
         
        
          ( 
         
        
          0 
         
        
          , 
         
        
          I 
         
        
          ) 
         
        
       
         N(0,I) 
        
       
     N(0,I),这也是为什么VAE输出的是正态分布的参数 
      
       
        
        
          μ 
         
        
          , 
         
         
         
           σ 
          
         
           2 
          
         
        
       
         \mu,\sigma^2 
        
       
     μ,σ2。
 理论上,我们通过重参数技巧 
      
       
        
        
          x 
         
        
          = 
         
        
          μ 
         
        
          + 
         
        
          σ 
        ? 
        
          ? 
         
        
          , 
         
        
          ? 
         
        
          ~ 
         
        
          N 
         
        
          ( 
         
        
          0 
         
        
          , 
         
        
          I 
         
        
          ) 
         
        
       
         x=\mu+\sigma\,\epsilon,\epsilon \sim N(0,I) 
        
       
     x=μ+σ?,?~N(0,I),即可实现输出为 
     
      
       
       
         N 
        
       
         ( 
        
       
         μ 
        
       
         , 
        
        
        
          σ 
         
        
          2 
         
        
       
         ) 
        
       
      
        N(\mu,\sigma^2) 
       
      
    N(μ,σ2),且将采样这一不可导的操作转为可导。
 若是不对编码器 
      
       
        
         
         
           p 
          
         
           θ 
          
         
        
          ( 
         
        
          z 
         
        
          ∣ 
         
        
          x 
         
        
          ) 
         
        
       
         p_{\theta}(z|x) 
        
       
     pθ?(z∣x)加以限制,只使用MSE进行训练,VAE会逐渐退化为AE,因为网络一定会倾向于将 
      
       
        
         
         
           σ 
          
         
           2 
          
         
        
          → 
         
        
          0 
         
        
       
         \sigma^2 \rightarrow 0 
        
       
     σ2→0,因为这最有利于重建,那么我们最直接的想法就是使用另外2个MSE,强迫 
      
       
        
        
          μ 
         
        
          → 
         
        
          0 
         
        
          , 
        ? 
         
         
           σ 
          
         
           2 
          
         
        
          → 
         
        
          I 
         
        
       
         \mu \rightarrow 0,\,\sigma^2\rightarrow I 
        
       
     μ→0,σ2→I,但这样3个MSE之间的比例就会十分难以调整,容易顾此失彼,因此,我们继续从MLE出发,继续推导VAE的损失函数。
3 VAE的损失函数
承接第一节,我们已经确认了生成式网络的最终目标就是最大化 
      
       
        
         
         
           p 
          
         
           θ 
          
         
        
          ( 
         
        
          x 
         
        
          ) 
         
        
       
         p_{\theta}(x) 
        
       
     pθ?(x)的似然,而正如常识所知,直接最大化 
     
      
       
        
        
          p 
         
        
          θ 
         
        
       
         ( 
        
       
         x 
        
       
         ) 
        
       
      
        p_{\theta}(x) 
       
      
    pθ?(x)太过困难,我们采用隐变量模型建构,那么公式如下:
  
      
       
        
         
          
          
           
            
            
              l 
             
            
              o 
             
            
              g 
             
             
             
               p 
              
             
               θ 
              
             
            
              ( 
             
            
              x 
             
            
              ) 
             
            
           
          
          
           
            
             
            
              = 
             
            
              l 
             
            
              o 
             
            
              g 
             
             
             
               p 
              
             
               θ 
              
             
            
              ( 
             
            
              x 
             
            
              ) 
             
            
              ∫ 
             
             
             
               p 
              
             
               ? 
              
             
            
              ( 
             
            
              z 
             
            
              ∣ 
             
            
              x 
             
            
              ) 
            ? 
            
              d 
             
            
              z 
             
            
           
          
          
          
         
         
          
          
           
            
           
          
          
           
            
             
            
              = 
             
            
              ∫ 
             
             
             
               p 
              
             
               ? 
              
             
            
              ( 
             
            
              z 
             
            
              ∣ 
             
            
              x 
             
            
              ) 
            ? 
            
              l 
             
            
              o 
             
            
              g 
             
             
             
               p 
              
             
               θ 
              
             
            
              ( 
             
            
              x 
             
            
              ) 
            ? 
            
              d 
             
            
              z 
             
            
           
          
          
          
         
         
          
          
           
            
           
          
          
           
            
             
            
              = 
             
            
              ∫ 
             
             
             
               p 
              
             
               ? 
              
             
            
              ( 
             
            
              z 
             
            
              ∣ 
             
            
              x 
             
            
              ) 
            ? 
            
              l 
             
            
              o 
             
            
              g 
             
             
              
               
               
                 p 
                
               
                 θ 
                
               
              
                ( 
               
              
                x 
               
              
                , 
               
              
                z 
               
              
                ) 
               
              
              
               
               
                 p 
                
               
                 θ 
                
               
              
                ( 
               
              
                z 
               
              
                ∣ 
               
              
                x 
               
              
                ) 
               
              
            ? 
            
              d 
             
            
              z 
             
             
            
              【条件概率的定义】 
             
            
           
          
          
          
         
         
          
          
           
            
           
          
          
           
            
             
            
              = 
             
            
              ∫ 
             
             
             
               p 
              
             
               ? 
              
             
            
              ( 
             
            
              z 
             
            
              ∣ 
             
            
              x 
             
            
              ) 
            ? 
            
              l 
             
            
              o 
             
            
              g 
             
             
              
               
               
                 p 
                
               
                 θ 
                
               
              
                ( 
               
              
                x 
               
              
                , 
               
              
                z 
               
              
                ) 
              ? 
               
               
                 p 
                
               
                 ? 
                
               
              
                ( 
               
              
                z 
               
              
                ∣ 
               
              
                x 
               
              
                ) 
               
              
              
               
               
                 p 
                
               
                 θ 
                
               
              
                ( 
               
              
                z 
               
              
                ∣ 
               
              
                x 
               
              
                ) 
              ? 
               
               
                 p 
                
               
                 ? 
                
               
              
                ( 
               
              
                z 
               
              
                ∣ 
               
              
                x 
               
              
                ) 
               
              
            ? 
            
              d 
             
            
              z 
             
            
           
          
          
          
         
         
          
          
           
            
           
          
          
           
            
             
            
              = 
             
            
              ∫ 
             
             
             
               p 
              
             
               ? 
              
             
            
              ( 
             
            
              z 
             
            
              ∣ 
             
            
              x 
             
            
              ) 
            ? 
            
              l 
             
            
              o 
             
            
              g 
             
             
              
               
               
                 p 
                
               
                 θ 
                
               
              
                ( 
               
              
                x 
               
              
                , 
               
              
                z 
               
              
                ) 
               
              
              
               
               
                 p 
                
               
                 ? 
                
               
              
                ( 
               
              
                z 
               
              
                ∣ 
               
              
                x 
               
              
                ) 
               
              
            ? 
            
              d 
             
            
              z 
             
            
              + 
             
            
              ∫ 
             
             
             
               p 
              
             
               ? 
              
             
            
              ( 
             
            
              z 
             
            
              ∣ 
             
            
              x 
             
            
              ) 
            ? 
            
              l 
             
            
              o 
             
            
              g 
             
             
              
               
               
                 p 
                
               
                 ? 
                
               
              
                ( 
               
              
                z 
               
              
                ∣ 
               
              
                x 
               
              
                ) 
               
              
              
               
               
                 p 
                
               
                 θ 
                
               
              
                ( 
               
              
                z 
               
              
                ∣ 
               
              
                x 
               
              
                ) 
               
              
            ? 
            
              d 
             
            
              z 
             
            
           
          
          
          
         
         
          
          
           
            
           
          
          
           
            
             
            
              = 
             
             
             
               E 
              
              
              
                z 
               
              
                ~ 
               
               
               
                 p 
                
               
                 ? 
                
               
              
                ( 
               
              
                z 
               
              
                ∣ 
               
              
                x 
               
              
                ) 
               
              
             
            
              [ 
             
            
              l 
             
            
              o 
             
            
              g 
             
             
              
               
               
                 p 
                
               
                 θ 
                
               
              
                ( 
               
              
                x 
               
              
                , 
               
              
                z 
               
              
                ) 
               
              
              
               
               
                 p 
                
               
                 ? 
                
               
              
                ( 
               
              
                z 
               
              
                ∣ 
               
              
                x 
               
              
                ) 
               
              
             
            
              ] 
             
            
              + 
             
             
             
               D 
              
              
              
                K 
               
              
                L 
               
              
             
            
              ( 
            ? 
             
             
               p 
              
             
               ? 
              
             
            
              ( 
             
            
              z 
             
            
              ∣ 
             
            
              x 
             
            
              ) 
            ? 
            
              ∣ 
             
            
              ∣ 
            ? 
             
             
               p 
              
             
               θ 
              
             
            
              ( 
             
            
              z 
             
            
              ∣ 
             
            
              x 
             
            
              ) 
            ? 
            
              ) 
             
            
           
          
          
          
         
         
          
          
           
            
           
          
          
           
            
             
            
              ≥ 
             
             
             
               E 
              
              
              
                z 
               
              
                ~ 
               
               
               
                 p 
                
               
                 ? 
                
               
              
                ( 
               
              
                z 
               
              
                ∣ 
               
              
                x 
               
              
                ) 
               
              
             
            
              [ 
             
            
              l 
             
            
              o 
             
            
              g 
             
             
              
               
               
                 p 
                
               
                 θ 
                
               
              
                ( 
               
              
                x 
               
              
                , 
               
              
                z 
               
              
                ) 
               
              
              
               
               
                 p 
                
               
                 ? 
                
               
              
                ( 
               
              
                z 
               
              
                ∣ 
               
              
                x 
               
              
                ) 
               
              
             
            
              ] 
             
            ?? 
            
              【 
             
            
              K 
             
            
              L 
             
            
              散度 
             
            
              ≥ 
             
            
              0 
             
            
              ,可利用 
             
            
              ? 
             
            
              l 
             
            
              n 
             
            
              x 
             
            
              ≥ 
             
            
              1 
             
            
              ? 
             
            
              x 
             
            
              证明】 
             
            
           
          
          
          
         
        
       
         \begin{align} log p_{\theta}(x)&=log p_{\theta}(x) \int p_{\phi}(z|x)\,dz \\&=\int p_{\phi}(z|x)\,log p_{\theta}(x)\,dz \\&=\int p_{\phi}(z|x)\,log \frac{p_{\theta}(x,z)}{p_{\theta}(z|x)}\,dz\quad【条件概率的定义】 \\&=\int p_{\phi}(z|x)\,log \frac{p_{\theta}(x,z)\,p_{\phi}(z|x)}{p_{\theta}(z|x)\,p_{\phi}(z|x)}\,dz \\&=\int p_{\phi}(z|x)\,log \frac{p_{\theta}(x,z)}{p_{\phi}(z|x)}\,dz+\int p_{\phi}(z|x)\,log \frac{p_{\phi}(z|x)}{p_{\theta}(z|x)}\,dz \\&=E_{z\sim p_{\phi}(z|x)}[log \frac{p_{\theta}(x,z)}{p_{\phi}(z|x)}]+D_{KL}(\,p_{\phi}(z|x)\,||\,p_{\theta}(z|x)\,) \\& \ge E_{z\sim p_{\phi}(z|x)}[log \frac{p_{\theta}(x,z)}{p_{\phi}(z|x)}]\qquad\;【KL散度\ge0,可利用-lnx \ge 1-x证明】 \end{align} 
        
       
     logpθ?(x)?=logpθ?(x)∫p??(z∣x)dz=∫p??(z∣x)logpθ?(x)dz=∫p??(z∣x)logpθ?(z∣x)pθ?(x,z)?dz【条件概率的定义】=∫p??(z∣x)logpθ?(z∣x)p??(z∣x)pθ?(x,z)p??(z∣x)?dz=∫p??(z∣x)logp??(z∣x)pθ?(x,z)?dz+∫p??(z∣x)logpθ?(z∣x)p??(z∣x)?dz=Ez~p??(z∣x)?[logp??(z∣x)pθ?(x,z)?]+DKL?(p??(z∣x)∣∣pθ?(z∣x))≥Ez~p??(z∣x)?[logp??(z∣x)pθ?(x,z)?]【KL散度≥0,可利用?lnx≥1?x证明】??
 最终我们可认为损失函数为:
  
      
       
        
         
          
          
            a 
           
          
            r 
           
          
            g 
           
          
            m 
           
          
            a 
           
          
            x 
           
          
         
           θ 
          
        ? 
        
          l 
         
        
          o 
         
        
          g 
         
         
         
           p 
          
         
           θ 
          
         
        
          ( 
         
        
          x 
         
        
          ) 
         
        
          = 
         
         
          
          
            a 
           
          
            r 
           
          
            g 
           
          
            m 
           
          
            a 
           
          
            x 
           
          
         
           θ 
          
         
         
         
           E 
          
          
          
            z 
           
          
            ~ 
           
           
           
             p 
            
           
             ? 
            
           
          
            ( 
           
          
            z 
           
          
            ∣ 
           
          
            x 
           
          
            ) 
           
          
         
        
          [ 
         
        
          l 
         
        
          o 
         
        
          g 
         
         
          
           
           
             p 
            
           
             θ 
            
           
          
            ( 
           
          
            x 
           
          
            , 
           
          
            z 
           
          
            ) 
           
          
          
           
           
             p 
            
           
             ? 
            
           
          
            ( 
           
          
            z 
           
          
            ∣ 
           
          
            x 
           
          
            ) 
           
          
         
        
          ] 
         
        
       
         \mathop{argmax}\limits_{\theta}\,log p_{\theta}(x) = \mathop{argmax}\limits_{\theta} E_{z\sim p_{\phi}(z|x)}[log \frac{p_{\theta}(x,z)}{p_{\phi}(z|x)}] 
        
       
     θargmax?logpθ?(x)=θargmax?Ez~p??(z∣x)?[logp??(z∣x)pθ?(x,z)?]
  
      
       
        
        
          L 
         
        
          = 
         
        
          ? 
         
         
         
           E 
          
          
          
            z 
           
          
            ~ 
           
           
           
             p 
            
           
             ? 
            
           
          
            ( 
           
          
            z 
           
          
            ∣ 
           
          
            x 
           
          
            ) 
           
          
         
        
          [ 
         
        
          l 
         
        
          o 
         
        
          g 
         
         
          
           
           
             p 
            
           
             θ 
            
           
          
            ( 
           
          
            x 
           
          
            , 
           
          
            z 
           
          
            ) 
           
          
          
           
           
             p 
            
           
             ? 
            
           
          
            ( 
           
          
            z 
           
          
            ∣ 
           
          
            x 
           
          
            ) 
           
          
         
        
          ] 
         
        
       
         L= -E_{z\sim p_{\phi}(z|x)}[log \frac{p_{\theta}(x,z)}{p_{\phi}(z|x)}] 
        
       
     L=?Ez~p??(z∣x)?[logp??(z∣x)pθ?(x,z)?]
 对于上式我们可以有2种理解:
- 最大化 p θ ( x ) p_{\theta}(x) pθ?(x)转化为了最大化下界ELBO(Evidence Lower Bound),因此我们只需要去优化 E z ~ p ? ( z ∣ x ) [ l o g p θ ( x , z ) p ? ( z ∣ x ) ] E_{z\sim p_{\phi}(z|x)}[log \frac{p_{\theta}(x,z)}{p_{\phi}(z|x)}] Ez~p??(z∣x)?[logp??(z∣x)pθ?(x,z)?]
-  
       
        
         
          
          
            E 
           
           
           
             z 
            
           
             ~ 
            
            
            
              p 
             
            
              ? 
             
            
           
             ( 
            
           
             z 
            
           
             ∣ 
            
           
             x 
            
           
             ) 
            
           
          
         
           [ 
          
         
           l 
          
         
           o 
          
         
           g 
          
          
           
            
            
              p 
             
            
              θ 
             
            
           
             ( 
            
           
             x 
            
           
             , 
            
           
             z 
            
           
             ) 
            
           
           
            
            
              p 
             
            
              ? 
             
            
           
             ( 
            
           
             z 
            
           
             ∣ 
            
           
             x 
            
           
             ) 
            
           
          
         
           ] 
          
         
           = 
          
         
           l 
          
         
           o 
          
         
           g 
          
          
          
            p 
           
          
            θ 
           
          
         
           ( 
          
         
           x 
          
         
           ) 
          
         
           ? 
          
          
          
            D 
           
           
           
             K 
            
           
             L 
            
           
          
         
           ( 
         ? 
          
          
            p 
           
          
            ? 
           
          
         
           ( 
          
         
           z 
          
         
           ∣ 
          
         
           x 
          
         
           ) 
         ? 
         
           ∣ 
          
         
           ∣ 
         ? 
          
          
            p 
           
          
            θ 
           
          
         
           ( 
          
         
           z 
          
         
           ∣ 
          
         
           x 
          
         
           ) 
         ? 
         
           ) 
          
         
        
          E_{z\sim p_{\phi}(z|x)}[log \frac{p_{\theta}(x,z)}{p_{\phi}(z|x)}]=log p_{\theta}(x)-D_{KL}(\,p_{\phi}(z|x)\,||\,p_{\theta}(z|x)\,) 
         
        
      Ez~p??(z∣x)?[logp??(z∣x)pθ?(x,z)?]=logpθ?(x)?DKL?(p??(z∣x)∣∣pθ?(z∣x)),最小化损失函数 
       
        
         
         
           L 
          
         
        
          L 
         
        
      L即最大化 
       
        
         
          
          
            E 
           
           
           
             z 
            
           
             ~ 
            
            
            
              p 
             
            
              ? 
             
            
           
             ( 
            
           
             z 
            
           
             ∣ 
            
           
             x 
            
           
             ) 
            
           
          
         
           [ 
          
         
           l 
          
         
           o 
          
         
           g 
          
          
           
            
            
              p 
             
            
              θ 
             
            
           
             ( 
            
           
             x 
            
           
             , 
            
           
             z 
            
           
             ) 
            
           
           
            
            
              p 
             
            
              ? 
             
            
           
             ( 
            
           
             z 
            
           
             ∣ 
            
           
             x 
            
           
             ) 
            
           
          
         
           ] 
          
         
        
          E_{z\sim p_{\phi}(z|x)}[log \frac{p_{\theta}(x,z)}{p_{\phi}(z|x)}] 
         
        
      Ez~p??(z∣x)?[logp??(z∣x)pθ?(x,z)?]时,会最大化似然 
       
        
         
         
           l 
          
         
           o 
          
         
           g 
          
          
          
            p 
           
          
            θ 
           
          
         
           ( 
          
         
           x 
          
         
           ) 
          
         
        
          log p_{\theta}(x) 
         
        
      logpθ?(x),即让生成图片更真实的同时;最小化Encoder建模的 
       
        
         
          
          
            p 
           
          
            θ 
           
          
         
           ( 
          
         
           z 
          
         
           ∣ 
          
         
           x 
          
         
           ) 
          
         
        
          p_{\theta}(z|x) 
         
        
      pθ?(z∣x)和真实隐变量后验分布 
       
        
         
          
          
            p 
           
          
            ? 
           
          
         
           ( 
          
         
           z 
          
         
           ∣ 
          
         
           x 
          
         
           ) 
          
         
        
          p_{\phi}(z|x) 
         
        
      p??(z∣x)之间的KL散度(当然是事实上是二者trade off)
 若是我们使得 p θ ( z ∣ x ) → N ( 0 , I ) p_{\theta}(z|x)\rightarrow N(0,I) pθ?(z∣x)→N(0,I),即大功告成。于是我们继续分解:
 E z ~ p ? ( z ∣ x ) [ l o g p θ ( x , z ) p ? ( z ∣ x ) ] = ∫ p ? ( z ∣ x ) ? l o g p θ ( x , z ) p ? ( z ∣ x ) ? d z = ∫ p ? ( z ∣ x ) ? l o g p θ ( x ∣ z ) ? p ( z ) p ? ( z ∣ x ) ? d z = ∫ p ? ( z ∣ x ) ? l o g p θ ( x ∣ z ) ? d z + ∫ p ? ( z ∣ x ) ? l o g p ( z ) p ? ( z ∣ x ) ? d z = E z ~ p ? ( z ∣ x ) [ l o g p θ ( x ∣ z ) ] ? D K L ( ? p ? ( z ∣ x ) ? ∣ ∣ ? p ( z ) ? ) ≈ E z ~ p θ ( z ∣ x ) [ l o g p θ ( x ∣ z ) ] ? D K L ( ? p θ ( z ∣ x ) ? ∣ ∣ ? p ( z ) ? ) \begin{align} E_{z\sim p_{\phi}(z|x)}[log \frac{p_{\theta}(x,z)}{p_{\phi}(z|x)}]&=\int p_{\phi}(z|x)\,log \frac{p_{\theta}(x,z)}{p_{\phi}(z|x)}\,dz \\&=\int p_{\phi}(z|x)\,log \frac{p_{\theta}(x|z)\,p(z)}{p_{\phi}(z|x)}\,dz \\&=\int p_{\phi}(z|x)\,log p_{\theta}(x|z)\,dz + \int p_{\phi}(z|x)\,log \frac{p(z)}{p_{\phi}(z|x)}\,dz \\&=E_{z\sim p_{\phi}(z|x)}[log p_{\theta}(x|z)]-D_{KL}(\,p_{\phi}(z|x)\,||\,p(z)\,) \\& \approx E_{z\sim p_{\theta}(z|x)}[log p_{\theta}(x|z)]-D_{KL}(\,p_{\theta}(z|x)\,||\,p(z)\,) \end{align} Ez~p??(z∣x)?[logp??(z∣x)pθ?(x,z)?]?=∫p??(z∣x)logp??(z∣x)pθ?(x,z)?dz=∫p??(z∣x)logp??(z∣x)pθ?(x∣z)p(z)?dz=∫p??(z∣x)logpθ?(x∣z)dz+∫p??(z∣x)logp??(z∣x)p(z)?dz=Ez~p??(z∣x)?[logpθ?(x∣z)]?DKL?(p??(z∣x)∣∣p(z))≈Ez~pθ?(z∣x)?[logpθ?(x∣z)]?DKL?(pθ?(z∣x)∣∣p(z))??
 其中 E z ~ p θ ( z ∣ x ) [ l o g p θ ( x ∣ z ) ] E_{z\sim p_{\theta}(z|x)}[log p_{\theta}(x|z)] Ez~pθ?(z∣x)?[logpθ?(x∣z)]为最大似然,我们假设最终为正态分布,最大似然就完全等价于最小化重建损失MSE
 而 D K L ( ? p θ ( z ∣ x ) ? ∣ ∣ ? p ( z ) ? ) D_{KL}(\,p_{\theta}(z|x)\,||\,p(z)\,) DKL?(pθ?(z∣x)∣∣p(z))则为正则项,用于约束Encoder的输出,具体公式如下:
 D K L ( ? p θ ( z ∣ x ) ? ∣ ∣ ? p ( z ) ? ) = ? ∫ p ? ( z ∣ x ) ? l o g p ( z ) p ? ( z ∣ x ) ? d z = ∫ p ? ( z ∣ x ) ? [ ? z 2 2 ? l o g 1 2 π ? ( z ? μ θ ( x ) ) 2 2 σ θ ( x ) 2 ? + l o g 1 2 π σ θ ( x ) 2 ] d z = 1 2 ∫ p ? ( z ∣ x ) ? [ ? z 2 ? ( z ? μ θ ( x ) σ θ ( x ) ) 2 ? ? l o g σ θ ( x ) 2 ] d z = 1 2 [ ? ? 1 + μ θ ( x ) 2 + σ θ ( x ) 2 ? l o g σ θ ( x ) 2 ? ] 【 E ( z 2 ) = μ 2 + σ 2 ,用于解答 z 2 和 ( z ? μ σ ) 2 】 \begin{align} D_{KL}(\,p_{\theta}(z|x)\,||\,p(z)\,)&=-\int p_{\phi}(z|x)\,log \frac{p(z)}{p_{\phi}(z|x)}\,dz \\&=\int p_{\phi}(z|x)\,[\,\frac{z^2}{2}-log\frac{1}{\sqrt{2\pi}}-\frac{(z-\mu_{\theta}(x))^2}{2{\sigma_{\theta}(x)}^2}\,+log\frac{1}{\sqrt{2\pi{\sigma_{\theta}(x)}^2}}]dz \\&=\frac{1}{2}\int p_{\phi}(z|x)\,[\,z^2-(\frac{z-\mu_{\theta}(x)}{{\sigma_{\theta}(x)}})^2\,-log{\sigma_{\theta}(x)}^2]dz \\&=\frac{1}{2}[\,-1+{\mu_{\theta}(x)}^2+{\sigma_{\theta}(x)}^2-log{\sigma_{\theta}(x)}^2\,]\qquad\qquad【E(z^2)=\mu^2+\sigma^2,用于解答z^2和(\frac{z-\mu}{\sigma})^2】 \end{align} DKL?(pθ?(z∣x)∣∣p(z))?=?∫p??(z∣x)logp??(z∣x)p(z)?dz=∫p??(z∣x)[2z2??log2π?1??2σθ?(x)2(z?μθ?(x))2?+log2πσθ?(x)2?1?]dz=21?∫p??(z∣x)[z2?(σθ?(x)z?μθ?(x)?)2?logσθ?(x)2]dz=21?[?1+μθ?(x)2+σθ?(x)2?logσθ?(x)2]【E(z2)=μ2+σ2,用于解答z2和(σz?μ?)2】??
 综上,我们得到了VAE的损失函数如下:
 L v a e = ? E z ~ p ? ( z ∣ x ) [ l o g p θ ( x , z ) p ? ( z ∣ x ) ] = ? E z ~ p θ ( z ∣ x ) [ l o g p θ ( x ∣ z ) ] + D K L ( ? p θ ( z ∣ x ) ? ∣ ∣ ? p ( z ) ? ) = M S E ( x , p θ ( x , ? ) ) + 1 2 [ ? ? 1 + μ θ ( x ) 2 + σ θ ( x ) 2 ? l o g σ θ ( x ) 2 ? ] \begin{align} L_{vae}&=-E_{z\sim p_{\phi}(z|x)}[log \frac{p_{\theta}(x,z)}{p_{\phi}(z|x)}] \\&=-E_{z\sim p_{\theta}(z|x)}[log p_{\theta}(x|z)]+D_{KL}(\,p_{\theta}(z|x)\,||\,p(z)\,) \\&=MSE(x,p_{\theta}(x,\epsilon))+\frac{1}{2}[\,-1+{\mu_{\theta}(x)}^2+{\sigma_{\theta}(x)}^2-log{\sigma_{\theta}(x)}^2\,]\qquad \end{align} Lvae??=?Ez~p??(z∣x)?[logp??(z∣x)pθ?(x,z)?]=?Ez~pθ?(z∣x)?[logpθ?(x∣z)]+DKL?(pθ?(z∣x)∣∣p(z))=MSE(x,pθ?(x,?))+21?[?1+μθ?(x)2+σθ?(x)2?logσθ?(x)2]??
 具体实现上,即是Encoder后接两层Linear,分别预测 μ θ ( x ) 和 σ θ ( x ) 2 \mu_{\theta}(x)和\sigma_{\theta}(x)^2 μθ?(x)和σθ?(x)2,然后通过重参数化技巧,采样一个 x ′ = μ θ ( x ) + σ θ ( x ) ? ? x'=\mu_{\theta}(x)+\sigma_{\theta}(x)\,\epsilon x′=μθ?(x)+σθ?(x)?输入Decoder,重建x,当然在细节上,我们可以选择预测 l o g ? σ θ ( x ) 2 log\,\sigma_{\theta}(x)^2 logσθ?(x)2,从而避免了网络输出为负的情况。
4 结语
现在准备开始写Diffusion Model的博客,算是一个总结,也算是对学习知识的回顾,学到现在真的得到了太多人博客的帮忙,希望自己也能成其中的一员。
 Reference:
 苏剑林.《变分自编码器(一):原来是这么一回事》
 苗思奇.《机器学习方法—优雅的模型(一):变分自编码器(VAE)》
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。 如若内容造成侵权/违法违规/事实不符,请联系我的编程经验分享网邮箱:veading@qq.com进行投诉反馈,一经查实,立即删除!