首页 > 其他分享 >深度学习--GAN实战

深度学习--GAN实战

时间:2023-04-27 13:01:16浏览次数:44  
标签:实战 dim nn -- self dataset GAN __ np

深度学习--GAN实战

DCGAN

import torch
from torch import  nn, optim, autograd
import  numpy as np
import visdom
import random
#用python -m visdom.server启动服务

h_dim = 400
batchsz = 512
viz = visdom.Visdom(use_incoming_socket=False)

class Generator(nn.Module):

    def __init__(self):
        super(Generator, self).__init__()

        self.net = nn.Sequential(
            nn.Linear(2,h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim,h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim,2)
        )

    def forward(self,z):
        output = self.net(z)
        return output

class Discriminator(nn.Module):

    def __init__(self):
        super(Discriminator,self).__init__()

        self.net = nn.Sequential(
            nn.Linear(2,h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim,h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim,1),
            nn.Sigmoid()
        )

    def forward(self,x):
        output = self.net(x)
        return output.view(-1)



#生成数据集
def data_generator():
    '''
    8-gaussian mixturn models
    :return:
    '''
    scale = 2.
    centers = [
        (1,0),
        (-1,0),
        (0,1),
        (0,-1),
        (1./np.sqrt(2),1./np.sqrt(2)),
        (1./ np.sqrt(2),-1. / np.sqrt(2)),
        (-1./np.sqrt(2),1./np.sqrt(2)),
        (-1./ np.sqrt(2),-1. / np.sqrt(2))
    ]
    #维度放缩
    centers = [(scale*x,scale*y) for x,y in centers]

    while True:
        dataset = []
        for i in range(batchsz):
            point = np.random.randn(2)*0.02
            center = random.choice(centers)
            #N(0,1) + center x1/x2
            point[0] += center[0]
            point[1] += center[1]
            dataset.append(point)

        dataset = np.array(dataset).astype(np.float32)
        dataset /= 1.141
        #死循环生成器
        yield dataset

def main():

    #减小随机性
    torch.manual_seed(23)
    np.random.seed(23)

    data_iter = data_generator()
    x = next(data_iter)
    #[b,2]
    #print(x.shape)
    G = Generator().cuda()
    D = Discriminator().cuda()

    #print(G)
    #print(D)
    optim_G = optim.Adam(G.parameters(),lr=5e-4,betas=(0.5,0.9))
    optim_D = optim.Adam(D.parameters(), lr=5e-4, betas=(0.5, 0.9))

    for epoch in range(50000):
        #train Discriminator
        for  _ in range(5):
            # 训练真实数据
            x = next(data_iter)
            x = torch.from_numpy(x).cuda()
            # [b,2] =>[b,1]
            predr = D(x)
            # max predr
            lossr = -predr.mean()
            #1.2 train on fake data
            z = torch.randn(batchsz,2).cuda()
            xf = G(z).detach()   #不算梯度
            predf = D(xf)
            lossf = predf.mean()

            # aggregate all
            loss_D = lossr +lossf

            #optimize
            optim_D.zero_grad()
            loss_D.backward()
            optim_D.step()

        #2.genarator的训练
        z = torch.randn(batchsz,2).cuda()
        xf = G(z)
        predf = D(xf)
        #max predf.mean()
        loss_G = -predf.mean()

        optim_G.zero_grad()
        loss_G.backward()
        optim_D.step()

        if epoch%100 == 0:
            print(loss_D.item(),loss_G.item())


if __name__ == '__main__':
    main()

WGAN

相较于DCGAN,就是在损失函数上加一个惩罚项

import  torch
from    torch import nn, optim, autograd
import  numpy as np
import  visdom
from    torch.nn import functional as F
from    matplotlib import pyplot as plt
import  random

h_dim = 400
batchsz = 512
viz = visdom.Visdom()

class Generator(nn.Module):

    def __init__(self):
        super(Generator, self).__init__()

        self.net = nn.Sequential(
            nn.Linear(2, h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, 2),
        )

    def forward(self, z):
        output = self.net(z)
        return output


