首页 > 其他分享 >CS231N assignment 3 _ GAN 学习笔记 & 解析

CS231N assignment 3 _ GAN 学习笔记 & 解析

时间:2023-04-30 09:33:12浏览次数:43  
标签:loss CS231N Linear nn assignment GAN fake size

这篇文章之所以来的比较早, 是因为我们机器人比赛字符识别数据集不够, 想自己造点数据集其实

课程内容总结

所谓GAN, 原理很简单, 我们有一个生成器网络和鉴别器网络, 生成器生成假的数据, 鉴别器分辨真假, 二者知己知彼互相优化自己, 从而达到博弈的效果.

实际操作中, 我们一般是训练k步鉴别器, 随后训练一步生成器(或者一步&多步, 这东西其实不绝对, 现在很多GAN变种解决了k超参数问题). 生成器从噪声生成目标的图像等,  例如下面就是一个简单的生成器结构(卷积/反卷积)

我们优化的函数是下面的表现形式:

这个公式有两层含义: 首先我们要优化鉴别器参数, 尽可能提高真实数据判断为真的概率, 以及生成数据判断为假的概率, 随后要优化生成器, 尽可能骗过生成器.

但是实际上, 考虑到函数在初期梯度不够大, 不符合一般的梯度优化规律, 所以我们会适当对优化的具体形式做一些变化:

实际上, 两个模型同步是最大的问题, 因此GAN往往是比较难训练出来的, 其对于超参数很敏感. 课程上的插图也能说明这一点

那么下面我们就开始激动人心的实战吧.

MNIST数据集

想必各位入坑深度学习第一次接触的就是MNIST数据集吧. 其是一个28*28的手写图片的系列集合, 我们可以看到部分的内容:

我们初始生成的数据是一个-1,1之间的完全随机的数, 根据dim要求完成代码很简单.

两个网络

其实两个网络的组成和一般的网络并无二致, 仅是训练方式有所区别. 在初期为了训练方便我们仅采用了普通的全连接网络模型. 考虑到题目提示已经很明白了, 所以我们可以写出代码:

(这里的非线性函数选择中, tanh是为了数据规范化, 和前面对数据的预处理格式一致, 而其他函数可能是经过实践得到的最好结果)

# discriminator
    model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(784,256),
            nn.LeakyReLU(),
            nn.Linear(256,256),
            nn.LeakyReLU(),
            nn.Linear(256,1)
    )
# generator
    model = nn.Sequential(
            nn.Linear(noise_dim,1024),
            nn.ReLU(),
            nn.Linear(1024,1024),
            nn.ReLU(),
            nn.Linear(1024,784),
            nn.Tanh()
    )

 其中discriminator的输出为一维数据, 数据是一个连续的值(∈R), 也就表明我们最终输出的并不是一个绝对的真假判断, 而是一个打分值[注意: 实际上sigmoid包含在了bce的loss内, 但是这里没写出来]. 

GAN误差

我们来看原本作业的解说:

可以看到, 我们考虑到了s和y的实际取值范围, 我们要求s越大越好, 与此同时1-s越小越好, 意即真实的数据打分要比较高, 虚假的数据打分要比较低. 注意到s取值范围, 1-s>0和s>0在log下不能严格成立, 所以需要进行一些修改. 可能这个不好看出来, 这里将输出打分的sigmoid合并在一起了, 大概就是这么个思路:

结果的表达形式的要求就是ln内的数值>0, 因此我们分为两个部分, 其中s>0多处一项并且符号相反, 整合了sigmoid之后能够有这么简单的表达形式, 可以说大道至简,非常巧妙.

 # input.clamp(min=0) => max(input,0)
    loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log() 

这个函数包含了所有真实和虚假的情况, 现在因为真实和虚假是分开投入的, 所以target只需要分别设置为全0和全1即可.

def discriminator_loss(logits_real, logits_fake):
    size=logits_real.size()
    true_labels = torch.ones(size).type(dtype)
    fake_labels = torch.zeros(size).type(dtype)
    loss = bce_loss(logits_real,true_labels)+bce_loss(logits_fake,fake_labels-1)
    return loss

