首页 > 其他分享 >DDMP中的损失函数

DDMP中的损失函数

时间:2024-06-16 18:59:26浏览次数:13  
标签:bar 函数 DDMP frac 损失 mu theta alpha Sigma

接着扩散模型 简述训练扩散模型过程中用到的损失函数形式。完整的观察数据\(x\)的对数似然如下:

\[\begin{aligned} \mathrm{log}\ p(x) &\geq \mathbb{E}_{q_{\phi}(z_{1:T}|z_0)} \mathrm{log} \frac{p(z_T)\prod_{t=0}^{T-1}p_{\theta}(z_t|z_{t+1})}{\prod_{t=0}^{T-1}q_{\phi}(z_{t+1}|z_t)} \\ &= \mathbb{E}_{q_{\phi}(z_{1}|z_0)} [\mathrm{log}\ p_{\theta}(z_0|z_1) ] - \mathbb{D}_{KL}(q_{\phi}(z_T|z_0)||p(z_T)) - \sum_{t=2}^{T} \mathbb{E}_{q_{\phi}(z_t|z_0)} [ \mathbb{D}_{KL}(q_{\phi}(z_{t-1}|z_t,z_0)||p_{\theta}(z_{t-1}|z_t)) ] \end{aligned} \tag {1} \]

其中,\(q_{\phi}(z_{t-1}|z_t,z_0)\)为了便于计算,已经近似为高斯分布

\[\mathcal N(\mu_q(z_t,z_0), \Sigma_q(t)) \tag {2}\]

\[\mu_q(z_t, z_0) = \frac{\alpha_t(1-\bar{\alpha}_{t-1}^2) z_t + \bar{\alpha}_{t-1}( 1 - \alpha_t^2 ) z_0 }{ 1 - \bar {\alpha}_t^2 } \tag {3} \]

\[\Sigma_q(t) = \frac{ (1 - \alpha_t^2) (1 - \bar{\alpha}_{t-1}^2) }{ 1 - \bar{\alpha}_{t}^2 }I \tag {4} \]

形式一

为了使得去噪过程\(p_{\theta}(z_{t-1}|z_t)\)和“真实”的\(q_{\phi}(z_{t-1}|z_t,z_0)\)尽可能接近,因此也可以将\(p_{\theta}(z_{t-1}|z_t)\)建模为一个高斯分布。又由于所有的\(\alpha\)项在每个时间步都是固定的,因此可以将其方差设计与“真实”的\(q(z_{t-1}|z_t,z_0)\)的方差是一样的。且这个高斯分布与初始值\(z_0\)是无关的,因此可以将其均值设计为关于\(z_t, t\)的函数,即设为\(\mu_{\theta}(z_t,t)\).

  考虑两个高斯分布的KL散度等于

\[\begin{aligned} & \ \ \ \ \mathbb{D}_{KL} ( \mathcal N(x;\mu_x,\Sigma_x) || \mathcal N(y;\mu_y,\Sigma_y)) \\ & = \frac{1}{2}[log\frac{|\Sigma_y|}{|\Sigma_x|} - d + tr(\Sigma_y^{-1}\Sigma_x) + (\mu_y-\mu_x)^T\Sigma_y^{-1}(\mu_y-\mu_x)] \end{aligned} \tag {5} \]

应用到公式(1)中的第三项,因此有

\[\begin{aligned} & \ \ \ \ \mathbb{D}_{KL} ( \mathcal N(z_{t-1};\mu_q(z_t,z_0),\Sigma_q(t)) || \mathcal N(z_{t-1};\mu_{\theta}(z_t,t),\Sigma_q(t))) \\ & = \frac{1}{2\sigma_{q}^2(t)}||\mu_{\theta}(x_t,t) - \mu_{q}(x_t,x_0)||^2 \end{aligned} \tag {6} \]

其中\(\sigma_{q}^2(t)\)是公式(4)前的系数即\(\sigma_{q}^2(t)= \frac{ (1 - \alpha_t^2) (1 - \bar{\alpha}_{t-1}^2) }{ 1 - \bar{\alpha}_{t}^2 }\)

由于\(\mu_{\theta}(x_t,t)\)也是\(x_t\)的函数,因此,可以参考公式(3)的形式,将进一步假设

\[\mu_{\theta}(x_t, t) = \frac{\alpha_t(1-\bar{\alpha}_{t-1}^2) z_t + \bar{\alpha}_{t-1}( 1 - \alpha_t^2 ) z_{\theta}(z_t, t) }{ 1 - \bar {\alpha}_t^2 } \tag {7} \]

