首页 > 其他分享 >GAN的原理入门

GAN的原理入门

时间:2023-06-02 23:05:11浏览次数:53  
标签:真实 入门 训练 生成 GAN 分布 原理 图片

GAN的基本原理其实非常简单,这里以生成图片为例进行说明。假设我们有两个网络,G(Generator)和D(Discriminator)。正如它的名字所暗示的那样,它们的功能分别是:

  • G是一个生成图片的网络,它接收一个随机的噪声z,通过这个噪声生成图片(如正态分布,auto-encoder是中间输出是一般也是),记做G(z)。
  • D是一个判别网络,判别一张图片是不是“真实的”。它的输入参数是x,x代表一张图片,输出D(x)代表x为真实图片的概率,如果为1,就代表100%是真实的图片,而输出为0,就代表不可能是真实的图片。

在训练过程中,生成网络G的目标就是尽量生成真实的图片去欺骗判别网络D。而D的目标就是尽量把G生成的图片和真实的图片分别开来。这样,G和D构成了一个动态的“博弈过程”。

最后博弈的结果是什么?在最理想的状态下,G可以生成足以“以假乱真”的图片G(z)。对于D来说,它难以判定G生成的图片究竟是不是真实的,因此D(G(z)) = 0.5。

这样我们的目的就达成了:我们得到了一个生成式的模型G,它可以用来生成图片。

Goodfellow从理论上证明了该算法的收敛性  ,以及在模型收敛时,生成数据具有和真实数据相同的分布(保证了模型效果)。

开发者自述:我是这样学习 GAN 的

from:https://www.leiphone.com/news/201707/1JEkcUZI1leAFq5L.html

 

 

Generative Adversarial Network,就是大家耳熟能详的 GAN,由 Ian Goodfellow 首先提出,在这两年更是深度学习中最热门的东西,仿佛什么东西都能由 GAN 做出来。我最近刚入门 GAN,看了些资料,做一些笔记。

1.Generation

什么是生成(generation)?就是模型通过学习一些数据,然后生成类似的数据。让机器看一些动物图片,然后自己来产生动物的图片,这就是生成。

以前就有很多可以用来生成的技术了,比如 auto-encoder(自编码器),结构如下图:

GAN的原理入门_数据

你训练一个 encoder,把 input 转换成 code,然后训练一个 decoder,把 code 转换成一个 image,然后计算得到的 image 和 input 之间的 MSE(mean square error),训练完这个 model 之后,取出后半部分 NN Decoder,输入一个随机的 code,就能 generate 一个 image。

但是 auto-encoder 生成 image 的效果,当然看着很别扭啦,一眼就能看出真假。所以后来还提出了比如VAE这样的生成模型,我对此也不是很了解,在这就不细说。

上述的这些生成模型,其实有一个非常严重的弊端。比如 VAE,它生成的 image 是希望和 input 越相似越好,但是 model 是如何来衡量这个相似呢?model 会计算一个 loss,采用的大多是 MSE,即每一个像素上的均方差。loss 小真的表示相似嘛?

GAN的原理入门_数据_02

比如这两张图,第一张,我们认为是好的生成图片,第二张是差的生成图片,但是对于上述的 model 来说,这两张图片计算出来的 loss 是一样大的,所以会认为是一样好的图片。

这就是上述生成模型的弊端,用来衡量生成图片好坏的标准并不能很好的完成想要实现的目的。于是就有了下面要讲的 GAN。

2.GAN

大名鼎鼎的 GAN 是如何生成图片的呢?首先大家都知道 GAN 有两个网络,一个是 generator,一个是 discriminator,从二人零和博弈中受启发,通过两个网络互相对抗来达到最好的生成效果。流程如下:

GAN的原理入门_生成图片_03

主要流程类似上面这个图。首先,有一个一代的 generator,它能生成一些很差的图片,然后有一个一代的 discriminator,它能准确的把生成的图片,和真实的图片分类,简而言之,这个 discriminator 就是一个二分类器,对生成的图片输出 0,对真实的图片输出 1。

