首页 > 其他分享 >【生成对抗网络GAN】最全的关于生成对抗网络Generative Adversarial Networks,GAN的介绍!!

【生成对抗网络GAN】最全的关于生成对抗网络Generative Adversarial Networks,GAN的介绍!!

时间:2024-09-17 14:49:31浏览次数:3  
标签:判别 nn 生成器 生成 GAN 图像 对抗

【生成对抗网络GAN】最全的关于生成对抗网络Generative Adversarial Networks,GAN的介绍!!

【生成对抗网络GAN】最全的关于生成对抗网络Generative Adversarial Networks,GAN的介绍!!


文章目录


前言

生成对抗网络(Generative Adversarial Networks,GAN)自2014年由Ian Goodfellow提出以来,成为图像生成领域的核心技术之一。它通过对抗训练生成器和判别器两个网络,极大提升了生成图像的质量。

1.GAN的基础理论

1.1背景与概念

GAN由生成器(Generator)和判别器(Discriminator)两个对抗网络组成。

  • 生成器:尝试生成逼真的假图像,输入通常是一个随机噪声向量。
  • 判别器:判别输入图像是真实图像还是生成器生成的假图像。

GAN的目标是通过博弈的方式让生成器生成的图像逐步逼近真实图像,直到判别器无法区分两者为止。这个过程可以用一个零和博弈来表示,生成器试图“欺骗”判别器,而判别器试图准确区分真实图像与生成图像。

1.2训练过程

GAN的训练过程包含两个主要步骤:

  • 训练判别器:让判别器尽可能区分真实图像与生成器生成的假图像。
  • 训练生成器:通过更新生成器,使其生成的图像逐渐接近真实图像,从而迷惑判别器。

GAN的损失函数是基于博弈论中的最小最大损失函数,具体表达式如下:
在这里插入图片描述
其中, D ( x ) D(x) D(x)是判别器对真实图像的预测, G ( z ) G(z) G(z)是生成器从噪声 z z z生成的图像

2.GAN的实际用途

GAN在图像生成任务中有着广泛的应用,以下是一些主要应用场景:

  • 图像生成:生成高质量的照片、艺术作品等,常用于艺术创作、游戏开发等。
  • 图像修复:修复破损或缺失部分的图像。
  • 超分辨率重建:将低分辨率图像重建为高分辨率图像。
  • 图像到图像翻译:如从素描生成彩色图像,从夏天风景图像生成冬天场景等。
  • 视频生成:生成具有逼真运动的连续图像序列。

3.GAN的代码实现:使用GAN生成手写数字

以下是一个使用PyTorch实现简单GAN的代码示例,用于生成手写数字(如MNIST数据集)。

代码示例:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# 定义生成器
class Generator(nn.Module):
    def __init__(self, input_size=100, output_size=784):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_size, 256),
            nn.ReLU(True),
            nn.Linear(256, 512),
            nn.ReLU(True),
            nn.Linear(512, 1024),
            nn.ReLU(True),
            nn.Linear(1024, output_size),
            nn.Tanh()  # 将输出限制在[-1, 1]之间,适合生成图像数据
        )
    
    def forward(self, x):
        return self.model(x)

# 定义判别器
class Discriminator(nn.Module):
    def __init__(self, input_size=784):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_size, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()  # 输出一个介于0到1之间的概率值
        )
    
    def forward(self, x):
        return self.model(x)

# 数据准备
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

mnist = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
dataloader = DataLoader(mnist, batch_size=64, shuffle=True)

# 初始化模型
generator = Generator()
discriminator = Discriminator()

# 损失函数和优化器
criterion = nn.BCELoss()  # 二分类交叉熵损失
optimizer_g = optim.Adam(generator.parameters(), lr=0.0002)
optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0002)

