首页 > 其他分享 >Pytorch 实现 GAN 网络

Pytorch 实现 GAN 网络

时间:2023-11-08 11:25:58浏览次数:43  
标签:loss plt log nn torch 网络 GAN Pytorch SIZE

Pytorch 实现 GAN 网络

原理

GAN的基本原理其实非常简单,假设我们有两个网络,G(Generator)和D(Discriminator)。它们的功能分别是:

G 是一个生成网络,它接收一个随机的噪声z,通过这个噪声生成伪造数据,记做 G(z)。

D 是一个判别网络,判别数据是不是“真实的”。它的输入参数是x,输出记为 D(x) 代表 x 为真实的概率。如果为 1 就代表 x 为真的概率是100%,而输出为 0 代表为真概率是0% 即为假。

在训练过程中,生成网络 G 的目标就是尽量生成真实的数据去欺骗判别网络D。而 D 的目标就是尽量把 G 生成的数据和真实的数据分别开来。这样,G 和 D 构成了一个动态的“博弈过程”。

最后博弈的结果是什么?在最理想的状态下,G可以生成足以“以假乱真”的数据 G(z) 。对于 D 来说,它难以判定 G 生成的数据究竟是不是真实的,因此 D(G(z)) = 0.5。

当判别器真假难辨时,D_fake,D_real->0.5,G_loss=log(1-0.5)=0.6931..., 此时 D_loss=log(1-0.5)+log(0.5)= 1.3832...

实现

这里我们的任务是:构造一个GAN网络,希望 生成器 能够输入噪声生成一个二次函数曲线

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np

BATCH_SIZE = 64
G_IN_SIZE = 15 #生成器 输入尺寸
G_OUT_SIZE = 15 #生成器 输出尺寸

PAINT_POINTS = np.vstack([np.linspace(-1,1, G_OUT_SIZE) for _ in range(BATCH_SIZE)]) #shape (BATCH_SIZE, G_OUT_SIZE)

plt.plot(PAINT_POINTS[0], 2 * np.power(PAINT_POINTS[0], 2) + 1, c='#74BCFF', lw=3, label='Real Curve')    #2 * x^2 + 1
plt.legend(loc='upper right') #标签位置
plt.show()

# 准备真实数据
def real_points():
    paints = 2 * np.power(PAINT_POINTS,2) + 1
    paints = torch.from_numpy(paints).float()
    return paints

#定义网络
G = nn.Sequential(
    nn.Linear(G_IN_SIZE,128),
    nn.ReLU(),
    nn.Linear(128,G_OUT_SIZE)
)

D = nn.Sequential(
    nn.Linear(G_OUT_SIZE,128),
    nn.ReLU(),
    nn.Linear(128,1),
    nn.Sigmoid()            #0为False,1为True  D的评估应该是在【0-1】之间的数值,所以这里采用的是Sigmod激活
)

# 优化函数
optimizer_G = torch.optim.Adam(G.parameters(),lr=0.0001)
optimizer_D = torch.optim.Adam(D.parameters(),lr=0.0001)

#训练
for step in range(10001):
    real_data = real_points() # 生成真实数据
    randn_input = torch.randn(BATCH_SIZE, G_IN_SIZE) #输入噪声

    eps = 1e-6  #防止log 0
    
    D_real = D(real_data) # 0为False,1为True,这里输入真实数据,D_real越靠近1越好
    
    #训练判别器D,根据公式 D_loss 分为两个部分:判断真实数据 log(1-D_real);判断假数据 log(D_fake) 
    # D带着G一起更新,使用D(G(input))
    D_loss = -torch.mean(torch.log(eps + 1.0 - D_real) + torch.log(eps + D(G(randn_input))))
    optimizer_D.zero_grad()
    D_loss.backward()
    optimizer_D.step()

    #训练生成器G
    G_fake_out = G(randn_input) # 生成器生成假数据
    D_fake = D(G_fake_out) # 用判别器判别假数据,最好能让判别器判断概率趋近0.5,即生成器生成的假数据,能让判别器真假难辨
    # G的损失 越接近1越好,当判别器真假难辨时,D_fake,D_real->0.5,G_loss=log(1-0.5)=0.6931..., 此时 D_loss=log(1-0.5)+log(0.5)= 1.3832...
    G_loss = -torch.mean(torch.log(1.0 - D_fake + eps))
    
    optimizer_G.zero_grad()
    G_loss.backward() #反向
    optimizer_G.step() #更新G参数
        
    if step % 1000 == 0:  # plotting
        plt.cla()
        plt.plot(PAINT_POINTS[0], G_fake_out.data.numpy()[0], c='#4AD631', lw=3, label='Generated Curve',)
        plt.plot(PAINT_POINTS[0], 2 * np.power(PAINT_POINTS[0], 2) + 1, c='#74BCFF', lw=3, label='Real Curve')
        plt.text(-1.0, 0.4, 'G_loss= %.3f ' % G_loss.data.numpy(), fontdict={'size': 13})
        plt.text(-1.0, 0.2, 'D_loss= %.3f ' % D_loss.data.numpy(), fontdict={'size': 13})
        plt.ylim((0, 3));plt.legend(loc='upper right', fontsize=10);plt.draw();plt.pause(0.1)