接着,开始训练出二代的 generator,它能生成稍好一点的图片,能够让一代的 discriminator 认为这些生成的图片是真实的图片。然后会训练出一个二代的 discriminator,它能准确的识别出真实的图片,和二代 generator 生成的图片。以此类推,会有三代,四代。。。n 代的 generator 和 discriminator,最后 discriminator 无法分辨生成的图片和真实图片,这个网络就拟合了。

这就是 GAN,运行过程就是这么的简单。这就结束了嘛?显然没有,下面还要介绍一下 GAN 的原理。

3.原理

首先我们知道真实图片集的分布 Pdata(x),x 是一个真实图片,可以想象成一个向量,这个向量集合的分布就是 Pdata。我们需要生成一些也在这个分布内的图片,如果直接就是这个分布的话,怕是做不到的。

我们现在有的 generator 生成的分布可以假设为 PG(x;θ),这是一个由 θ 控制的分布,θ 是这个分布的参数(如果是高斯混合模型,那么 θ 就是每个高斯分布的平均值和方差)

假设我们在真实分布中取出一些数据,{x1, x2, ... , xm},我们想要计算一个似然 PG(xi; θ)。

对于这些数据,在生成模型中的似然就是

GAN的原理入门_算法_04

我们想要最大化这个似然,等价于让 generator 生成那些真实图片的概率最大。这就变成了一个最大似然估计的问题了,我们需要找到一个 θ* 来最大化这个似然。

GAN的原理入门_数据_05

寻找一个 θ* 来最大化这个似然,等价于最大化 log 似然。因为此时这 m 个数据,是从真实分布中取的,所以也就约等于,真实分布中的所有 x 在 P分布中的 log 似然的期望。

真实分布中的所有 x 的期望,等价于求概率积分,所以可以转化成积分运算,因为减号后面的项和 θ 无关,所以添上之后还是等价的。然后提出共有的项,括号内的反转,max 变 min,就可以转化为 KL divergence 的形式了,KL divergence 描述的是两个概率分布之间的差异。

所以最大化似然,让 generator 最大概率的生成真实图片,也就是要找一个 θ 让 P更接近于 Pdata。

那如何来找这个最合理的 θ 呢?我们可以假设 PG(x; θ) 是一个神经网络。

首先随机一个向量 z,通过 G(z)=x 这个网络,生成图片 x,那么我们如何比较两个分布是否相似呢?只要我们取一组 sample z,这组 z 符合一个分布,那么通过网络就可以生成另一个分布 PG,然后来比较与真实分布 Pdata。

大家都知道,神经网络只要有非线性激活函数,就可以去拟合任意的函数,那么分布也是一样,所以可以用一直正态分布,或者高斯分布,取样去训练一个神经网络,学习到一个很复杂的分布。

GAN的原理入门_生成模型_06

如何来找到更接近的分布,这就是 GAN 的贡献了。先给出 GAN 的公式:

GAN的原理入门_算法_07

这个式子的好处在于,固定 G,max  V(G,D) 就表示 PG 和 Pdata 之间的差异,然后要找一个最好的 G,让这个最大值最小,也就是两个分布之间的差异最小。

GAN的原理入门_生成模型_08

表面上看这个的意思是,D 要让这个式子尽可能的大,也就是对于 x 是真实分布中,D(x) 要接近与 1,对于 x 来自于生成的分布,D(x) 要接近于 0,然后 G 要让式子尽可能的小,让来自于生成分布中的 x,D(x) 尽可能的接近 1。

现在我们先固定 G,来求解最优的 D:

GAN的原理入门_生成图片_09

GAN的原理入门_算法_10

对于一个给定的 x,得到最优的 D 如上图,范围在 (0,1) 内,把最优的 D 带入

GAN的原理入门_算法_11

可以得到:

GAN的原理入门_数据_12

GAN的原理入门_生成图片_13

JS divergence 是 KL divergence 的对称平滑版本,表示了两个分布之间的差异,这个推导就表明了上面所说的,固定 G。

GAN的原理入门_算法_14

表示两个分布之间的差异,最小值是 -2log2,最大值为 0。

现在我们需要找个 G,来最小化