# 训练过程
for epoch in range(50):
    for i, (real_images, _) in enumerate(dataloader):
        batch_size = real_images.size(0)
        
        # 训练判别器
        real_images = real_images.view(batch_size, -1)  # 展平图像
        real_labels = torch.ones(batch_size, 1)  # 真实标签为1
        fake_labels = torch.zeros(batch_size, 1)  # 假标签为0
        
        outputs = discriminator(real_images)
        d_loss_real = criterion(outputs, real_labels)  # 判别器对真实图像的损失
        
        noise = torch.randn(batch_size, 100)  # 随机噪声
        fake_images = generator(noise)  # 生成假图像
        outputs = discriminator(fake_images.detach())
        d_loss_fake = criterion(outputs, fake_labels)  # 判别器对假图像的损失
        
        d_loss = d_loss_real + d_loss_fake  # 总损失
        optimizer_d.zero_grad()
        d_loss.backward()  # 反向传播
        optimizer_d.step()  # 更新判别器
        
        # 训练生成器
        noise = torch.randn(batch_size, 100)
        fake_images = generator(noise)
        outputs = discriminator(fake_images)
        g_loss = criterion(outputs, real_labels)  # 生成器希望判别器输出1
        
        optimizer_g.zero_grad()
        g_loss.backward()
        optimizer_g.step()
        
    print(f'Epoch [{epoch+1}/50], d_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}')

代码解释:

  • 1.class Generator(nn.Module):定义生成器网络,其输入是随机噪声,输出为生成的图像。
  • 2.class Discriminator(nn.Module):定义判别器网络,输入为图像,输出一个概率值,表示图像是真实图像的概率。
  • 3.criterion = nn.BCELoss():使用二分类交叉熵损失函数,适用于二分类问题。
  • 4.optimizer_g = optim.Adam(generator.parameters(), lr=0.0002):使用Adam优化器更新生成器的权重。
  • 5.optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0002):Adam优化器更新判别器的权重。
  • 6.d_loss = d_loss_real + d_loss_fake:判别器的总损失包括真实图像的损失和生成图像的损失。
  • 7.g_loss = criterion(outputs, real_labels):生成器的损失,目标是让判别器认为生成图像为真实图像。
  • 8.noise = torch.randn(batch_size, 100):生成随机噪声作为生成器的输入,用于生成假图像。

4.GAN的相关论文推荐

(1)Generative Adversarial Networks,2014

论文地址:https://arxiv.org/pdf/1406.2661

主要内容:

  • “GAN之父” Ian Goodfellow 发表的第一篇提出 GAN 的论文,这应该是任何开始研究学习 GAN 的都该阅读的一篇论文,它提出了 GAN 这个模型框架,讨论了非饱和的损失函数,然后对于最佳判别器(optimal discriminator)给出其导数,然后进行证明;最后是在 Mnist、TFD、CIFAR-10 数据集上进行了实验。
    在这里插入图片描述

(2)Conditional GANs,2014

论文地址:https://arxiv.org/pdf/1411.1784

主要内容:

  • 如果说上一篇 GAN 论文是开始出现 GAN 这个让人觉得眼前一亮的模型框架,这篇 cGAN 就是当前 GAN 模型技术变得这么热门的重要因素之一,事实上 GAN 开始是一个无监督模型,生成器需要的仅仅是随机噪声,但是效果并没有那么好,在 14 年提出,到 16 年之前,其实这方面的研究并不多,真正开始一大堆相关论文发表出来,第一个因素就是 cGAN,第二个因素是等会介绍的 DCGAN;
  • cGAN 其实是将 GAN 又拉回到监督学习领域,如下图所示,它在生成器部分添加了类别标签这个输入,通过这个改进,缓和了 GAN 的一大问题–训练不稳定,而这种思想,引入先验知识的做法,在如今大多数非常有名的 GAN 中都采用这种做法,后面介绍的生成图片的 BigGAN,或者是图片转换的 Pix2Pix,都是这种思想,可以说 cGAN 的提出非常关键。
    在这里插入图片描述

(3)DCGAN,2015

论文地址:https://arxiv.org/pdf/1511.06434
主要内容:

  • 其实原作者推荐第一篇论文应该是阅读这篇 DCGAN 论文,2015年发表的。这是第一次采用 CNN 结构实现 GAN 模型,它介绍如何使用卷积层,并给出一些额外的结构上的指导建议来实现。另外,它还讨论如何可视化 GAN 的特征、隐空间的插值、利用判别器特征训练分类器以及评估结果。下图是 DCGAN 的生成器部分结构示意图
    在这里插入图片描述

(4)Improved Techniques for Training GANs,2016

论文地址:https://arxiv.org/pdf/1606.03498
主要内容:

  • 这篇论文的作者之一是 Ian Goodfellow,它介绍了很多如何构建一个 GAN 结构的建议,它可以帮助你理解 GAN 不稳定性的原因,给出很多稳定训练 DCGANs 的建议,比如特征匹配(feature matching)、最小批次判别(minibatch discrimination)、单边标签平滑(one-sided label smoothing)、虚拟批归一化(virtual batch normalization)等等,利用这些建议来实现 DCGAN 模型是一个很好学习了解 GANs 的做法。
    在这里插入图片描述