def generator_loss(logits_fake):
    size=logits_fake.size()
    true_labels = torch.ones(size).type(dtype)
    loss = bce_loss(logits_fake,true_labels)
    return loss

训练

随后我们按照题目设置adam优化器就可以开始优化流程了. 优化流程可以说很简单, 就是应用loss. 需要注意, 在这里我们直接D和G相继训练一次, 也就是超参数k=1.

我们监控一下流程. 不难发现, 最初得到的就是纯纯的噪声, 随后慢慢成型. 然而目前我们发现生成图像杂点还是很多, 且D和G的误差是不收敛的.

LS-GAN

LS-GAN的区别仅仅在于误差函数不同, 原本的误差函数是Log, 现在我们变成了平方.

所以我们现在也就不必担心数值稳定性了. 这里因为已知真假, 所以函数只和打分相关.

loss=0.5*(torch.pow(scores_real-1,2).mean()+torch.pow(scores_fake,2).mean()) # discriminator
loss = 0.5*torch.pow(scores_fake-1,2).mean() # generator

因为误差表现形式的不同, 所以误差绝对值会小一些,但是仍然不收敛, 且训练过程表现也不是很一样, 最终效果依然称不上好, 这应该就是网络瓶颈.

DC-GAN

DC-GAN在误差上相对于前两者来说并无区别, 但是采用的D和G网络换成了卷积网络. D网络和常规的LeNet-5基本架构是一致的, 根据题目提示容易得到答案:

        Unflatten(batch_size,1,28,28),
        nn.Conv2d(in_channels=1,out_channels=32,kernel_size=5,stride=1),
        nn.LeakyReLU(),
        nn.MaxPool2d(2),
        nn.Conv2d(in_channels=32,out_channels=64,kernel_size=5,stride=1),
        nn.LeakyReLU(),
        nn.MaxPool2d(2),
        Flatten(),
        nn.Linear(4*4*64,4*4*64),
        nn.LeakyReLU(),
        nn.Linear(4*4*64,1)

而G网络需要进行上采样. 这个步骤是从pytorch函数convTranspose2d来实现的, 原理就是大家看过的名场面:

这一过程可以表示为: (假设输入和输出的通道数都是1)

upsample = nn.ConvTranspose2d(1, 1, kernel_size=3, stride=2, padding=1)

 得到的图像尺寸利用下面的公式得到: (公式源自pytorch文档)

Hout​=(Hin​−1)×stride[0]−2×padding[0]+dilation[0]×(kernel_size[0]−1)+output_padding[0]+1

Wout​=(Win​−1)×stride[1]−2×padding[1]+dilation[1]×(kernel_size[1]−1)+output_padding[1]+1

最终的架构如下:

    return nn.Sequential(
        nn.Linear(noise_dim, 1024), 
        nn.ReLU(),
        nn.BatchNorm1d(1024),  # 对于全连接向量使用batchnorm1d而非2d
        nn.Linear(1024, 7*7*128),
        nn.ReLU(),
        nn.BatchNorm1d(7*7*128),
        Unflatten(batch_size, 128, 7, 7),
        nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
        nn.ReLU(),
        nn.BatchNorm2d(64), 
        nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1),
        nn.Tanh(), # 规范化输出范围
        Flatten(),
    )

可以看出, 最终DC-GAN的生成图像质量明显更高,即使在较少的训练次数下也有相对不错的表现.

inline question

问题1

求解:

最终回到了这个点. 由此我们就看出, 在这种情况下问题构成了一个循环, 互相对抗无法达到双方的最优值, 这可能也是某些时候GAN难以训练的原因之一吧.

答案很显然不是, 因为此时我们G可以很好地骗过D, 但是你说D本身段位不够, 我就能说G很高明吗? 肯定不是这样. 只有当双方都收敛,才是我们希望看到的.

标签:loss,CS231N,Linear,nn,assignment,GAN,fake,size
From: https://www.cnblogs.com/360MEMZ/p/17364219.html

