首页 > 其他分享 >GAN 的基本形式

GAN 的基本形式

时间:2022-11-17 22:24:45浏览次数:72  
标签:基本 log over 形式 GAN pmb mathcal align sim

目录

GAN

GAN 即生成式对抗网络,这个网络包括两个部分:生成器 \(G\) 和鉴别器 \(D\)。\(D\) 的目标是在生成器生成的图像(或其他输出)和真实图像中鉴别出两者,即 \(\mathcal L_D:=L[D(G(x)), \text{fake}]+L[D(y),\text{valid}]\);而 \(G\) 的目标则是相对地生成不会被鉴别器发现的图像,即 \(\mathcal L_G:=L[D(G(x)),\text{valid}]\)。

\[\begin{align*} \mathcal L_D&:=L[\text{fake},D(G(x))]+L[\text{valid},D(y)]\\ \mathcal L_G&:=L[\text{valid},D(G(x))]\\ \end{align*} \]

其中 \(L\) 为衡量两个参数差异的函数;\(D\) 为鉴别器输出,判断图像(生成的图像或者真实图像)的真实性;\(G(x)\) 代表生成器生成的图片;\(y\) 代表真实图片。

在生成器和鉴别器的竞争中,生成器生成的图片逐渐接近真实图片,也就是说生成器学习到了样本的真实分布。

基本形式

对于一个简单的、应用于图片生成的 GAN,生成器的输入是一个服从分布 \(\mathcal Z\) 的高维随机噪声 \(\pmb z\sim\mathcal Z\);而真实图片则为一个服从分布 \(\mathcal X\) 的数据 \(\pmb x\sim\mathcal X\)。

令 \(D()\) 输出值区间为 \([0,1]\) 表示真实性,则 \(\begin{cases}\text{fake}\leftarrow0\\\text{valid}\leftarrow1\end{cases}\),选择 \(L\) 为二分类的交叉熵函数。

\[\begin{align*} L(y,\hat y)&=-y\log\hat y-(1-y)\log(1-\hat y)\\[0.5em] \mathcal L_D&=L[\text{fake},D(G(\pmb z))]+L[\text{valid},D(\pmb x)]\\ &=-\log[1-D(G(\pmb z))]-\log D(\pmb x)\\ \mathcal L_G&=L[\text{valid},D(G(\pmb z))]\\ &=-\log D(G(\pmb z)) \end{align*} \]

改写为平均形式(期望),即加上 \(\mathbb E\)。

\[\begin{align*} \mathcal L_D&=-\frac1N\sum_{i=1}^N\log[1-D(G(\pmb z_i))]+\log D(\pmb x_i)\\ &\xlongequal{N\to\infty}-\sum_{\pmb z\sim\mathcal Z}p_{\pmb z\sim\mathcal Z}(\pmb z)\log[1-D(G(\pmb z))]-\sum_{\pmb x\sim\mathcal X}p_{\pmb x\sim\mathcal X}(\pmb x)\log D(\pmb x_i)\\[0.7em] &=-\mathbb E_{\pmb z\sim\mathcal Z}\log[1-D(G(\pmb z))]-\mathbb E_{\pmb x\sim\mathcal X}\log D(\pmb x)\\[1em] \mathcal L_G&=-\frac1N\sum_{i=1}^N\log D(G(\pmb z_i))\\ &\xlongequal{N\to\infty}-\sum_{\pmb z\sim\mathcal Z}p_{\pmb z\sim\mathcal Z}(\pmb z)\log D(G(\pmb z))\\[0.5em] &=-\mathbb E_{\pmb z\sim\mathcal Z}\log D(G(\pmb z)) \end{align*} \]

显然 \(\log[1-D(G(\pmb z))]\) 和 \(\log[D(G(\pmb z))]\) 是对立的,一者增加则另一者减小,不妨将 \(\mathcal L_G\) 改为 \(\mathbb E_{z\sim\mathcal Z}\log[1-D(G(\pmb z))]\),此时可以将 \(\mathcal L_D\) 和 \(\mathcal L_G\) 合为一体。