扩展阅读:
生成对抗网络损失函数的理解

标签:loss,plt,log,nn,torch,网络,GAN,Pytorch,SIZE
From: https://www.cnblogs.com/gaobw/p/17816950.html

相关文章

  • Jtti:美国服务器关网事件:美国服务器的建站网络安全优势
    美国服务器在全球范围内享有广泛的受欢迎度,但也需要注意网络安全问题。美国服务器在建站网络安全方面有一些优势,但同时也面临一些挑战。以下是关于美国服务器建站网络安全的一些优势和注意事项:优势:高质量的数据中心:美国拥有众多高质量的数据中心,提供物理安全性、灾难恢复、备份电......
  • 网络数据库练习题
    练习一1  简述什么是网络数据库。 2  SQLServer2000的常见版本有哪些。 3  解释以下若干名词术语:关系,元组,属性,主键。 4  简述SQLServer2000中的4个系统数据库的主要用途。 5  简述SQLServer2000中的一些常用数据类型(datetime,int,float,money)的用法或......
  • 某音用SSL证书上了一把“安全锁”,加密保护网络传输数据安全
    依照《网络安全法》、《数据安全法》等相关法律法规,网络运营者应当按照网络安全等级保护制度的要求,采用数据加密等措施来防止网络数据泄露或者被窃取、篡改。某音作为头部的音乐创意短视频社交平台,每天都有数以亿计的用户在上面观看、发布视频,而这会产生大量包含个人账号、密码等用......
  • 工业网络交换机助力革新燃气管网监控安全与运行效率
    随着现代社会能源需求的不断增长,燃气管网监控系统的重要性日益凸显。为了确保燃气供应的安全可靠,并提升管网运行效率,工业网络交换机作为关键设备在燃气管网监控系统中发挥着重要作用。本文将深入探讨工业网络交换机在燃气管网监控系统中的应用,并展示其在革新管网安全与运行效率方面......
  • 面试官:你会如何设计QQ中的网络协议?
    引言在设计QQ这道面试题时,我们需要避免进入面试误区。这意味着我们不应该盲目地开展头脑风暴,提出一些不切实际的想法,因为这些想法可能无法经受面试官的深入追问。因此,我们需要站在前人的基础上,思考如何解决这类面试题。我们可以设计一个实际可行的QQ系统,而不是离题太远。设计细......
  • Linux服务器网络配置记录
    Linux服务器网络配置记录材料准备材料数量服务器1显示器1网线2(千兆*1)千兆交换机1插线板1网线连接从路由器LAN口引出网线到交换机任一口,再从交换机剩余任一口引出千兆网线到服务器网线插口1服务器网线插口1插入后有有灯闪烁代表网线连接正常网......
  • 一文概览NLP句法分析:从理论到PyTorch实战解读
    关注TechLead,分享AI全维度知识。作者拥有10+年互联网服务架构、AI产品研发经验、团队管理经验,同济本复旦硕,复旦机器人智能实验室成员,阿里云认证的资深架构师,项目管理专业人士,上亿营收AI产品研发负责人。本文全面探讨了自然语言处理(NLP)中句法分析的理论与实践。从句法和语法的......
  • Linux操作系统 虚拟机连接网络和xshell连接虚拟机
    虚拟机连接网络:桥接模式:1.编辑--虚拟网络编辑器--桥接模式--自动或指定具体网卡 2.设置--网络适配器--桥接模式 3.查看宿主机ip地址 4.配置linuxip地址5.配置的linux,ip地址和宿主机的IP地址,子网掩码,默认网关,dns都是一样6.重启网卡7.互ping8.pingwww.baidu.co......
  • clumsy 0.3 发布,十年前推出的差网络环境模拟工具
    clumsy0.3现已发布,距离v0.1版本已经过去了十年的时间。clumsy能在Windows平台下人工造成不稳定的网络状况,方便你调试应用程序在极端网络状况下的表现。0.3二进制文件与一年半前发布的0.3RC4相同。将滞后时间上限提高到15秒改用zig0.9.0生成二进制文件......
  • 关于关于怎么样让自己的虚拟机连上网络,以及Xshell怎么连上虚拟机
    当你使用虚拟机来模拟不同的操作系统环境或进行开发和测试时,连接虚拟机到网络以及使用远程终端工具如Xshell是非常重要的。在本篇博客中,我将向你介绍如何使你的虚拟机连接到网络,以及如何使用Xshell来连接到虚拟机。连接虚拟机到网络在开始之前,确保你已经安装了虚拟机软件,比如VMwar......