首页 > 其他分享 >(二)WGAN实战

(二)WGAN实战

时间:2023-11-01 19:46:06浏览次数:31  
标签:实战 dim nn torch dataset xr np WGAN

一、WGAN实战

import  torch
from    torch import nn, optim, autograd
import  numpy as np
import  visdom
from    torch.nn import functional as F
from    matplotlib import pyplot as plt
import  random

h_dim = 400
batchsz = 512
viz = visdom.Visdom()

class Generator(nn.Module):

    def __init__(self):
        super(Generator, self).__init__()

        self.net = nn.Sequential(
            nn.Linear(2, h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, 2),
        )

    def forward(self, z):
        output = self.net(z)
        return output


class Discriminator(nn.Module):

    def __init__(self):
        super(Discriminator, self).__init__()

        self.net = nn.Sequential(
            nn.Linear(2, h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        output = self.net(x)
        return output.view(-1)

def data_generator():

    scale = 2.
    centers = [
        (1, 0),
        (-1, 0),
        (0, 1),
        (0, -1),
        (1. / np.sqrt(2), 1. / np.sqrt(2)),
        (1. / np.sqrt(2), -1. / np.sqrt(2)),
        (-1. / np.sqrt(2), 1. / np.sqrt(2)),
        (-1. / np.sqrt(2), -1. / np.sqrt(2))
    ]
    centers = [(scale * x, scale * y) for x, y in centers]
    while True:
        dataset = []
        for i in range(batchsz):
            point = np.random.randn(2) * .02
            center = random.choice(centers)
            point[0] += center[0]
            point[1] += center[1]
            dataset.append(point)
        dataset = np.array(dataset, dtype='float32')
        dataset /= 1.414  # stdev
        yield dataset

    # for i in range(100000//25):
    #     for x in range(-2, 3):
    #         for y in range(-2, 3):
    #             point = np.random.randn(2).astype(np.float32) * 0.05
    #             point[0] += 2 * x
    #             point[1] += 2 * y
    #             dataset.append(point)
    #
    # dataset = np.array(dataset)
    # print('dataset:', dataset.shape)
    # viz.scatter(dataset, win='dataset', opts=dict(title='dataset', webgl=True))
    #
    # while True:
    #     np.random.shuffle(dataset)
    #
    #     for i in range(len(dataset)//batchsz):
    #         yield dataset[i*batchsz : (i+1)*batchsz]


def generate_image(D, G, xr, epoch):
    """
    Generates and saves a plot of the true distribution, the generator, and the
    critic.
    """
    N_POINTS = 128
    RANGE = 3
    plt.clf()

    points = np.zeros((N_POINTS, N_POINTS, 2), dtype='float32')
    points[:, :, 0] = np.linspace(-RANGE, RANGE, N_POINTS)[:, None]
    points[:, :, 1] = np.linspace(-RANGE, RANGE, N_POINTS)[None, :]
    points = points.reshape((-1, 2))
    # (16384, 2)
    # print('p:', points.shape)

    # draw contour
    with torch.no_grad():
        points = torch.Tensor(points).cuda() # [16384, 2]
        disc_map = D(points).cpu().numpy() # [16384]
    x = y = np.linspace(-RANGE, RANGE, N_POINTS)
    cs = plt.contour(x, y, disc_map.reshape((len(x), len(y))).transpose())
    plt.clabel(cs, inline=1, fontsize=10)
    # plt.colorbar()


    # draw samples
    with torch.no_grad():
        z = torch.randn(batchsz, 2).cuda() # [b, 2]
        samples = G(z).cpu().numpy() # [b, 2]
    plt.scatter(xr[:, 0], xr[:, 1], c='orange', marker='.')
    plt.scatter(samples[:, 0], samples[:, 1], c='green', marker='+')

    viz.matplot(plt, win='contour', opts=dict(title='p(x):%d'%epoch))


def weights_init(m):
    if isinstance(m, nn.Linear):
        # m.weight.data.normal_(0.0, 0.02)
        nn.init.kaiming_normal_(m.weight)
        m.bias.data.fill_(0)

def gradient_penalty(D, xr, xf):
    """

    :param D:
    :param xr:
    :param xf:
    :return:
    """
    LAMBDA = 0.3

    # only constrait for Discriminator
    xf = xf.detach()
    xr = xr.detach()

    # [b, 1] => [b, 2]
    alpha = torch.rand(batchsz, 1).cuda()
    alpha = alpha.expand_as(xr)

    interpolates = alpha * xr + ((1 - alpha) * xf)
    interpolates.requires_grad_()

    disc_interpolates = D(interpolates)

    gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
                              grad_outputs=torch.ones_like(disc_interpolates),
                              create_graph=True, retain_graph=True, only_inputs=True)[0]

    gp = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * LAMBDA

    return gp

def main():

    torch.manual_seed(23)
    np.random.seed(23)

    G = Generator().cuda()
    D = Discriminator().cuda()
    G.apply(weights_init)
    D.apply(weights_init)

    optim_G = optim.Adam(G.parameters(), lr=1e-3, betas=(0.5, 0.9))
    optim_D = optim.Adam(D.parameters(), lr=1e-3, betas=(0.5, 0.9))


    data_iter = data_generator()
    print('batch:', next(data_iter).shape)

    viz.line([[0,0]], [0], win='loss', opts=dict(title='loss',
                                                 legend=['D', 'G']))

    for epoch in range(50000):

        # 1. train discriminator for k steps
        for _ in range(5):
            x = next(data_iter)
            xr = torch.from_numpy(x).cuda()

            # [b]
            predr = (D(xr))
            # max log(lossr)
            lossr = - (predr.mean())

            # [b, 2]
            z = torch.randn(batchsz, 2).cuda()
            # stop gradient on G
            # [b, 2]
            xf = G(z).detach()
            # [b]
            predf = (D(xf))
            # min predf
            lossf = (predf.mean())

            # gradient penalty
            gp = gradient_penalty(D, xr, xf)

            loss_D = lossr + lossf + gp
            optim_D.zero_grad()
            loss_D.backward()
            # for p in D.parameters():
            #     print(p.grad.norm())
            optim_D.step()


        # 2. train Generator
        z = torch.randn(batchsz, 2).cuda()
        xf = G(z)
        predf = (D(xf))
        # max predf
        loss_G = - (predf.mean())
        optim_G.zero_grad()
        loss_G.backward()
        optim_G.step()


        if epoch % 100 == 0:
            viz.line([[loss_D.item(), loss_G.item()]], [epoch], win='loss', update='append')

            generate_image(D, G, xr, epoch)

            print(loss_D.item(), loss_G.item())






if __name__ == '__main__':
    main()

 

标签:实战,dim,nn,torch,dataset,xr,np,WGAN
From: https://www.cnblogs.com/zhangxianrong/p/17803931.html

相关文章

  • django搭建平台实战教程一:生成数据库数据
    首先需要创建一个django-rest-framework项目,如何创建可以参考https://www.django-rest-framework.org/tutorial/quickstart/,不再赘述。创建完结构如图所示 settings.py配置mysql数据库...DATABASES={"default":{"ENGINE":"django.db.backends.mysql",......
  • Django实战项目-学习任务系统-自定义URL拦截器
    接着上期代码框架,6个主要功能基本实现,剩下的就是细节点的完善优化了。首先增加URL拦截器,你不会希望没有登录用户就可以进入用户主页各种功能的,所以增加URL拦截器可以解决这个问题。Django框架本身也有URL拦截器,但是因为本系统用户模型跟Django框架本身用户模型不匹配,所以没有用,......
  • Golang语言快速上手到综合实战-高并发聊天室、豆瓣电影爬虫
    Golang语言快速上手到综合实战-高并发聊天室、豆瓣电影爬虫我们公司需要快速迭代一款产品,当时,我们团队的后端框架是springmvc,该框架结构清晰,上手快,但是由于我们的产品迭代速度快,底层数据库操作接口变动频繁,导致service层工作量巨大,不胜其烦。另外,随着项目的成长,代码......
  • Go开发工程师入门到项目实战 Google架构师联合大厂架构师出品
    01|Go语言课程介绍蔡超Mobvista技术副总裁兼首席架构师,前亚马逊(中国)首席软件架构师本节内容你好,我是蔡超,目前在Mobvista担任技术副总裁兼首席架构师。在加入Mobvista前,我也曾在亚马逊,惠普等公司担任过首席软件架构师。我是从小学四年级开始学习计算机编程的,并一直从事......
  • 记一次实战渗透测试
    前言之前一个朋友叫我帮忙测试一下他们公司的网站信息收集使用插件Wappalyzer探测信息首先使用Nmap扫描了一下端口,因为扫描高危端口,往往有意外收获。开放了21,80,3306,3389,8001端口,于是尝试匿名登录ftp,无果,爆破3306和3389远程桌面,爆了半天也没爆出来然后御剑目录扫描扫描网站......
  • 解码注意力Attention机制:从技术解析到PyTorch实战
    在本文中,我们深入探讨了注意力机制的理论基础和实际应用。从其历史发展和基础定义,到具体的数学模型,再到其在自然语言处理和计算机视觉等多个人工智能子领域的应用实例,本文为您提供了一个全面且深入的视角。通过Python和PyTorch代码示例,我们还展示了如何实现这一先进的机制。关......
  • 解码注意力Attention机制:从技术解析到PyTorch实战
    在本文中,我们深入探讨了注意力机制的理论基础和实际应用。从其历史发展和基础定义,到具体的数学模型,再到其在自然语言处理和计算机视觉等多个人工智能子领域的应用实例,本文为您提供了一个全面且深入的视角。通过Python和PyTorch代码示例,我们还展示了如何实现这一先进的机制。关......
  • 基于Unity整合BEPUphysicsint物理引擎实战
    上一节我们详细的讲解BEPUphysicsint的物理事件。此物理引擎会产生了碰撞事件与非碰撞事件,碰撞事件大家好理解,非碰撞事件例如:物理Entity的update事件,Entity的activation/deactivation事件等。本节课来实战如何编译BEPUphysicsint源码到自己的项目,如何整合物理引擎与Unity图......
  • [17章+电子书]C#速成指南-从入门到进阶,实战WPF与Unity3D开发
    点击下载:[17章+电子书]C#速成指南-从入门到进阶,实战WPF与Unity3D开发  提取码:a3s5 《C#速成指南--从入门到进阶,实战WPF与Unity3D开发》完整讲解了C#语言的核心知识和高阶编程技巧,并结合WPF客户管理系统和Unity3D切水果游戏两大实战项目,帮你实现技术的精通,完成从Zero到Hero的蜕变......
  • k8s-服务网格实战-入门Istio
    背景终于进入大家都比较感兴趣的服务网格系列了,在前面已经讲解了:如何部署应用到kubernetes服务之间如何调用如何通过域名访问我们的服务如何使用kubernetes自带的配置ConfigMap基本上已经够我们开发一般规模的web应用了;但在企业中往往有着复杂的应用调用关系,应用与......