\[\begin{align*} \mathcal L&:=\mathbb E_{\pmb z\sim\mathcal Z}\log[1-D(G(\pmb z))]+\mathbb E_{\pmb x\sim\mathcal X}\log D(\pmb x)\\ \end{align*} \]

此时生成器和鉴别器的目标合成为 \(\min_G\max_D\mathcal L\)。

实际训练时,\(D\) 和 \(G\) 是分开训练的,应该分别采用 \(\mathcal L_D,\mathcal L_G\) 训练。

若 \(\theta_G\) 设为 \(G\) 的权重,根据 \(D(G(\pmb z))\) 的在最开始训练的情况(\(D(G(\pmb z))\) 接近 0),应该选择 \(\mathbb E_{\pmb z\sim\mathcal Z}\log D(G(\pmb z))\) ,这样会有更高的导数。也可以看到,两个函数的极值点是相同的,都满足 \(\nabla_{\theta_G}D(G(\pmb z))=\pmb0\)。

\[\begin{align*} \nabla_{\theta_G}\mathbb E_{\pmb z\sim\mathcal Z}\log D(G(\pmb z))&=\frac1{D(G(\pmb z))}\nabla_{\theta_G}D(G(\pmb z))\\ \nabla_{\theta_G}\mathbb E_{\pmb z\sim\mathcal Z}\log[1-D(G(\pmb z))]&=\frac1{D(G(\pmb z))-1}\nabla_{\theta_G}D(G(\pmb z))\\ \end{align*} \]

最优鉴别器

已知 \(\pmb z\sim\mathcal Z\),\(\pmb x\sim\mathcal X\),而 \(\pmb z\) 经过生成器的处理也产生了一个分布 \(G(\pmb z)\sim\mathcal Z_G\)。这些分布都是在(可能不同的)高维空间的一种分布。将这些分布合成为 \(\mathcal T:=\mathcal X\cup\mathcal Z_G\)。

\[\begin{align*} \mathcal L&=\mathbb E_{\pmb z\sim\mathcal Z}\log[1-D(G(\pmb z))]+\mathbb E_{\pmb x\sim\mathcal X}\log D(\pmb x)\\ &=\int_{\mathcal Z}p_{\mathcal Z}(\pmb z)\log[1-D(G(\pmb z))]\mathrm d\pmb z+\int_{\mathcal X}p_{\mathcal X}(\pmb x)\log D(\pmb x)\mathrm d\pmb x\\ &=\int_{\mathcal Z_G}p_{\mathcal Z_G}(\pmb z)\log[1-D(\pmb z)]\mathrm d\pmb z+\int_{\mathcal X}p_{\mathcal X}(\pmb x)\log D(\pmb x)\mathrm d\pmb x\\ &=\int_{\mathcal T}p_{\mathcal Z_G}(\pmb x)\log[1-D(\pmb x)]+p_{\mathcal X}(\pmb x)\log D(\pmb x)\mathrm d\pmb x\\ \end{align*} \]

考虑到 \(a\log x+b\log(1-x),0<a,b,x<1\) 达到最大值时,自变量 \(x\) 为 \(\frac a{a+b}\)。因此最优的鉴别器为 \(D(\pmb x)={p_{\mathcal X}(\pmb x)\over p_{\mathcal Z_G}(\pmb x)+p_{\mathcal X}(\pmb x)}\)。

\[\begin{align*} \max_D\mathcal L&=\mathbb E_{\pmb z\sim\mathcal Z}\log[1-D(G(\pmb z))]+\mathbb E_{\pmb x\sim\mathcal X}\log D(\pmb x)\\ &=\mathbb E_{\pmb x\sim\mathcal T}\log\left[{p_{\mathcal Z_G}(\pmb x)\over p_{\mathcal Z_G}(\pmb x)+p_{\mathcal X}(\pmb x)}\right]+\mathbb E_{\pmb x\sim\mathcal T}\log\left[{p_{\mathcal X}(\pmb x)\over p_{\mathcal Z_G}(\pmb x)+p_{\mathcal X}(\pmb x)}\right]\\ &=2D_{JS}\left(p_{\mathcal X}\parallel p_{\mathcal Z_G}\right)-\log4 \end{align*} \]

即在最优鉴别器下,损失函数等同于 JS 散度。

学习过程