GAN的原理入门_生成模型_15

观察上式,当 PG(x)=Pdata(x) 时,G 是最优的。

4.训练

有了上面推导的基础之后,我们就可以开始训练 GAN 了。结合我们开头说的,两个网络交替训练,我们可以在起初有一个 G0 和 D0,先训练 D找到 :

GAN的原理入门_生成图片_16

然后固定 D0 开始训练 G0, 训练的过程都可以使用 gradient descent,以此类推,训练 D1,G1,D2,G2,...

但是这里有个问题就是,你可能在 D0* 的位置取到了:

GAN的原理入门_生成图片_17

然后更新 G0 为 G1,可能

GAN的原理入门_算法_18

了,但是并不保证会出现一个新的点 D1* 使得

GAN的原理入门_生成模型_19

这样更新 G 就没达到它原来应该要的效果,如下图所示:

GAN的原理入门_生成图片_20

避免上述情况的方法就是更新 G 的时候,不要更新 G 太多。

知道了网络的训练顺序,我们还需要设定两个 loss function,一个是 D 的 loss,一个是 G 的 loss。下面是整个 GAN 的训练具体步骤:

GAN的原理入门_数据_21

上述步骤在机器学习和深度学习中也是非常常见,易于理解。

5.存在的问题

但是上面 G 的 loss function 还是有一点小问题,下图是两个函数的图像:

GAN的原理入门_生成图片_22

log(1-D(x)) 是我们计算时 G 的 loss function,但是我们发现,在 D(x) 接近于 0 的时候,这个函数十分平滑,梯度非常的小。这就会导致,在训练的初期,G 想要骗过 D,变化十分的缓慢,而上面的函数,趋势和下面的是一样的,都是递减的。但是它的优势是在 D(x) 接近 0 的时候,梯度很大,有利于训练,在 D(x) 越来越大之后,梯度减小,这也很符合实际,在初期应该训练速度更快,到后期速度减慢。

所以我们把 G 的 loss function 修改为

GAN的原理入门_数据_23

这样可以提高训练的速度。

还有一个问题,在其他 paper 中提出,就是经过实验发现,经过许多次训练,loss 一直都是平的,也就是

GAN的原理入门_数据_24

JS divergence 一直都是 log2,P和 Pdata 完全没有交集,但是实际上两个分布是有交集的,造成这个的原因是因为,我们无法真正计算期望和积分,只能使用 sample 的方法,如果训练的过拟合了,D 还是能够完全把两部分的点分开,如下图:

GAN的原理入门_生成图片_25

对于这个问题,我们是否应该让 D 变得弱一点,减弱它的分类能力,但是从理论上讲,为了让它能够有效的区分真假图片,我们又希望它能够 powerful,所以这里就产生了矛盾。

还有可能的原因是,虽然两个分布都是高维的,但是两个分布都十分的窄,可能交集相当小,这样也会导致 JS divergence 算出来 =log2,约等于没有交集。

解决的一些方法,有添加噪声,让两个分布变得更宽,可能可以增大它们的交集,这样 JS divergence 就可以计算,但是随着时间变化,噪声需要逐渐变小。

还有一个问题叫 Mode Collapse,如下图:

GAN的原理入门_生成图片_26

这个图的意思是,data 的分布是一个双峰的,但是学习到的生成分布却只有单峰,我们可以看到模型学到的数据,但是却不知道它没有学到的分布。

造成这个情况的原因是,KL divergence 里的两个分布写反了

GAN的原理入门_算法_27

这个图很清楚的显示了,如果是第一个 KL divergence 的写法,为了防止出现无穷大,所以有 Pdata 出现的地方都必须要有 PG 覆盖,就不会出现 Mode Collapse。

6.参考

这是对 GAN 入门学习做的一些笔记和理解,后来太懒了,不想打公式了,主要是参考了李宏毅老师的视频:

http://t.cn/RKXQOV0

标签:真实,入门,训练,生成,GAN,分布,原理,图片
From: https://blog.51cto.com/u_11908275/6405474