这样公式(6)进一步化简为

\[\begin{aligned} & \ \ \ \ \mathbb{D}_{KL} ( \mathcal N(z_{t-1};\mu_q(z_t,z_0),\Sigma_q(t)) || \mathcal N(z_{t-1};\mu_{\theta}(z_t,t),\Sigma_q(t))) \\ & = \frac{1}{2\sigma_{q}^2(t)} \frac{\bar{\alpha}_{t-1}^2( 1 - \alpha_t^2 )^2}{ (1 - \bar {\alpha}_t^2)^2} ||z_{\theta}(z_t,t) - z_0||^2 \end{aligned} \tag {8} \]

至此,优化VDM就变成了学习一个神经网络,从样本任意时刻的加噪版本预测出其原来的样本。最终最小化公式(1)中的第三项,等价于最小化关于时间步的期望,因此有

\[arg min \mathbb{E}_{t \sim U\{2,T\}} [ \mathbb{E}_{q_{\phi}(z_t|z_0)}[ \mathbb{D}_{KL}(q_{\phi}(z_{t-1}|z_t,z_0)||p_{\theta}(z_{t-1}|z_t)) ] ] \]

形式二

\[z_t = \bar \alpha_t z_0 + \sqrt{1-\bar {\alpha}_t^2} \bar \epsilon_t \tag {9} \]

可得

\[z_0 = \frac{z_t - \sqrt{(1-\bar {\alpha}_t^2)} \bar {\epsilon}_t}{\bar {\alpha}_t} \tag {10} \]

再代入公式(3)得

\[\mu_q(x_t,x_0) = \frac{1}{\alpha_t}x_t - \frac{1-\alpha_t^2}{\sqrt{1-\bar{\alpha}_t^2} \alpha_t} \bar \epsilon_t \tag{11} \]

参考形式一中的假设方式,可以假设

\[\mu_{\theta}(x_t,t) = \frac{1}{\alpha_t}x_t - \frac{1-\alpha_t^2}{\sqrt{1-\bar{\alpha}_t^2} \alpha_t} \epsilon_{\theta}(z_t, t) \tag{12} \]

再代入公式(6)可以得到

\[\begin{aligned} & \ \ \ \ \mathbb{D}_{KL} ( \mathcal N(z_{t-1};\mu_q(z_t,z_0),\Sigma_q(t)) || \mathcal N(z_{t-1};\mu_{\theta}(z_t,t),\Sigma_q(t))) \\ & = \frac{1}{2\sigma_{q}^2(t)} \frac{( 1 - \alpha_t^2 )^2}{ (1 - \bar {\alpha}_t^2)\alpha_t^2} ||\epsilon_{\theta}(z_t,t) - \epsilon_t||^2 \end{aligned} \tag {12} \]

至此,优化VDM就变成了学习一个神经网络,从样本任意时刻的加噪版本预测出按照公式(10)添加的原始噪音。

形式三

由公式(8)和公式(12)可以得到

\[||\epsilon_{\theta}(z_t,t) - \epsilon_t||^2 = \frac{\bar{\alpha_t}^2}{1-\bar{\alpha_t}^2} ||z_{\theta}(z_t,t) - z_0||^2 \tag{13} \]

由于\(\bar {\alpha_t}, \sqrt{1-\bar {\alpha_t}^2}\) 分别是\(t\)时间步的加噪信号公式(9)中的原始信号和噪音信号系数,因此将信噪比SNR(t)定义为系数平方之比,即

\[SNR(t) = \frac{\bar{\alpha_t}^2}{1-\bar{\alpha_t}^2} \tag {14} \]

这个信噪比在时间步初期其值较大,代表真实信号占比多噪音占比少;在时间步后期其值较小,代表真实信号占比少噪音占比多。因为推理过程是完全从高斯分布随机取样,为了保证推理与训练保持一致,训练过程采取特定的\(\bar {\alpha}_t\)使得T步得到的是完全噪音,不包含任何原始信号。此时信噪比是0.

当预测发送在信噪比接近0(\(\bar \alpha_t \to 0\))时,模型原始预测是噪音\(\bar \epsilon\),因此根据公式(10)预估对应的原始信号

\[\bar z_0 = \frac{z_t - \sqrt{(1-\bar {\alpha}_t^2)} \bar {\epsilon}}{\bar {\alpha}_t} \]

