paper:https://arxiv.org/pdf/1703.10593.pdf [2017]
code 参考:
1 整体架构
整体架构主要由两个 生成器(G 和 F)、两个判别器(Dy 和 Dx)组成。
这里借用了语言翻译领域中循环一致性的思想,即将一个句子从中文翻译到英文,然后再将其翻译回中文,应该得到与原始的中文相同的句子。
所以这篇 paper 的做法是将 x 经由 G 生成得到 \(\hat{y}\),再经由 F 生成得到 \(\hat{x}\),这样看的话,这里其实就有了三个数据域:原始图像 x 所属的数据域 X,G 生成的中间图像 \(\hat{y}\) 所属的数据域 \(\hat{Y}\),以及 F 生成的重建图像 \(\hat{x}\) 的数据域 \(\hat{X}\)。
其中域 X 和域\(\hat{X}\)在训练过程中是期望它们的分布尽量接近的(所以其实这里也可以理解为总共有 2 个域),而我们所需要的风格转换后的图像其实就是中间图像 \(\hat{y}\)。
整个过程如上图 (a) 所示,两个生成器 G 和 F 循环生成,然后两个判别器 Dx 和 Dy 在各自的域上进行判断。
2 核心创新
提出了循环一致性损失,训练不再需要成对的图像数据,成为后续 GANs 相关论文的重要参考。
3 损失函数
1)传统 GAN 的生成&判别损失
由于有两个生成器(Generator)G 和 F、两个判别器(Discriminator)Dx 和 Dy,所以生成&判别损失其实是有两个部分:
以第一个公式为例,鉴别器 Dy 越大,表明预测结果越接近真实图片(也即生成的结果越接近真实)。
其中鉴别器的损失计算采用 L1 loss。
2)G 和 D 的优化目标
对于 Generator 来说,其目的就是为了使得生成的图像越接近真实图像越好,所以其优化目标是使得 Discriminator 的判别概率越大越好,也即 \(max_{D_Y}\)。
而对于 Discriminator 来说,其目的是为了尽量鉴别出由 Generator 生成的非真实图片(对于生成图片,给出低概率),所以其目标是使得对 Generator 生成的图片赋低分,也即 \(min_{G}\)。
由于有两个 Generator 和 两个 Discriminator,所以优化目标也有两组:
3)循环一致性损失
循环一致性损失主要作用就是控制在使用非对称样本时,生成结果别跑偏了,所以需要控制 F 的重建结果和原始图片的一致性,其定义如下:
其实就是两个 Generator 组成的两阶段生成的过程中,第二阶段的重建结果与第一阶段输出图片的一致性之和。
循环一致性损失采用的是 MES loss。
4)整体损失函数
就是两组 生成&对抗 损失加上一个 带权重的 循环一致性损失。
4 代码解读&实现
4.1 前置知识(可选)
在直接阅读代码之前,为了保证阅读代码的流畅性,有必要将一些可能引起疑惑的操作函数进行说明,主要包括:
- nn.ConvTranspose2d()
- nn.InstanceNorm2d()
- nn.detach()
- albumentations 图像增强库
- ReflectionPad2d
这些函数如果已经知道的可以直接跳过。
1)nn.ConvTranspose2d()
也叫转置卷积、反卷积,和卷积对应,其目的是将低 size 的 feature map 转为 高 size 的 feature map,是图像重建过程中恢复图像原来尺寸常用的操作。
尺寸变换公式:
其实大多数时候,我们在执行卷积/反卷积的时候,只是期望能将 feature size 缩放/放大 为原来的一倍,所以这里可以简化的去记变换规则,即:
当我们希望得到 输入特征图大小/输出特征图大小 = stride 的话,需要 padding = (kernel_size - stride + output_padding )/2,进一步的则 output_padding 应该取值为 stride - 1。
常用的一组参数为:kernel_size=3, stride=2, padding=1, output_padding=1,这样正好使得 feature map 被反卷积上采样为 2 倍原尺寸大小。
更多的解释可以参考:
2)nn.InstanceNorm2d()
又叫实例归一化,其是对每个样本沿着通道方向独立对各个通道进行计算,而批量归一化则是对所有样本沿着batch 的方向对各个通道分别进行计算。
举个例子:当输入特征图形状为 (2,3,256,512),表示有两个 256×512 的特征图,特征图通道数为 3,假设为 RGB 三个通道。
那么实例归一化会依次对样本 1,样本 2 分别计算 R、G、B 三个通道的均值、方差,每次计算其实是对 256×512 个元素值进行计算。
而批量归一化则是对整个批次的样本,对各个通道分别求出均值和方差,每次计算其实是对 2×256×512 个元素值进行计算。
论文图示:
至于为啥风格转换任务中要使用 IN,摘录知乎回答如下:
BN 适用于判别模型中,比如图片分类模型。因为 BN 注重对每个batch进行归一化,从而保证数据分布的一致性,而判别模型的结果正是取决于数据整体分布。但是BN对batchsize的大小比较敏感,由于每次计算均值和方差是在一个batch上,所以如果batchsize太小,则计算的均值、方差不足以代表整个数据分布;
IN适用于生成模型中,比如图片风格迁移。因为图片生成的结果主要依赖于某个图像实例,所以对整个batch归一化不适合图像风格化中,在风格迁移中使用 Instance Normalization 不仅可以加速模型收敛,并且可以保持每个图像实例之间的独立。
参考:
3)nn.detach()
举个例子来说明一下detach有什么用。 如果 A 网络的输出被喂给 B 网络作为输入,如果我们希望在梯度反传(loss.backward()
)的时候只更新 B 中参数的值,而不更新 A 中的参数值,这时候就可以使用 detach(),代码示例:
...
fake_a = gen_A(domain_b_img)
D_A_fake_prob = disc_A(fake_a.detach())
loss = mse(D_A_fake_prob, label_a_prob)
...
loss.backward()
...
这样在进行 backward 时,disc_A 网络会更新参数,但 gen_A 网络不会。
4)albumentations
albumentations 是一个第三方的图像增强操作库,其主要特点就是快,也封装了很多常规的图像增强方式(例如翻转、随机裁剪等)。
其使用方式也很简单:
import albumentations as A
from albumentations.pytorch import ToTensorV2
transforms = A.Compose(
[
A.Resize(width=256, height=256),
A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255),
ToTensorV2(),
],
additional_targets={'image0': 'image'}
)
其中需要解释的是 additional_targets 参数,其实就是将多个成对的 image 对象绑定到一起,例如将两个附加图片都绑定到原始图片上(使他们成对):{'image0': 'image', 'image1': 'image'}
,这样 image0 和 image1 就会执行和 image 相同的 transform 操作。
详见:
5)ReflectionPad2d()
这个填充函数不同全零填充(padding),而是采用输入边界的反射来填充输入张量,说人话就是用图像矩阵中其他位置的像素值来填充(扩充)边界,从而增大图像尺寸。
- 填充一层时:m = nn.ReflectionPad2d(1)
填充顺序是:左、右、上、下
- 填充多层时
示例:nn.ReflectionPad2d((1, 1, 2, 0))
中,这几个数字表示左右上下分别要填充的层数
之所以使用反射填充,一个主要的原因是如果我们直接使用全零填充,会导致图像产生黑边,影响视觉模型的训练效果,因为黑边其实是个很明显的结构特征。
在视觉代码中通常可以通过如下的写法替代原始 Conv 中默认的全零填充方式:
nn.Conv2d(in_channels, out_channels, kernel_size=7, stride=1, padding=3, padding_mode='reflect'),
6)clamp()
常用 out.clamp(0, 1) 将 out 中各个数的取值范围压缩到 0-1 之间。
7)Automatic mixed precision
Automatic mixed precision(amp),自动混合精度,可以在神经网络推理过程中,针对不同的层,采用不同的数据精度进行计算,从而实现节省显存和加快速度的目的。
自动预示着 Tensor 的 dtype 类型会自动变化,也就是框架按需自动调整 tensor 的 dtype。
更多详见:https://blog.csdn.net/Z2572862506/article/details/128800233
4.2 Code
本来想直接从官方代码入手的,但是官方的封装的有点复杂,对新手不太友好,无法直观的关注 cGAN 核心的逻辑,所以参考了多份代码,下面的内容其实就是参照着其中我感觉比较好的一份实现来写得,原代码作者视频:
https://www.bilibili.com/video/BV1kb4y197PE/?spm_id_from=333.337.search-card.all.click&vd_source=bda72e785d42f592b8a2dc6c2aad2409
4.3 Generator module
Generator 由多层卷积与残差模块堆叠而成,顺序依次为:初始转置卷积 + 2层下采样 + 9 个残差模块 + 2层上采样 + 最后一层卷积。
图像经过 Generator 后,输出与原图像保持相同尺寸。
import torch
import torch.nn as nn
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels, down=True, use_act=True, **kwargs):
"""
:param in_channels:
:param out_channels:
:param down: 是否下采样
:param use_act: 是否使用激活函数
:param kwargs:
"""
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, padding_mode='reflect', **kwargs)
if down
else nn.ConvTranspose2d(in_channels, out_channels, **kwargs),
nn.InstanceNorm2d(out_channels),
# inplace = False 时,不会修改输入对象的值,而是返回一个新创建的对象,所以打印出对象存储地址不同
# inplace = True 时,会修改输入对象的值,所以打印出对象存储地址相同
# inplace = True ,会改变输入数据的值,节省反复申请与释放内存的空间与时间,只是将原来的地址传递,效率更好
nn.ReLU(inplace=True) if use_act else nn.Identity() # nn.Identity() 这里其实就是个占位,当不使用激活函数时,表明什么都不做
)
def forward(self, x):
return self.conv(x)
class ResidualBlock(nn.Module):
def __init__(self, channels):
"""
这里的 channel、ks=3,pad=1,保证了输入数据和输出数据的维度不会发生改变,只是单纯的做 residual
:param channels:
"""
super().__init__()
self.block = nn.Sequential(
ConvBlock(channels, channels, kernel_size=3, padding=1),
ConvBlock(channels, channels, use_act=False, kernel_size=3, padding=1),
)
def forward(self, x):
return x + self.block(x)
class Generator(nn.Module):
def __init__(self, img_channels, num_features=64, num_residuals=9):
"""
Generator 经过 下采样、残差连接、上采样,其输出尺寸和输入尺寸是一致的(当然也只有这样,才能使用 l1 计算循环一致性损失)
:param img_channels: 输入图像通道数,默认为 3
:param num_features: Generator 编码时的图像尺寸基数,后面会基于该基数转换尺寸
:param num_residuals: 堆叠的残差模块个数
"""
super().__init__()
self.initial = nn.Sequential(
nn.Conv2d(img_channels, num_features, kernel_size=7, stride=1, padding=3, padding_mode='reflect'),
nn.ReLU(inplace=True)
)
# 两层卷积下采样
self.down_blocks = nn.ModuleList(
[
ConvBlock(num_features, num_features*2, kernel_size=3, stride=2, padding=1),
ConvBlock(num_features*2, num_features*4, kernel_size=3, stride=2, padding=1),
]
)
# num_residuals 层堆叠的残差模块
self.residual_blocks = nn.Sequential(
*[ResidualBlock(num_features*4) for _ in range(num_residuals)]
)
self.up_blocks = nn.ModuleList(
[
ConvBlock(num_features*4, num_features*2, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
ConvBlock(num_features*2, num_features*1, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
]
)
self.last = nn.Conv2d(num_features*1, img_channels, kernel_size=7, stride=1, padding=3, padding_mode='reflect')
def forward(self, x):
x = self.initial(x)
for layer in self.down_blocks:
x = layer(x)
x = self.residual_blocks(x)
for layer in self.up_blocks:
x = layer(x)
return torch.tanh(self.last(x))
def test():
img_channels = 3
img_size = 256
x = torch.randn((2, img_channels, img_size, img_size))
gen = Generator(img_channels, 9)
print(gen)
print('-' * 90)
y = gen(x)
print(y.shape) # torch.Size([2, 3, 256, 256])
if __name__ == '__main__':
test()
4.3 Discriminator module
Discriminator 由多层卷积组成,原论文中是直接将图片输入转为标量的概率输出,但是这里实现并没有转为标量,而是直接作为 vector 输出,之后在计算 loss 时采用 ones_like 与真实值对应上(感觉也行吧)。
import torch
import torch.nn as nn
import torch.nn.functional as F
class Block(nn.Module):
def __init__(self, in_channles, out_channels, stride):
super().__init__()
self.conv = nn.Sequential(
# 转置卷积
nn.Conv2d(in_channles, out_channels, kernel_size=4, stride=stride, padding=1, padding_mode='reflect'),
# 实例归一化
nn.InstanceNorm2d(out_channels),
nn.LeakyReLU(0.2),
)
def forward(self, x):
return self.conv(x)
class Discriminator(nn.Module):
def __init__(self, in_channels=3, features=[64, 128, 256, 512]):
super().__init__()
self.initial = nn.Sequential(
nn.Conv2d(in_channels, features[0], kernel_size=4, stride=2, padding=1, padding_mode='reflect'),
nn.LeakyReLU(0.2)
)
layers = []
in_channels = features[0]
for feature in features[1:]:
layers.append(
Block(in_channels, feature, stride=1 if feature == features[-1] else 2)
)
in_channels = feature
layers.append(nn.Conv2d(in_channels, out_channels=1, kernel_size=4, padding=1, padding_mode='reflect'))
self.model = nn.Sequential(*layers)
def forward(self, x):
x = self.initial(x)
return torch.sigmoid(self.model(x))
def test():
x = torch.randn((1, 3, 256, 256))
model = Discriminator(in_channels=3)
preds = model(x)
print(model)
print('-' * 90)
print(preds.shape) # torch.Size([1, 1, 30, 30])
if __name__ == '__main__':
test()
4.4 Dataset Loader module
在介绍数据集加载逻辑之前,先把 config.py 和 utils.py 中的一些功能函数和参数说明贴一下:
utils.py
:
import numpy as np
import os
import random
import torch
def save_checkpoint(model, optimizer, filename='my_checkpoint.pth.tar'):
print('=> Saving checkpoint')
checkpoint = {
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict()
}
torch.save(checkpoint, filename)
def load_checkpoint(checkpoint_file, model, optimizer, lr):
print('=> Loading checkpoint')
checkpoint = torch.load(checkpoint_file, map_location=config.DEVICE)
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
# if wen don't do this then it will just have learning rate of old checkpoint and it will lead to many hours of debugging
for param_group in optimizer.param_group:
param_group['lr'] = lr
def seed_everything(seed=42):
os.environ['PYTHONHASHSEED'] = str(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
config.py
:
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
TRAIN_DIR = 'data/train'
VAL_DIR = 'data/val'
BATCH_SIZE = 1
LEARNINGG_RATE = 2e-4
LAMBDA_IDENTITY = 0.0
LAMBDA_CYCLE = 10
NUM_WORKERS = 2
NUM_EPOCHS = 200
LOAD_MODEL = False # 初始训练时指定为 False
SAVE_MODEL = True
CHECKPOINT_GEN_A = 'checkpoints/gen_A.pth.tar'
CHECKPOINT_GEN_B = 'checkpoints/gen_B.pth.tar'
CHECKPOINT_DISCRIMINATOR_A = 'checkpoints/d_A.pth.tar'
CHECKPOINT_DISCRIMINATOR_B = 'checkpoints/d_B.pth.tar'
transforms = A.Compose(
[
A.Resize(width=256, height=256),
# A.HorizontalFlip(p=0.5),
A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255),
ToTensorV2(),
],
additional_targets={'image0': 'image'}
)
数据加载逻辑 dataset.py
:
import torch
from PIL import Image
import os
from torch.utils.data import Dataset
import numpy as np
class ABDataset(Dataset):
def __init__(self, root_path_domain_a, root_path_domain_b, transform=None):
self.root_path_domain_a = root_path_domain_a
self.root_path_domain_b = root_path_domain_b
self.transform = transform
self.domain_a_images = os.listdir(root_path_domain_a)
self.domain_b_images = os.listdir(root_path_domain_b)
self.length_dataset = max(len(self.domain_a_images), len(self.domain_b_images)) # 1000, 1500
self.domain_a_len = len(self.domain_a_images)
self.domain_b_len = len(self.domain_b_images)
def __len__(self):
return self.length_dataset
def __getitem__(self, index):
"""
示例:
domain_a 对应 horse
domain_b 对应 zebra
:param index:
:return:
"""
domain_a_img = self.domain_a_images[index % self.domain_a_len]
domain_b_img = self.domain_b_images[index % self.domain_b_len]
domain_a_path = os.path.join(self.root_path_domain_a, domain_a_img)
domain_b_path = os.path.join(self.root_path_domain_b, domain_b_img)
domain_a_img = np.array(Image.open(domain_a_path).convert('RGB'))
domain_b_img = np.array(Image.open(domain_b_path).convert('RGB'))
if self.transform:
augmentations = self.transform(image=domain_a_img, image0=domain_b_img)
domain_a_img = augmentations['image']
domain_b_img = augmentations['image0']
return domain_a_img, domain_b_img
4.5 Trian
前置模块写完了,下面定义训练逻辑。在写之前,先梳理一下流程。
- 获取两个域的图片 domain_a_img、domain_b_img
- 优化两个 Discriminator,目标是 \(min_{fake\_img}\)、\(max_{real\_img}\)
a. 利用 gen_A、gen_B 生成 fake_a、fake_b
b. 过 disc_A、disc_B 获得判别概率
c. 计算 disc_A、disc_B 的 loss - 优化两个 Generator,目标是 \(max_{fake\_img}\)、\(max_{rec\_img}\)
a. 利用 gen_A、gen_B 生成 rec_b、rec_a
b. 过 disc_A、disc_B 获得判别概率
c. 计算 gen_A、gen_B 的 loss
代码逻辑如下 train.py
:
import torch
from dataset import ABDataset
import sys
from utils import save_checkpoint, load_checkpoint
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from torchvision.utils import save_image
from discriminator_model import Discriminator
from generator_model import Generator
import config
def train_fn(disc_A, disc_B, gen_B, gen_A, loader, opt_disc, opt_gen, l1, mse, d_scaler, g_scaler):
loop = tqdm(loader, leave=True)
for idx, (domain_a_img, domain_b_img) in enumerate(loop):
domain_a_img = domain_a_img.to(config.DEVICE)
domain_b_img = domain_b_img.to(config.DEVICE)
# Train Discriminator A and B
with torch.cuda.amp.autocast():
fake_a = gen_A(domain_b_img)
D_A_real_prob = disc_A(domain_a_img)
D_A_fake_prob = disc_A(fake_a.detach()) # 注意这里使用 detach(),使得更新 disc_A 的时候不更新 gen_A
D_A_real_loss = mse(D_A_real_prob, torch.ones_like(D_A_real_prob))
D_A_fake_loss = mse(D_A_fake_prob, torch.zeros_like(D_A_fake_prob))
D_A_loss = D_A_real_loss + D_A_fake_loss
fake_b = gen_B(domain_a_img)
D_B_real_prob = disc_B(domain_b_img)
D_B_fake_prob = disc_B(fake_b.detach())
D_B_real_loss = mse(D_B_real_prob, torch.ones_like(D_B_real_prob))
D_B_fake_loss = mse(D_B_fake_prob, torch.zeros_like(D_B_fake_prob))
D_B_loss = D_B_real_loss + D_B_fake_loss
# put it togethor
D_loss = (D_A_loss + D_B_loss) / 2
# 注意这里使用了 amp 的话,与往常通用的写法有一点不一样了
opt_disc.zero_grad()
d_scaler.scale(D_loss).backward()
d_scaler.step(opt_disc)
d_scaler.update()
# Train Generator A and B
with torch.cuda.amp.autocast():
# adversarial loss for both generators
# 下面两句在 Discriminator 中已经执行过了,这里有必要再执行一次吗?
D_A_fake_prob = disc_A(fake_a)
D_B_fake_prob = disc_B(fake_b)
# 对于 Gs,其目标是使得 fake img 在 disc 看来,概率越接近真实越好
loss_G_A = mse(D_A_fake_prob, torch.ones_like(D_A_fake_prob))
loss_G_B = mse(D_B_fake_prob, torch.ones_like(D_B_fake_prob))
# cycle consistency loss
cycle_b = gen_B(fake_a)
cycle_a = gen_A(fake_b)
cycle_b_loss = l1(domain_b_img, cycle_b)
cycle_a_loss = l1(domain_a_img, cycle_a)
# identity loss
# 这个原论文中并有提到这个损失,config 配置文件中配置的权重为 0,所以实际上也并没有使用
identity_b = gen_B(domain_b_img)
identity_a = gen_A(domain_a_img)
identity_b_loss = l1(domain_b_img, identity_b)
identity_a_loss = l1(domain_a_img, identity_a)
# add all togethor
G_loss = (
loss_G_A
+ loss_G_B
+ cycle_a_loss * config.LAMBDA_CYCLE
+ cycle_b_loss * config.LAMBDA_CYCLE
+ identity_a_loss * config.LAMBDA_IDENTITY
+ identity_b_loss * config.LAMBDA_IDENTITY
)
opt_gen.zero_grad()
g_scaler.scale(G_loss).backward()
g_scaler.step(opt_gen)
g_scaler.update()
if idx % 100 == 0:
# 在读入图片时进行过 norm,所以这里保存时需要 denorm
save_image(domain_a_img * 0.5 + 0.5, f"saved_images/a_{idx}.png")
save_image(domain_b_img * 0.5 + 0.5, f"saved_images/b_{idx}.png")
save_image(fake_a * 0.5 + 0.5, f"saved_images/fake_a_{idx}.png") # fake_a 由 gen_B 生成
save_image(fake_b * 0.5 + 0.5, f"saved_images/fake_b_{idx}.png") # fake_b 由 gen_A 生成
def main():
disc_A = Discriminator(in_channels=3).to(config.DEVICE)
disc_B = Discriminator(in_channels=3).to(config.DEVICE)
gen_A = Generator(img_channels=3, num_residuals=9).to(config.DEVICE)
gen_B = Generator(img_channels=3, num_residuals=9).to(config.DEVICE)
# 将两个 Discriminator 放在一起优化
opt_disc = optim.Adam(
list(disc_A.parameters()) + list(disc_B.parameters()),
lr=config.LEARNINGG_RATE,
betas=(0.5, 0.999)
)
# 将两个 Generator 放在一起优化
opt_gen = optim.Adam(
list(gen_A.parameters()) + list(gen_B.parameters()),
lr=config.LEARNINGG_RATE,
betas=(0.5, 0.999),
)
l1 = nn.L1Loss()
mse = nn.MSELoss()
if config.LOAD_MODEL:
load_checkpoint(
config.CHECKPOINT_GEN_A, gen_A, opt_gen, config.LEARNINGG_RATE,
)
load_checkpoint(
config.CHECKPOINT_GEN_B, gen_B, opt_gen, config.LEARNINGG_RATE,
)
load_checkpoint(
config.CHECKPOINT_DISCRIMINATOR_A, disc_A, opt_disc, config.LEARNINGG_RATE,
)
load_checkpoint(
config.CHECKPOINT_DISCRIMINATOR_B, disc_B, opt_disc, config.LEARNINGG_RATE,
)
dataset = ABDataset(
root_path_domain_a=config.TRAIN_DIR+'/horses', root_path_domain_b=config.TRAIN_DIR+'/zebras', transform=config.transforms
)
loader = DataLoader(
dataset,
batch_size=config.BATCH_SIZE,
shuffle=True,
num_workers=config.NUM_WORKERS,
pin_memory=True
)
# amp
g_scaler = torch.cuda.amp.GradScaler()
d_scaler = torch.cuda.amp.GradScaler()
for epoch in range(config.NUM_EPOCHS):
train_fn(disc_A, disc_B, gen_B, gen_A, loader, opt_disc, opt_gen, l1, mse, d_scaler, g_scaler)
if config.SAVE_MODEL:
save_checkpoint(gen_A, opt_gen, filename=config.CHECKPOINT_DISCRIMINATOR_A)
save_checkpoint(gen_B, opt_gen, filename=config.CHECKPOINT_DISCRIMINATOR_B)
save_checkpoint(disc_A, opt_disc, filename=config.CHECKPOINT_DISCRIMINATOR_A)
save_checkpoint(disc_B, opt_disc, filename=config.CHECKPOINT_DISCRIMINATOR_B)
if __name__ == '__main__':
main()
4.6 复现结果
完整的训练是 200 个 epoch,时间太长了,这里展示下 10 个 epoch 时的结果,尽管效果还不好,但是网络确实学习着去将 horse 和 zebra 互相转换。
标签:domain,nn,img,self,Adversarial,gen,fake,Image,Cycle From: https://www.cnblogs.com/zishu/p/17353899.html