class Discriminator(nn.Module):

    def __init__(self):
        super(Discriminator, self).__init__()

        self.net = nn.Sequential(
            nn.Linear(2, h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        output = self.net(x)
        return output.view(-1)

def data_generator():

    scale = 2.
    centers = [
        (1, 0),
        (-1, 0),
        (0, 1),
        (0, -1),
        (1. / np.sqrt(2), 1. / np.sqrt(2)),
        (1. / np.sqrt(2), -1. / np.sqrt(2)),
        (-1. / np.sqrt(2), 1. / np.sqrt(2)),
        (-1. / np.sqrt(2), -1. / np.sqrt(2))
    ]
    centers = [(scale * x, scale * y) for x, y in centers]
    while True:
        dataset = []
        for i in range(batchsz):
            point = np.random.randn(2) * .02
            center = random.choice(centers)
            point[0] += center[0]
            point[1] += center[1]
            dataset.append(point)
        dataset = np.array(dataset, dtype='float32')
        dataset /= 1.414  # stdev
        yield dataset

    # for i in range(100000//25):
    #     for x in range(-2, 3):
    #         for y in range(-2, 3):
    #             point = np.random.randn(2).astype(np.float32) * 0.05
    #             point[0] += 2 * x
    #             point[1] += 2 * y
    #             dataset.append(point)
    #
    # dataset = np.array(dataset)
    # print('dataset:', dataset.shape)
    # viz.scatter(dataset, win='dataset', opts=dict(title='dataset', webgl=True))
    #
    # while True:
    #     np.random.shuffle(dataset)
    #
    #     for i in range(len(dataset)//batchsz):
    #         yield dataset[i*batchsz : (i+1)*batchsz]


def generate_image(D, G, xr, epoch):
    """
    Generates and saves a plot of the true distribution, the generator, and the
    critic.
    """
    N_POINTS = 128
    RANGE = 3
    plt.clf()

    points = np.zeros((N_POINTS, N_POINTS, 2), dtype='float32')
    points[:, :, 0] = np.linspace(-RANGE, RANGE, N_POINTS)[:, None]
    points[:, :, 1] = np.linspace(-RANGE, RANGE, N_POINTS)[None, :]
    points = points.reshape((-1, 2))
    # (16384, 2)
    # print('p:', points.shape)

    # draw contour
    with torch.no_grad():
        points = torch.Tensor(points).cuda() # [16384, 2]
        disc_map = D(points).cpu().numpy() # [16384]
    x = y = np.linspace(-RANGE, RANGE, N_POINTS)
    cs = plt.contour(x, y, disc_map.reshape((len(x), len(y))).transpose())
    plt.clabel(cs, inline=1, fontsize=10)
    # plt.colorbar()


    # draw samples
    with torch.no_grad():
        z = torch.randn(batchsz, 2).cuda() # [b, 2]
        samples = G(z).cpu().numpy() # [b, 2]
    plt.scatter(xr[:, 0], xr[:, 1], c='orange', marker='.')
    plt.scatter(samples[:, 0], samples[:, 1], c='green', marker='+')

    viz.matplot(plt, win='contour', opts=dict(title='p(x):%d'%epoch))


def weights_init(m):
    if isinstance(m, nn.Linear):
        # m.weight.data.normal_(0.0, 0.02)
        nn.init.kaiming_normal_(m.weight)
        m.bias.data.fill_(0)

def gradient_penalty(D, xr, xf):
    """

    :param D:
    :param xr:
    :param xf:
    :return:
    """
    LAMBDA = 0.3

    # only constrait for Discriminator
    xf = xf.detach()
    xr = xr.detach()

    # [b, 1] => [b, 2]
    alpha = torch.rand(batchsz, 1).cuda()
    alpha = alpha.expand_as(xr)

    interpolates = alpha * xr + ((1 - alpha) * xf)
    interpolates.requires_grad_()

    disc_interpolates = D(interpolates)

    gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
                              grad_outputs=torch.ones_like(disc_interpolates),
                              create_graph=True, retain_graph=True, only_inputs=True)[0]

    gp = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * LAMBDA

    return gp

def main():

    torch.manual_seed(23)
    np.random.seed(23)

    G = Generator().cuda()
    D = Discriminator().cuda()
    G.apply(weights_init)
    D.apply(weights_init)

    optim_G = optim.Adam(G.parameters(), lr=1e-3, betas=(0.5, 0.9))
    optim_D = optim.Adam(D.parameters(), lr=1e-3, betas=(0.5, 0.9))


    data_iter = data_generator()
    print('batch:', next(data_iter).shape)

    viz.line([[0,0]], [0], win='loss', opts=dict(title='loss',
                                                 legend=['D', 'G']))

    for epoch in range(50000):

        # 1. train discriminator for k steps
        for _ in range(5):
            x = next(data_iter)
            xr = torch.from_numpy(x).cuda()

            # [b]
            predr = (D(xr))
            # max log(lossr)
            lossr = - (predr.mean())

            # [b, 2]
            z = torch.randn(batchsz, 2).cuda()
            # stop gradient on G
            # [b, 2]
            xf = G(z).detach()
            # [b]
            predf = (D(xf))
            # min predf
            lossf = (predf.mean())

            # gradient penalty
            gp = gradient_penalty(D, xr, xf)

            loss_D = lossr + lossf + gp
            optim_D.zero_grad()
            loss_D.backward()
            # for p in D.parameters():
            #     print(p.grad.norm())
            optim_D.step()


        # 2. train Generator
        z = torch.randn(batchsz, 2).cuda()
        xf = G(z)
        predf = (D(xf))
        # max predf
        loss_G = - (predf.mean())
        optim_G.zero_grad()
        loss_G.backward()
        optim_G.step()


        if epoch % 100 == 0:
            viz.line([[loss_D.item(), loss_G.item()]], [epoch], win='loss', update='append')

            generate_image(D, G, xr, epoch)

            print(loss_D.item(), loss_G.item())






if __name__ == '__main__':
    main()

标签:实战,dim,nn,--,self,dataset,GAN,__,np
From: https://www.cnblogs.com/ssl-study/p/17358620.html

相关文章

  • 两数之和
    给定一个整数数组nums 和一个整数目标值target,请你在该数组中找出和为目标值target 的那 两个 整数,并返回它们的数组下标。你可以假设每种输入只会对应一个答案。但是,数组中同一个元素在答案里不能重复出现。你可以按任意顺序返回答案输入:nums=[2,7,11,15],target......
  • 2023.4.27——软件工程日报
    所花时间(包括上课):3h代码量(行):0行博客量(篇):1篇今天,上午学习,下午学习并开会。我了解到的知识点:1.了解了一些数据库的知识;2.了解了一些python的知识;3.了解了一些英语知识;5.了解了一些Javaweb的知识;4.了解了一些数学建模的知识;6.了解了一些计算机网络的知识;......
  • 记录一次未初始化漏洞_four
    对一道关于未初始化漏洞的题目的总结,来自前几天的DASCTF。这道题总体不算难,我觉得更多的考了代码审计能力(也有可能是本人初学,看伪c没经验,所以觉得很复杂,中间看了看wp对这道题才恍然大悟)因为作为一道栈题来说,伪c算挺长的了。题目链接:https://pan.baidu.com/s/1oLz7BPI5oyJlrO2a5......
  • .NET使用一行命令轻松生成EF Core项目框架
    dotnetef是EntityFrameworkCore(EFCore)的一个命令行工具,用于管理EFCore应用程序的数据库和代码。除了提供管理数据库的命令之外,dotnetef还可以生成和管理实体和上下文代码。本文将介绍如何使用dotnetef动态生成代码。一、环境准备1、项目准备用vs2022新建一个.NET6的asp.......
  • 《Effective C#》系列之(一)——异常处理与资源管理
    请注意,《EffectiveC#》中的异常处理与资源管理部分实际上是第四章的内容。以下是关于该章节的详细解释。第四章:异常处理与资源管理一.了解异常处理机制异常处理机制使程序员能够在程序运行过程中处理错误情况。C#提供了try-catch-finally语句块来捕获和处理异常。了解不同类......
  • 出售金鱼
     一、问题描述   小明将养的一缸金鱼分五次出售;第一次卖出全部的一半加1/2条;第二次卖出余下的三分之一加1/3条;第三次卖出余下的四分之一加1/4条;第四次卖出余下的五分之一加1/5条;最后卖出余下的11条。求原来有几条。二、设计思路    金鱼分五次出售,每次卖出的方式相同......
  • 配置.husky和commitlint以及Eslint
    代码规范ESLint+Prettier(项目是基于uniapp+vue3+ts)无脑执行以下操作,让你在vue3+ts的项目中愉快的使用eslint和prettier。npminstalleslintprettier--save-devnpminstalleslint-config-prettiereslint-plugin-prettiereslint-plugin-vue--save-devnpminstall......
  • 云主机AK/SK泄露利用
    https://github.com/iiiusky/alicloud-toolsAK/SK利用行云管家直接大部分主流的云厂商。https://yun.cloudbility.com/ 使用方式该工具主要是方便快速使用阿里云api执行一些操作Usage:AliCloud-Tools[flags]AliCloud-Tools[command]AvailableCommands:ecs......
  • 记一次线上服务器问题排查过程
    问题描述前几天我们更新线上服务器,使用对应的新版客户端连接时,怎么都连不上,如果直接连接其他服,比如我们内部的测试服或者审核服,却一切正常。同时,如果使用老包连接服务器,也是正常的这个问题查的头疼,每一步都超出我的理解范围排查过程首先第一步,我们在服务器接口的必经之路上......
  • Golang单元测试
    1.前言2.先决条件3.创建单元测试的示例程序4.创建单元测试5.使用gotest运行测试6.Table-driven的单元测试7.测试覆盖率8.Go基准测试9.为代码写示例10.总结11.参考文档1.前言原文:HowToWriteUnitTestsinGoAuthor:TobiBalogun译者:philoenglis......