(5)Pix2Pix,2016

论文地址:https://arxiv.org/pdf/1611.07004
主要内容:

  • Pix2Pix 的目标是实现图像转换的应用,如下图所示。这个模型在训练时候需要采用成对的训练数据,并对 GAN 模型采用了不同的配置。其中它应用到了 PatchGAN 这个模型,PatchGAN 对图片的一块 70*70 大小的区域进行观察来判断该图片是真是假,而不需要观察整张图片。
  • 此外,生成器部分使用 U-Net 结构,即结合了 ResNet 网络中的 skip connections 技术,编码器和解码器对应层之间有相互连接,它可以实现如下图所示的转换操作,比如语义图转街景,黑白图片上色,素描图变真实照片等。
    在这里插入图片描述

(6)CycleGAN,2017

论文地址:https://arxiv.org/pdf/1703.10593
主要内容:

  • 上一篇论文 Pix2Pix 的问题就是训练数据必须成对,即需要原图片和对应转换后的图片,而现实就是这种数据非常难寻找,甚至有的不存在这样一对一的转换数据,因此有了 CycleGAN,仅仅需要准备两个领域的数据集即可,比如说普通马的图片和斑马的图片,但不需要一一对应。这篇论文提出了一个非常好的方法–循环一致性(Cycle-Consistency)损失函数,如下图所示的结构:
    在这里插入图片描述
  • 这种结构在接下来图片转换应用的许多 GAN 论文中都有利用到,cycleGAN 可以实现如下图所示的一些应用,普通马和斑马的转换、风格迁移(照片变油画)、冬夏季节变换等等。
    在这里插入图片描述

(7)Progressively Growing of GANs,2017

论文地址:https://arxiv.org/pdf/1710.10196
主要内容:

  • 这篇论文必读的原因是因为它取得非常好的结果以及对于 GAN 问题的创造性方法。它利用一个多尺度结构,从 44 到 88 一直提升到 1024*1024 的分辨率,如下图所示的结构,这篇论文提出了一些如何解决由于目标图片尺寸导致的不稳定问题。
    在这里插入图片描述

(8)StackGAN,2017

论文地址:https://arxiv.org/pdf/1612.03242
主要内容:

  • StackGAN 和 cGAN 、 Progressively GANs 两篇论文比较相似,它同样采用了先验知识,以及多尺度方法。整个网络结构如下图所示,第一阶段根据给定文本描述和随机噪声,然后输出 6464 的图片,接着将其作为先验知识,再次生成 256256 大小的图片。相比前面 推荐的 7 篇论文,StackGAN 通过一个文本向量来引入文本信息,并提取一些视觉特征
    在这里插入图片描述

(9)BigGAN,2018

论文地址:https://arxiv.org/pdf/1809.11096
主要内容:

  • BigGAN 应该是当前 ImageNet 上图片生成最好的模型了,它的生成结果如下图所示,非常的逼真,但这篇论文比较难在本地电脑上进行复现,它同时结合了很多结构和技术,包括自注意机制(Self-Attention)、谱归一化(Spectral Normalization)等,这些在论文都有很好的介绍和说明。
    在这里插入图片描述

(10)StyleGAN,2018

论文地址:https://arxiv.org/pdf/1812.04948
主要内容:

  • StyleGAN 借鉴了如 Adaptive Instance Normalization (AdaIN)的自然风格转换技术,来控制隐空间变量 z 。其网络结构如下图所示,它在生产模型中结合了一个映射网络以及 AdaIN 条件分布的做法,并不容易复现,但这篇论文依然值得一读,包含了很多有趣的想法。
    在这里插入图片描述

另外,再推荐一个收集了大量 GAN 论文的 Github 项目,并且根据应用方向划分论文

https://github.com/zhangqianhui/AdversarialNetsPapers

以及 3 个复现多种 GANs 模型的 github 项目,分别是目前主流的三个框架,TensorFlow、PyTorch 和 Keras:

  • TensorFlow 版本:https://github.com/TwistedW/tensorflow-GANs
  • PyTorch 版本:https://github.com/eriklindernoren/PyTorch-GAN
  • Keras 版本:https://github.com/eriklindernoren/Keras-GAN

