WGAN是一个对原始GAN进行重大改进的网络
主要是在如下方面做了改进
实例测试代码如下:
还是用我16张鸣人的照片搞一波事情,每一个上述的改进点,我再代码中都是用 Difference 标注的。
import torch as t
from torch import nn
from torch.autograd import Variable
from torch.optim import RMSprop
from torchvision import transforms
from torchvision.utils import make_grid
from torchvision.datasets import CIFAR10
from pylab import plt
import os
import torchvision.datasets as datasets
from torchvision.utils import save_image
# 至于 WGAN和GAN的区别请全文搜索 Importment Difference 即可查看
# step 1: ========================================== 定义本程序运行需要的一些参数
class WGAN_Config:
lr = 0.0001
nz = 100 # noise dimension
image_size = 64
nc = 3 # chanel of img
ngf = 64 # generator channel
ndf = 64 # discriminator channel
batch_size = 16
max_epoch = 5000 # =1 when debug
clamp_num = 0.01 # WGAN clip gradient
wgan_opt = WGAN_Config()
def deprocess_img(img):
out = 0.5 * (img + 1)
out = out.clamp(0, 1)
out = out.view(-1, 3, wgan_opt.image_size, wgan_opt.image_size)
return out
# step 2: ========================================== 老流程,加载数据集。
# data preprocess
transform = transforms.Compose([
transforms.Resize(wgan_opt.image_size),
transforms.ToTensor(),
transforms.Normalize([0.5] * 3, [0.5] * 3)
])
# dataset = CIFAR10(root='cifar10/', transform=transform, download=True)
# dataloader = t.utils.data.DataLoader(dataset, wgan_opt.batch_size, shuffle=True)
data_path = os.path.abspath("D:/software/Anaconda3/doc/3D_Naruto")
print (os.listdir(data_path))
# 请注意,在data_path下面再建立一个目录,存放所有图片,ImageFolder会在子目录下读取数据,否则下一步会报错。
dataset = datasets.ImageFolder(root=data_path, transform=transform)
dataloader = t.utils.data.DataLoader(dataset, batch_size=wgan_opt.batch_size, shuffle=True)
# step 3: ========================================== 定义WGAN的G网络和D网络的模型
class generator(nn.Module):
def __init__(self):
super(generator, self).__init__()
self.netg = nn.Sequential(
nn.ConvTranspose2d(wgan_opt.nz, wgan_opt.ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(wgan_opt.ngf * 8),
nn.ReLU(True),
nn.ConvTranspose2d(wgan_opt.ngf * 8, wgan_opt.ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(wgan_opt.ngf * 4),
nn.ReLU(True),
nn.ConvTranspose2d(wgan_opt.ngf * 4, wgan_opt.ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(wgan_opt.ngf * 2),
nn.ReLU(True),
nn.ConvTranspose2d(wgan_opt.ngf * 2, wgan_opt.ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(wgan_opt.ngf),
nn.ReLU(True),
nn.ConvTranspose2d(wgan_opt.ngf, wgan_opt.nc, 4, 2, 1, bias=False),
nn.Tanh()
)
def forward(self, imgs):
out = self.netg(imgs)
return out
class discriminator(nn.Module):
def __init__(self):
super(discriminator, self).__init__()
self.netd = nn.Sequential(
nn.Conv2d(wgan_opt.nc, wgan_opt.ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(wgan_opt.ndf, wgan_opt.ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(wgan_opt.ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(wgan_opt.ndf * 2, wgan_opt.ndf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(wgan_opt.ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(wgan_opt.ndf * 4, wgan_opt.ndf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(wgan_opt.ndf * 8),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(wgan_opt.ndf * 8, 1, 4, 1, 0, bias=False),
# Importment Difference 1: do not use sigmoid func here any more.
# nn.Sigmoid()
)
def forward(self, imgs):
out = self.netd(imgs)
return out.view(imgs.shape[0])
netd = discriminator()
netg = generator()
# step 4: ========================================== 初始化两个网络的参数
# 这一步是新学习的。参数权重初始化过程
def weight_init(m):
# weight_initialization: important for wgan
class_name = m.__class__.__name__
if class_name.find('Conv') != -1:
m.weight.data.normal_(0, 0.02)
elif class_name.find('Norm') != -1:
m.weight.data.normal_(1.0, 0.02)
netd.apply(weight_init)
netg.apply(weight_init)
# step 5: ========================================== 定义优化器,这里使用 RMSprop,不使用Adam
# 也推荐使用 SGD
# Importment Difference 2: Use RMSprop instead of Adam
# optimizer
optimizerD = RMSprop(netd.parameters(), lr=wgan_opt.lr)
optimizerG = RMSprop(netg.parameters(), lr=wgan_opt.lr)
# Importment Difference: No Log in loss
# criterion
# criterion = nn.BCELoss()
# step 6: ========================================== 开始训练了
# begin training
rand_noise = Variable(t.FloatTensor(wgan_opt.batch_size, wgan_opt.nz, 1, 1).normal_(0, 1))
iter_count = 0
# 将BCEloss 改为非log的loss,按照文章的记载,通常会使用直接同1和-1做比较
one = t.ones(wgan_opt.batch_size)
mone = -1 * one
for epoch in range(wgan_opt.max_epoch):
for ii, data in enumerate(dataloader, 0):
imgs = data[0] # real image
noise = Variable(t.randn(imgs.size(0), wgan_opt.nz, 1, 1)) # fake image
print(imgs.shape)
# Importment Difference 4: clip param for discriminator
for parm in netd.parameters():
parm.data.clamp_(-wgan_opt.clamp_num, wgan_opt.clamp_num)
# ----- train discriminator network -----
netd.zero_grad()
output = netd(imgs) # train netd with real img
output.backward(one) # 跟 1 进行比较
fake_pic = netg(noise).detach() # train netd with real img, 梯度在此截断,不要继续往前传播。
output2 = netd(fake_pic)
output2.backward(mone) # 跟 -1 进行比较
optimizerD.step()
# ------ train generator later -------
# we train the discriminator many times, and less train for generator.
# train netd more times: because the better netd is the better netg will be
if (ii + 1) % 1 == 0:
netg.zero_grad()
noise.data.normal_(0, 1)
fake_pic = netg(noise)
output = netd(fake_pic)
output.backward(one) # 跟 1 进行比较
optimizerG.step()
if iter_count % 50 == 0:
rand_imgs = netg(rand_noise)
rand_imgs = deprocess_img(rand_imgs.data)
save_image(rand_imgs, 'D:/software/Anaconda3/doc/3D_Img/wgan2/test_%d.png' % (iter_count))
iter_count = iter_count + 1
print('iter_count: ', iter_count)
效果如下:
最后都是用随机噪音产生的图片,时间太长了,训练次数不太够啊。