相关文章

  • Gangsters UVA - 672
     一家饭店,有一扇大小会变得门,变化范围为[0,k]。每过一单位时间你可以让门的大小+1,-1,或者不变。客人会在不同的时间来吃饭,但是如果门的大小和他们希望的值不一样,他们就不会进来并且直接消失。吃饭要花钱,现在问饭店最多能赚多少钱。  F[i][j]=max(F[i-1][j]+v,F[i-1][j-1......
  • 深度学习--GAN实战
    深度学习--GAN实战DCGANimporttorchfromtorchimportnn,optim,autogradimportnumpyasnpimportvisdomimportrandom#用python-mvisdom.server启动服务h_dim=400batchsz=512viz=visdom.Visdom(use_incoming_socket=False)classGenerator(nn.Module......
  • Cycle GAN:Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial
    paper:https://arxiv.org/pdf/1703.10593.pdf[2017]code参考:https://github.com/junyanz/pytorch-CycleGAN-and-pix2pixhttps://zhuanlan.zhihu.com/p/79221194https://blog.csdn.net/fangjin_kl/article/details/128117396https://www.bilibili.com/video/BV1kb4y197P......
  • 使用encoder编码器-decoder解码器加GAN网络的生成式图像修复
    论文链接https://openaccess.thecvf.com/content_cvpr_2016/papers/Pathak_Context_Encoders_Feature_CVPR_2016_paper.pdf简介作者提出了一种基于上下文像素预测的无监督视觉特征学习算法,它既完成了特征提取,也完成了图像修复。通过与自动编码器的类比,提出了上下文编码器(Conte......
  • [干货满满] CIFAR10炼丹记后篇 - CS231N 番外
    期中考试结束了,来填坑,因为真正接触到了玄学和银河的部分,也算是试验了几天的成果把(在上一个文章中,我们已经提到了,通过本课程学到的各种技巧,我们将准确度提升到了80%,这已经超过了大多数CS231N博客的效果了.但是毕竟这个是在基本的卷积网络架构去操作的,所以后续想要......
  • Linux安装基于rsyslog+loganalyzer的日志系统
    一、 关闭防火墙和selinuxsetenforce0vim/etc/selinux/config将配置文件中的SELINUX=enforcing 修改为systemctl stop firewalldsystemctl status firewalldsystemctl disable firewalld二、安装LAMPyuminstallmysql-servermysql-develhttpdphp-mysql phpphp-gdp......
  • CMU 提出全新 GAN 结构,GAN 自此迈入预训练大军!
    文|林锐众所周知,现在GAN的应用是越来越宽泛了,尤其是在CV领域。不仅可以调个接口生成新头像图一乐,也可以用GAN做数据增强让模型更加健壮。▲嘉然你带我走吧嘉然!在CV领域,不像分类、目标检测等任务可以使用预训练好的backbone来加速训练、提升精度,GAN的训练基本上是从头开始......
  • gganimate|让你的图动起来!!!
    这是ggplot中十分可爱的一个扩增包,目的只有一个,就是让你的图动起来!就是酱紫!!gganimate扩展了ggplot2实现的图形语法,包括动画描述。它通过提供一系列新的语法类来实现这一点,这些类可以添加到绘图对象中,以便自定义它应该如何随时间变化。下面是他的parameter:transition_*()定义了数据......
  • UBantu 无法运行 Ganache 解决方案
    问题描述直接在UBantu上执行ganache-2.5.4-linux-x86_64.AppImage程序可能因为权限问题而无法运行解决办法可以将ganache-2.5.4-linux-x86_64.AppImage进行解压,如下:$./ganache-2.5.4-linux-x86_64.AppImage--appimage-extract解压以后会创建squashfs-root文件夹,......
  • [oeasy]python0132_变量含义_meaning_声明_declaration_赋值_assignment
    变量定义回忆上次内容上次回顾了一下历史python是如何从无到有的看到Guido长期的坚持和努力 编程语言的基础都是变量声明python是如何声明变量的呢? 变量想要定义变量首先明确什么是变量变量就是数值能变的量英文名称varia......