相关文章

  • ASP.NET Core MVC 从入门到精通之自动映射(一)
    随着技术的发展,ASP.NETCoreMVC也推出了好长时间,经过不断的版本更新迭代,已经越来越完善,本系列文章主要讲解ASP.NETCoreMVC开发B/S系统过程中所涉及到的相关内容,适用于初学者,在校毕业生,或其他想从事ASP.NETCoreMVC系统开发的人员。经过前几篇文章的讲解,初步了解ASP.NETCore......
  • mysql(一):基本原理
    Innodb是如何实现事务的Innodb通过BufferPool,LogBuffer,RedoLog,UndoLog来实现事务,以一个update语句为例:Innodb在收到一个update语句后,会先根据条件找到数据所在的页,并将该页缓存在BufferPool中执行update语句,修改BufferPool中的数据,也就是内存中的数据针对update语句生......
  • AI入门(重实践)书籍推荐
    AI书籍推荐我最近看了下https://book.douban.com/subject/30147778/ 另外如果要看电子书的话建议看这个https://book.douban.com/subject/27154347/评价也非常高从项目着手机器学习和深度学习都有并且难度也不高其中文翻译电子版可以看这里https://github.com/it-ebooks......
  • SMB协议原理抓包分析——本质上和FTP下载文件的思路是一样的
     目录:1.SMB概述2.SMB原理3.SMB配置一、SMB概述SMB(全称是ServerMessageBlock)是一个协议名,可用于在计算机间共享文件、打印机、串口等,电脑上的网上邻居就是靠它实现的。SMB是一种客户机/服务器、请求/响应协议。通过SMB协议,客户端应用程序可以在各种网络环境下读、写服务器......
  • Fragment原理解析androidx版本&ViewPager与Fragment
    资料Fragment生命周期为什么要通过Fragment.setArguments(Bundle)传递参数单独问题:动态方式,静态方式添加随Activity启动动态添加回退栈onSaveInstance静态方式添加FragmentmHost是这个finalFragmentControllermFragments=FragmentController.createController(newHostCallb......
  • Jasypt入门
    Jasypt是一个java库,它允许开发人员以最小的工作量为他/她的项目添加基本的加密功能,而不需要对密码学的工作原理有深入的了解。一、特性Jasypt为您提供了简单的单向(摘要)和双向加密技术。用于任何JCE提供程序的开放API,而不仅仅是默认的JavaVM提供程序。Jasypt可以很......
  • KOOM原理分析之一些基础知识
    文章目录资料Profile工具的使用内存性能分析器概览内存计算方式查看内存分配情况(Record一段)查看全局JNI引用原生内存性能分析器将堆转储另存为HPROF文件HPROFAgentBinaryDumpFormat(format=b)HandlingofArrays资料使用内存性能分析器查看应用的内存使用情况HPROFAgentPr......
  • 一站式元数据治理平台——Datahub入门宝典
    随着数字化转型的工作推进,数据治理的工作已经被越来越多的公司提上了日程。作为新一代的元数据管理平台,Datahub在近一年的时间里发展迅猛,大有取代老牌元数据管理工具Atlas之势。国内Datahub的资料非常少,大部分公司想使用Datahub作为自己的元数据管理平台,但可参考的资料太少。所以整......
  • QR防伪溯源系统追溯原理是什么?
    本文分享自天翼云开发者社区《QR防伪溯源系统追溯原理是什么?》,作者:SD万QR防伪溯源系统是一种基于QR技术的防伪技术,通过为每件产品生成唯一的QR标签,并将其与产品信息、生产信息、物流信息等进行关联,实现产品的全程追溯。本文将从追溯原理、系统构成、应用场景等方面对QR防伪溯源......
  • Angular Google Charts教程_编程入门自学教程_菜鸟教程-免费教程分享
    教程简介GoogleCharts是一个纯粹的基于JavaScript的图表库,旨在通过添加交互式图表功能来增强Web应用程序.它支持各种图表.在Chrome,Firefox,Safari,InternetExplorer(IE)等标准浏览器中使用SVG绘制图表.在传统的IE6中,VML用于绘制图形.AngularGoogleCharts是一个基于开源角度......