那么根据两个分布的 JS 散度,能否从一个分布学习到另一个分布呢?

生成器 \(G(\pmb z;\theta_G)\) 的学习方法是求 \(\nabla_{\theta_G}\mathcal L\) 或 \(\nabla_{\theta_G}\mathcal L_G\),现在将其转化为生成器生成的分布和损失函数的求导关系,并简化 \(\mathcal L,p_{\mathcal Z_G},p_{\mathcal X}\) 的写法:

\[\begin{align*} \nabla_p\mathcal L_{maxD}(p, q)&:=\nabla_{p_{\mathcal Z_G}}\max_D\mathcal L(p_{\mathcal Z_G},p_{\mathcal X}) \end{align*} \]

由于分布是概率的函数,所以关于分布的导数就是关于一个函数的导数,这个函数服从约束 \(p\in[0,1]\wedge\int p=1\)。

现在采用拉格朗日乘数法求这个导数的极值点

\[\begin{align*} &&{\mathrm d\over\mathrm d\epsilon}\mathcal L_{maxD}(p+\epsilon\eta,q)&=\int_{\mathcal T}{\mathrm d\over\mathrm d\epsilon}\left[(p+\epsilon\eta)\log\left({p+\epsilon\eta\over p+q+\epsilon\eta}\right)+q\log\left({q\over p+q+\epsilon\eta}\right)\right]\mathrm d\pmb x\\ &&&=\int_{\mathcal T}\eta\log{p+\epsilon\eta\over p+q+\epsilon\eta}\mathrm d\pmb x\\ &&&\xlongequal{\epsilon=0}\int_{\mathcal T}\eta\log{p\over p+q}\mathrm d\pmb x\\ \text{because}&&\int_{\mathcal T}\eta{\delta\mathcal L_{maxD}\over\delta p(\pmb x)}\mathrm d\pmb x&={\mathrm d\over\mathrm d\epsilon}\mathcal L_{maxD}(p+\epsilon\eta,q)\\ \text{therefore}&&\nabla_{p_{\mathcal Z_G}}\max_D\mathcal L(p_{\mathcal Z_G},p_{\mathcal X})&=\nabla_p\mathcal L_{maxD}(p, q)={\delta\mathcal L_{maxD}\over\delta p(\pmb x)}=\log{p\over p+q}=\log{p_{\mathcal Z_G}\over p_{\mathcal Z_G}+p_{\mathcal X}} \end{align*} \]

生成器生成的分布和损失函数的求导关系已被求出,现在加入约束条件求极值点

\[\begin{align*} \text{let}&&F&:=\mathcal L_{maxD}(p, q)+a\left(\int_{\mathcal T}p\mathrm d\pmb x-1\right)\\ &&{\delta F\over\delta p}&=\log{p\over p+q}+a\\ &&{\partial F\over\partial a}&=\int_{\mathcal T}p\mathrm d\pmb x-1=0\\ \text{let}&&{\delta F\over\delta p}&=0\\ \text{then}&&p&={q\over e^a-1}\\ \text{so}&&{\partial F\over\partial a}&=\int_{\mathcal T}{q\over e^a-1}\mathrm d\pmb x-1={1\over e^a-1}-1=0\\ &&a&=\log2\\ &&p&={q\over e^{\log2}-1}=q \end{align*} \]

因此极值点为 \(p=q\) 即 \(p_{\mathcal Z_G}=p_{\mathcal X}\),也恰好为最小值点。

因此对于 \(p_{\mathcal Z_G}\),这是可以用梯度下降法求得最小值点的求导,在这个最小值点,实际上就有 \(p_{\mathcal Z_G}=p_{\mathcal X}\)。

那么在梯度下降的过程中约束是否会被破坏呢?我们知道 \(p_{\mathcal Z_G}\) 的改变实际上是通过改变 \(G(\pmb z;\theta_G)\) 中的 \(\theta_G\) 来实现的。如果 \(\pmb z\sim\mathcal Z\),那么 \(G(\pmb z;\theta_G)\sim\mathcal Z_G\),因此如果 \(p_{\mathcal Z}(\pmb z)\) 服从约束 \(p\in[0,1]\wedge\int p=1\),那么 \(p_{\mathcal Z_G}(\pmb z)\) 就服从约束。

