首页 > 其他分享 >深度学习《WGAN模型》

深度学习《WGAN模型》

时间:2022-12-14 16:31:07浏览次数:43  
标签:opt ngf nn 模型 netd 深度 WGAN data wgan


WGAN是一个对原始GAN进行重大改进的网络

深度学习《WGAN模型》_.net

主要是在如下方面做了改进

深度学习《WGAN模型》_WGAN_02

实例测试代码如下:

还是用我16张鸣人的照片搞一波事情,每一个上述的改进点,我再代码中都是用 Difference 标注的。

import torch as t
from torch import nn
from torch.autograd import Variable
from torch.optim import RMSprop
from torchvision import transforms
from torchvision.utils import make_grid
from torchvision.datasets import CIFAR10
from pylab import plt
import os
import torchvision.datasets as datasets
from torchvision.utils import save_image


# 至于 WGAN和GAN的区别请全文搜索 Importment Difference 即可查看

# step 1: ========================================== 定义本程序运行需要的一些参数
class WGAN_Config:
lr = 0.0001
nz = 100 # noise dimension
image_size = 64

nc = 3 # chanel of img
ngf = 64 # generator channel
ndf = 64 # discriminator channel

batch_size = 16
max_epoch = 5000 # =1 when debug
clamp_num = 0.01 # WGAN clip gradient

wgan_opt = WGAN_Config()


def deprocess_img(img):
out = 0.5 * (img + 1)
out = out.clamp(0, 1)
out = out.view(-1, 3, wgan_opt.image_size, wgan_opt.image_size)
return out



# step 2: ========================================== 老流程,加载数据集。
# data preprocess
transform = transforms.Compose([
transforms.Resize(wgan_opt.image_size),
transforms.ToTensor(),
transforms.Normalize([0.5] * 3, [0.5] * 3)
])

# dataset = CIFAR10(root='cifar10/', transform=transform, download=True)
# dataloader = t.utils.data.DataLoader(dataset, wgan_opt.batch_size, shuffle=True)

data_path = os.path.abspath("D:/software/Anaconda3/doc/3D_Naruto")
print (os.listdir(data_path))
# 请注意,在data_path下面再建立一个目录,存放所有图片,ImageFolder会在子目录下读取数据,否则下一步会报错。
dataset = datasets.ImageFolder(root=data_path, transform=transform)
dataloader = t.utils.data.DataLoader(dataset, batch_size=wgan_opt.batch_size, shuffle=True)





# step 3: ========================================== 定义WGAN的G网络和D网络的模型
class generator(nn.Module):
def __init__(self):
super(generator, self).__init__()
self.netg = nn.Sequential(
nn.ConvTranspose2d(wgan_opt.nz, wgan_opt.ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(wgan_opt.ngf * 8),
nn.ReLU(True),

nn.ConvTranspose2d(wgan_opt.ngf * 8, wgan_opt.ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(wgan_opt.ngf * 4),
nn.ReLU(True),

nn.ConvTranspose2d(wgan_opt.ngf * 4, wgan_opt.ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(wgan_opt.ngf * 2),
nn.ReLU(True),

nn.ConvTranspose2d(wgan_opt.ngf * 2, wgan_opt.ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(wgan_opt.ngf),
nn.ReLU(True),

nn.ConvTranspose2d(wgan_opt.ngf, wgan_opt.nc, 4, 2, 1, bias=False),
nn.Tanh()
)

def forward(self, imgs):
out = self.netg(imgs)
return out

class discriminator(nn.Module):
def __init__(self):
super(discriminator, self).__init__()
self.netd = nn.Sequential(
nn.Conv2d(wgan_opt.nc, wgan_opt.ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),

nn.Conv2d(wgan_opt.ndf, wgan_opt.ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(wgan_opt.ndf * 2),
nn.LeakyReLU(0.2, inplace=True),

nn.Conv2d(wgan_opt.ndf * 2, wgan_opt.ndf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(wgan_opt.ndf * 4),
nn.LeakyReLU(0.2, inplace=True),

nn.Conv2d(wgan_opt.ndf * 4, wgan_opt.ndf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(wgan_opt.ndf * 8),
nn.LeakyReLU(0.2, inplace=True),

nn.Conv2d(wgan_opt.ndf * 8, 1, 4, 1, 0, bias=False),
# Importment Difference 1: do not use sigmoid func here any more.
# nn.Sigmoid()
)

def forward(self, imgs):
out = self.netd(imgs)
return out.view(imgs.shape[0])

netd = discriminator()
netg = generator()



# step 4: ========================================== 初始化两个网络的参数
# 这一步是新学习的。参数权重初始化过程
def weight_init(m):
# weight_initialization: important for wgan
class_name = m.__class__.__name__
if class_name.find('Conv') != -1:
m.weight.data.normal_(0, 0.02)
elif class_name.find('Norm') != -1:
m.weight.data.normal_(1.0, 0.02)

netd.apply(weight_init)
netg.apply(weight_init)





# step 5: ========================================== 定义优化器,这里使用 RMSprop,不使用Adam
# 也推荐使用 SGD
# Importment Difference 2: Use RMSprop instead of Adam
# optimizer
optimizerD = RMSprop(netd.parameters(), lr=wgan_opt.lr)
optimizerG = RMSprop(netg.parameters(), lr=wgan_opt.lr)

# Importment Difference: No Log in loss
# criterion
# criterion = nn.BCELoss()




# step 6: ========================================== 开始训练了
# begin training
rand_noise = Variable(t.FloatTensor(wgan_opt.batch_size, wgan_opt.nz, 1, 1).normal_(0, 1))

iter_count = 0
# 将BCEloss 改为非log的loss,按照文章的记载,通常会使用直接同1和-1做比较
one = t.ones(wgan_opt.batch_size)
mone = -1 * one
for epoch in range(wgan_opt.max_epoch):
for ii, data in enumerate(dataloader, 0):
imgs = data[0] # real image
noise = Variable(t.randn(imgs.size(0), wgan_opt.nz, 1, 1)) # fake image
print(imgs.shape)

# Importment Difference 4: clip param for discriminator
for parm in netd.parameters():
parm.data.clamp_(-wgan_opt.clamp_num, wgan_opt.clamp_num)

# ----- train discriminator network -----
netd.zero_grad()

output = netd(imgs) # train netd with real img
output.backward(one) # 跟 1 进行比较

fake_pic = netg(noise).detach() # train netd with real img, 梯度在此截断,不要继续往前传播。
output2 = netd(fake_pic)
output2.backward(mone) # 跟 -1 进行比较

optimizerD.step()

# ------ train generator later -------
# we train the discriminator many times, and less train for generator.
# train netd more times: because the better netd is the better netg will be
if (ii + 1) % 1 == 0:
netg.zero_grad()
noise.data.normal_(0, 1)

fake_pic = netg(noise)
output = netd(fake_pic)

output.backward(one) # 跟 1 进行比较
optimizerG.step()

if iter_count % 50 == 0:
rand_imgs = netg(rand_noise)
rand_imgs = deprocess_img(rand_imgs.data)
save_image(rand_imgs, 'D:/software/Anaconda3/doc/3D_Img/wgan2/test_%d.png' % (iter_count))

iter_count = iter_count + 1
print('iter_count: ', iter_count)

效果如下:

最后都是用随机噪音产生的图片,时间太长了,训练次数不太够啊。

深度学习《WGAN模型》_2d_03


标签:opt,ngf,nn,模型,netd,深度,WGAN,data,wgan
From: https://blog.51cto.com/u_12419595/5937537

相关文章