论文推荐部分转自:https://zhuanlan.zhihu.com/p/72745900

总结

GAN通过生成器和判别器的对抗训练,极大地提升了生成图像的质量,并且通过诸如DCGAN、CycleGAN和StyleGAN等扩展变体不断优化生成效果。其广泛应用于图像生成、图像翻译、超分辨率重建等任务,推动了计算机视觉领域的快速发展。

标签:判别,nn,生成器,生成,GAN,图像,对抗
From: https://blog.csdn.net/gaoxiaoxiao1209/article/details/142304585

相关文章

  • 软件工程结对项目 3:python实现自动生成小学四则运算题目的程序
    这个作业属于哪个课程广工计院计科34班软工这个作业要求在哪里作业要求团队成员1庄崇立3122004633团队成员2罗振烘3122004748这个作业的目标结对合作完成小学四则运算题目的程序,熟悉项目开发流程,提高团队合作能力一、GitHub地址二、需求1.题目:实现一......
  • win2012服务器使用 Certbot 生成 Let's Encrypt 的域名证书
    1、安装windows版本的certbot,目前最新版是Certbot2.9.02、命令行输入[email protected]:\website\xxx\-dwww.xxx.cn其中,[email protected]为电子邮箱地址,d:\website\xxx\为网站根目录,www.xxx.cn为域名3、后面会有两次输入,第一......
  • 图像生成领域老牌的GAN模型简要回顾
    ......
  • 图像生成大模型Imagen
    图像生成大模型ImagenImagen是由GoogleResearch开发的一款基于深度学习的图像生成模型,其在文本到图像(Text-to-Image)的转换技术上取得了显著突破。Imagen通过结合大型Transformer语言模型的强大能力和高保真图像生成技术,实现了前所未有的照片级真实感和深度语言理解能力,成......
  • 利用扣子(coze.cn)平台配置工作流生成小红书风格的文案
    文章目录前言一、扣子是什么?二、使用步骤1.创建工作流2.配置工作流节点2.运行工作流,发布总结前言当你品尝了一道美味的甜品,却不知道怎么用生动的语言分享给朋友们。这时候,扣子大模型就派上用场啦!它能瞬间为你打造出一段让人垂涎欲滴的美食文案,让你的小红书瞬间收获......
  • 【智能算法应用】粒子群算法求解最小生成树问题
    目录1.最小生成树MST2.算法原理3.算法过程4.结果展示5.参考文献6.代码获取1.最小生成树MST最小生成树(MinimumSpanningTree,MST)是在给定的加权无向图中寻找一个边的子集,使得这些边构成的树包含图中的所有顶点,并且边的总权重尽可能小。如果图......
  • 一个使用 PyTorch 实现的中文聊天机器人对话生成模型916
    这是一个使用PyTorch实现的中文聊天机器人对话生成模型。1数据准备代码假设有两个文件:questions.txt和answers.txt,它们分别包含输入和输出序列。load_data函数读取这些文件并返回一个句子列表。build_vocab函数通过遍历句子来构建词汇表字典word2index和index2......
  • MobaXterm 密钥生成器
    1、MobaXterm密钥生成器,代码仓库地址:https://gitcode.com/gh_mirrors/mo/MobaXterm-keygen/blob/master/MobaXterm-Keygen.py2、也可以用我打包好的exe程序,不用安装python环境:https://pan.baidu.com/s/1jo85pQc_kfWhcYmZcc49CQ提取码:ws103、exe程序使用:随意输入用户名,输入......
  • 深入理解Python生成器、装饰器和异常处理
    一、Python生成器1.1什么是生成器?生成器(Generator)是Python中一种特殊的迭代器,它允许你在遍历大型数据集时节省内存。与普通函数不同,生成器函数使用yield关键字返回值,而不是return。生成器每次被调用时,函数的执行会在yield语句处暂停,并保存函数的状态,下一次再调用时从上次......
  • 纯C 生成二叉树广义表 根据广义表重构二叉树
    讲解很多都写在注释里了,重构二叉树的过程后面单独拿出来讲直接上代码:#include<stdio.h>#include<time.h>#include<stdlib.h>#include<limits.h>typedefstructBiTree{ intdata; structBiTree*next[2];}BiTree;BiTree*BiTree_init(intval)//节点初始化{......