参考公式

KL 散度

\[\begin{align*} D_{KL}(p\parallel q)&=\sum_{x\in\mathcal X}p(x)\log{p(x)\over q(x)}=\mathbb E_p\log{p(x)\over q(x)} \end{align*} \]

如果存在某个 \(x\) 使得 \(p(x)\ne0\wedge q(x)=0\),则 \(\log{p(x)\over q(x)}=\infty\),说明此时的 KL 散度无穷大,但实际上这两个分布的差别不一定有这么大。可以选择一种更平滑的 KL 散度公式:将一个极小量(噪声) \(\epsilon\) 添加到分布中,使得原来 \(q(x)=0\) 的量变为 \(q(x)=\epsilon\)。

JS 散度

\[\begin{align*} m&=\frac12(p+q)\\ D_{JS}(p\parallel q)&=\frac12D_{KL}(p\parallel m)+\frac12D_{KL}(q\parallel m)\\ &=\frac12\mathbb E_p\log{p(x)\over p(x)+q(x)}+\frac12\mathbb E_q\log{q(x)\over p(x)+q(x)}+\log2\\ \end{align*} \]

同样,如果 \(p\) 与 \(q\) 的分布完全没有重合时,\(D_{JS}(p\parallel q)\) 为常数。\(D_{JS}(p\parallel q)\ge0\) 且仅当 \(p,q\) 完全相同时取等号。

参考论文

标签:基本,log,over,形式,GAN,pmb,mathcal,align,sim
From: https://www.cnblogs.com/violeshnv/p/16901210.html

相关文章

  • 黏包现象,UPD基本代码使用,并发编程理论之操作系统发展史,多道技术,进程理论及调度算法
    目录黏包现象,UPD基本代码使用,并发编程理论之操作系统发展史,多道技术,进程理论及调度算法今日内容概要今日内容详细黏包现象struct模块黏包代码实战UDP协议并发编程理论多道......
  • Rundeck部署和基本使用【转】
    rundeck介绍Rundeck是一款能在数据中心或云环境中的日常业务中使程序自己主动化的开源软件。Rundeck 提供了大量功能。能够减轻耗时繁重的体力劳动。团队能够相互协作......
  • 02.编程基本概念
    一、变量与可变性1、变量在Rust语言中,变量默认是不可变的(immutable)。当变量不可变时,一旦值被绑定到一个名称上,你就不能改变这个值。fnmain(){letmutx=5;......
  • 正则表达式基本语法的详解
    正则表达式基本语法的详解本文给给大家介绍正则表达式的基本语法,需要的朋友可以参考下 正则表达式是一种文本模式,包括普通字符(例如,a到z之间的字母)和特殊字符(称为“......
  • python windows psutil获取基本监控指标
    #++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++#@auhorbyruiy####pipinstallparamiko-ihttps://pypi.tuna.tsinghua.edu.cn/simple##p......
  • centos7安装及基本配置
    镜像源:https://mirrors.tuna.tsinghua.edu.cn/centos-vault/7.0.1406/isos/x86_64/ centos镜像后缀详解linux发行版:linux内核基础上+系统层(系统库,设备驱动程序,......
  • Java 基本程序设计结构
    1.基本程序框架package*;//表示这个文件属于哪个包import*;//引入一些库,不必重复造轮子/**name:Java基本程序框架*describe:NULL*///public称为......
  • GO学习笔记之基本数据类型
    整型普通整型类型描述uint8无符号8位整型(0到255)uint16无符号16位整型(0到65535)uint32无符号32位整型(0到4294967295)uint64无符号......
  • 三、排序基本概念和方法概述
    一、排序的稳定性  当排序记录中的关键字${K_i}(i=1,2,...,n)$都不相同时,则任何一个记录的无序序列经排序后得到的结果唯一;反之,当待排序的序列中存在两个或两个以上......
  • Java中 String与基本数据类型,包装类,char[],byte[]之间的转换
    String与基本数据类型,包装类之间的转换。String转换为基本数据类型,包装类:调用包装类的parseXxx(str)方法Stringstr1="456";//string转换为int类型intstr......