这样网络预测的微小差异就会被放大很多倍,因此在论文[3]模型蒸馏过程,这就不是一个稳定的设计。为了避免这个问题,作者提出了3种解决办法。

  • 直接预测\(z\),而非噪音\(\epsilon\)
  • 同时预测\(z, \epsilon\),通过两个独立的输出通道\(z, \epsilon\)。由于根据公式(10)可以再由\(\epsilon\)再推断出\(z^{'}\),然后可以根据\(\bar \alpha_t^2, 1-\bar \alpha_t^2\)对这两个值进行差值。
  • 预测混合体 \(v=\alpha_t\epsilon - \sqrt{1-\alpha_t^2}z\)

参考

[1]. https://www.cnblogs.com/wolfling/p/17938102
[2]. Understanding Diffusion Models: A Unified Perspective
[3]. Progressive Distillation for Fast Sampling of Diffusion Models

标签:bar,函数,DDMP,frac,损失,mu,theta,alpha,Sigma
From: https://www.cnblogs.com/wolfling/p/18250729

相关文章

  • 简述回调函数的意义和作用
    回调函数是一种在程序中广泛使用的机制,它的意义和作用主要包括以下几个方面:异步操作:在一些需要异步执行的任务中,如网络请求、文件读写等,回调函数可以在任务完成后被调用,以便进行后续的处理。这样可以避免阻塞程序的执行,提高程序的响应性和效率。事件处理:回调函数可以用于处理各......
  • 编写多个函数的ROP链
    我们已经学会了编写单个和两个简单函数的ROP链,在这里我们说一下,编写ROP链多个需要注意的问题之前我们在学习两个函数的ROP时,编写了这样的payload我们当时没有考虑,参数冲突和栈溢出大小,现在我们来说一说举个例子,如果我们上次学习的两个函数的ROP中没有gets函数,而是read函数我们......
  • pytorch动态量化函数
    PyTorch动态量化APIPyTorch提供了丰富的动态量化API,可以帮助开发者轻松地将模型转换为动态量化模型。主要API包括:torch.quantization.quantize_dynamic:将模型转换为动态量化模型。torch.quantization.QuantStub:观察模型层的输入和输出分布。torch.quantization.Observer......
  • Unity的生命周期函数
    在Unity中,各个生命周期函数是在特定的时机被调用的,它们的执行顺序如下:1.Awake:当脚本实例被加载时调用,用于初始化数据。如果物体上有多个脚本,它们的Awake方法会在Start方法之前执行。2.OnEnable:当对象变为活动状态(enabled)或脚本被启用时调用。如果在场景加载后对象已经......
  • 编写单个函数的ROP链
    什么是ROP链在我初识栈溢出那篇博客已经详细的讲了函数的调用过程(基于X86框架),不了解的可以看一下,没有这个理论基础,是学不好ROP的。现在我们说一下什么是ROP。ROP链就是通过返回地址的修改来完成的编程,调用特定的函数的一种编程模式。我们可以联想一下你做的最简单的栈溢出的题,返......
  • 要将URL参数转换为JSON对象,可以使用以下函数:
    要将URL参数转换为JSON对象,可以使用以下函数:javascriptfunctiongetQueryParams(url){//使用正则表达式提取URL参数constparamsString=url.split('?')[1];if(!paramsString){return{};}//将参数字符串分割成数组,并解析键值对constparams=......
  • 6、Oracle中的分组函数
    最近项目要用到Oracle,奈何之前没有使用过,所以在B站上面找了一个学习视频,用于记录学习过程以及自己的思考。视频链接:【尚硅谷】Oracle数据库全套教程,oracle从安装到实战应用如果有侵权,请联系删除,谢谢。学习目标:了解组函数。描述组函数的用途。使用GROUPBY子句对数据分......
  • 【Linux】fork()函数详解|多进程
    ......
  • 关于ES6的箭头函数和展开运算符
    使用ES6的箭头函数和展开运算符(...)可以简化使用逻辑与(&&)运算符的代码。这种方法通常用于当你有一组变量,并且想要在单个表达式中检查它们是否都满足特定条件时。以下是一个示例,展示如何使用箭头函数和展开运算符来简化检查多个变量是否都已定义且不为空的代码://假设有以下变量co......
  • 机器视觉入门学习:YOLOV5自定义数据集部署、网络详解、损失函数(学习笔记)
     前言源码学习资源:YOLOV5预处理和后处理,源码详细分析-CSDN博客网络学习资源:YOLOv5网络详解_yolov5网络结构详解-CSDN博客YOLOv5-v6.0学习笔记_yolov5的置信度损失公式-CSDN博客 本文为个人学习,整合各路大佬的资料进行V5-6.0版本的网络分析,在开始学习之前最好先去学习YOL......