首页 > 其他分享 >详解cycleGAN(生成对抗网络)代码

详解cycleGAN(生成对抗网络)代码

时间:2022-12-04 10:01:32浏览次数:68  
标签:real loss name cycleGAN self 详解 fake tf 对抗


文章目录

  • ​​1.cycleGAN简介​​
  • ​​2.cycle代码详解​​
  • ​​main.py​​
  • ​​module.py​​
  • ​​model.py​​
  • ​​总结​​

1.cycleGAN简介

关于cGAN我主要参考了知乎量子位的分享,理论解释的比较易懂,这里不再赘述。
​​​带你理解CycleGAN,并用TensorFlow轻松实现​​​​git代码​​​​作者fork后带注释的代码​

2.cycle代码详解

上文知乎大神已经将代码解释的差不多,但有一些细节没有解释到,这里学习后记录一下供有需要的参考哈。

main.py

此段主要执行模型的训练和测试,为了支持命令行操作,使用了argparse模块,命令行的使用方式如
​​​tensorboard --logdir=./logs​​ 然后是执行函数,可选train或者test

def main(_):
if not os.path.exists(args.checkpoint_dir):
os.makedirs(args.checkpoint_dir)
if not os.path.exists(args.sample_dir):
os.makedirs(args.sample_dir)
if not os.path.exists(args.test_dir):
os.makedirs(args.test_dir)

tfconfig = tf.ConfigProto(allow_soft_placement=True) # 允许GPU无法计算时放到CPU上
tfconfig.gpu_options.allow_growth = True # 按需分配显存
with tf.Session(config=tfconfig) as sess:
model = cyclegan(sess, args) # 定义模型
model.train(args) if args.phase == 'train' else model.test(args)

if __name__ == '__main__':
tf.app.run()

module.py

模块主要有识别器、两种生成器–generator_unet和generator_resnet
模型默认采用的resnet生成器。
识别器不复杂,这里不做解释,生成器用到了残差网络(residule_block),关于这个我理解的不深,建议参考下
​​Residual Net 详解​​ 关于两种生成器方案,作者代码默认采用的是残差网络的全0填充
代码分为三部分:编码器-转换器-解码器,其中解码器采用上述所说的残差网络

def generator_resnet(image, options, reuse=False, name="generator"):

with tf.variable_scope(name):
# image is 256 x 256 x input_c_dim
if reuse:
tf.get_variable_scope().reuse_variables()
else:
assert tf.get_variable_scope().reuse is False
def residule_block(x, dim, ks=3, s=1, name='res'):
p = int((ks - 1) / 2)
y = tf.pad(x, [[0, 0], [p, p], [p, p], [0, 0]], "REFLECT")
y = instance_norm(conv2d(y, dim, ks, s, padding='VALID', name=name+'_c1'), name+'_bn1')
# tf.pad(input, paddings, name=None) 以下表示第一个维度不加,第二个维度前加p个0后加p个0,相当于加了2p个维度
y = tf.pad(tf.nn.relu(y), [[0, 0], [p, p], [p, p], [0, 0]], "REFLECT")
y = instance_norm(conv2d(y, dim, ks, s, padding='VALID', name=name+'_c2'), name+'_bn2')
return y + x
# 编码器
# [batch_size,262,262,dim]
c0 = tf.pad(image, [[0, 0], [3, 3], [3, 3], [0, 0]], "REFLECT")
# [batch_size,256,256,dim]
c1 = tf.nn.relu(instance_norm(conv2d(c0, options.gf_dim, 7, 1, padding='VALID', name='g_e1_c'), 'g_e1_bn'))
# [batch_size,128,128,dim*2]
# 注意slim框架的卷积维度计算不是四舍五入,而是直接去掉
c2 = tf.nn.relu(instance_norm(conv2d(c1, options.gf_dim*2, 3, 2, name='g_e2_c'), 'g_e2_bn'))
# [batch_size,64,64,dim*4]
c3 = tf.nn.relu(instance_norm(conv2d(c2, options.gf_dim*4, 3, 2, name='g_e3_c'), 'g_e3_bn'))
# define G network with 9 resnet blocks
# [batch_size,64,64,dim*4]
r1 = residule_block(c3, options.gf_dim*4, name='g_r1')
# [batch_size,64,64,dim]
r2 = residule_block(r1, options.gf_dim*4, name='g_r2')
r3 = residule_block(r2, options.gf_dim*4, name='g_r3')
r4 = residule_block(r3, options.gf_dim*4, name='g_r4')
r5 = residule_block(r4, options.gf_dim*4, name='g_r5')
r6 = residule_block(r5, options.gf_dim*4, name='g_r6')
r7 = residule_block(r6, options.gf_dim*4, name='g_r7')
r8 = residule_block(r7, options.gf_dim*4, name='g_r8')
# [batch_size, 64, 64, dim]
r9 = residule_block(r8, options.gf_dim*4, name='g_r9')
# 解码器
# 反卷积
d1 = deconv2d(r9, options.gf_dim*2, 3, 2, name='g_d1_dc')
d1 = tf.nn.relu(instance_norm(d1, 'g_d1_bn'))
d2 = deconv2d(d1, options.gf_dim, 3, 2, name='g_d2_dc')
d2 = tf.nn.relu(instance_norm(d2, 'g_d2_bn'))
d2 = tf.pad(d2, [[0, 0], [3, 3], [3, 3], [0, 0]], "REFLECT")
pred = tf.nn.tanh(conv2d(d2, options.output_c_dim, 7, 1, padding='VALID', name='g_pred_c'))

return pred

model.py

下面进入本文的正菜,模型构建和训练、测试、保存
首先初始化参数变量和模型,利用namedtuple的_make方法将生成器的参数变量添加到tuple中
模型函数:

def _build_model(self):
self.real_data = tf.placeholder(tf.float32,
[None, self.image_size, self.image_size,
self.input_c_dim + self.output_c_dim],
name='real_A_and_B_images')
# 数据预处理将A B耦合在一个数组
self.real_A = self.real_data[:, :, :, :self.input_c_dim]
self.real_B = self.real_data[:, :, :, self.input_c_dim:self.input_c_dim + self.output_c_dim]
# 生成fb fba
self.fake_B = self.generator(self.real_A, self.options, False, name="generatorA2B")
self.fake_A_ = self.generator(self.fake_B, self.options, False, name="generatorB2A")
# 生成fa fab
self.fake_A = self.generator(self.real_B, self.options, True, name="generatorB2A")
self.fake_B_ = self.generator(self.fake_A, self.options, True, name="generatorA2B")
# 识别器识别fa fb
self.DB_fake = self.discriminator(self.fake_B, self.options, reuse=False, name="discriminatorB")
self.DA_fake = self.discriminator(self.fake_A, self.options, reuse=False, name="discriminatorA")
# g_loss_A = g_loss_A_1 + 10*cyc_loss
# cyc_loss = tf.reduce_mean(tf.abs(input_A-cyc_A)) + tf.reduce_mean(tf.abs(input_B-cyc_B))
# 生成器a的loss
self.g_loss_a2b = self.criterionGAN(self.DB_fake, tf.ones_like(self.DB_fake)) \
+ self.L1_lambda * abs_criterion(self.real_A, self.fake_A_) \
+ self.L1_lambda * abs_criterion(self.real_B, self.fake_B_)
# 生成器b的loss
self.g_loss_b2a = self.criterionGAN(self.DA_fake, tf.ones_like(self.DA_fake)) \
+ self.L1_lambda * abs_criterion(self.real_A, self.fake_A_) \
+ self.L1_lambda * abs_criterion(self.real_B, self.fake_B_)
# 生成器的总loss
self.g_loss = self.criterionGAN(self.DA_fake, tf.ones_like(self.DA_fake)) \
+ self.criterionGAN(self.DB_fake, tf.ones_like(self.DB_fake)) \
+ self.L1_lambda * abs_criterion(self.real_A, self.fake_A_) \
+ self.L1_lambda * abs_criterion(self.real_B, self.fake_B_)

# fake ab的初始样本
self.fake_A_sample = tf.placeholder(tf.float32,
[None, self.image_size, self.image_size,
self.input_c_dim], name='fake_A_sample')
self.fake_B_sample = tf.placeholder(tf.float32,
[None, self.image_size, self.image_size,
self.output_c_dim], name='fake_B_sample')
# 真实样本a b 的识别结果
self.DB_real = self.discriminator(self.real_B, self.options, reuse=True, name="discriminatorB")
self.DA_real = self.discriminator(self.real_A, self.options, reuse=True, name="discriminatorA")
# 假样本a b 的识别结果
self.DB_fake_sample = self.discriminator(self.fake_B_sample, self.options, reuse=True, name="discriminatorB")
self.DA_fake_sample = self.discriminator(self.fake_A_sample, self.options, reuse=True, name="discriminatorA")

# d b 的loss = loss_real + loss_fake
self.db_loss_real = self.criterionGAN(self.DB_real, tf.ones_like(self.DB_real))
self.db_loss_fake = self.criterionGAN(self.DB_fake_sample, tf.zeros_like(self.DB_fake_sample))
self.db_loss = (self.db_loss_real + self.db_loss_fake) / 2

self.da_loss_real = self.criterionGAN(self.DA_real, tf.ones_like(self.DA_real))
self.da_loss_fake = self.criterionGAN(self.DA_fake_sample, tf.zeros_like(self.DA_fake_sample))
self.da_loss = (self.da_loss_real + self.da_loss_fake) / 2
self.d_loss = self.da_loss + self.db_loss

# 将生成器的loss添加到log日志
self.g_loss_a2b_sum = tf.summary.scalar("g_loss_a2b", self.g_loss_a2b)
self.g_loss_b2a_sum = tf.summary.scalar("g_loss_b2a", self.g_loss_b2a)
self.g_loss_sum = tf.summary.scalar("g_loss", self.g_loss)
# 将summary保存到磁盘 可以使用tf.summary.merge_all

self.g_sum = tf.summary.merge([self.g_loss_a2b_sum, self.g_loss_b2a_sum, self.g_loss_sum])

self.db_loss_sum = tf.summary.scalar("db_loss", self.db_loss)
self.da_loss_sum = tf.summary.scalar("da_loss", self.da_loss)
self.d_loss_sum = tf.summary.scalar("d_loss", self.d_loss)
self.db_loss_real_sum = tf.summary.scalar("db_loss_real", self.db_loss_real)
self.db_loss_fake_sum = tf.summary.scalar("db_loss_fake", self.db_loss_fake)
self.da_loss_real_sum = tf.summary.scalar("da_loss_real", self.da_loss_real)
self.da_loss_fake_sum = tf.summary.scalar("da_loss_fake", self.da_loss_fake)
self.d_sum = tf.summary.merge(
[self.da_loss_sum, self.da_loss_real_sum, self.da_loss_fake_sum,
self.db_loss_sum, self.db_loss_real_sum, self.db_loss_fake_sum,
self.d_loss_sum]
)

self.test_A = tf.placeholder(tf.float32,
[None, self.image_size, self.image_size,
self.input_c_dim], name='test_A')
self.test_B = tf.placeholder(tf.float32,
[None, self.image_size, self.image_size,
self.output_c_dim], name='test_B')
self.testB = self.generator(self.test_A, self.options, True, name="generatorA2B")
self.testA = self.generator(self.test_B, self.options, True, name="generatorB2A")
# 提取生成器和识别器要训练的变量,并打印全部变量
t_vars = tf.trainable_variables()
self.d_vars = [var for var in t_vars if 'discriminator' in var.name]
self.g_vars = [var for var in t_vars if 'generator' in var.name]
for var in t_vars: print(var.name)

训练函数

def train(self, args):
"""Train cyclegan"""
self.lr = tf.placeholder(tf.float32, None, name='learning_rate')
# 定义生成器和识别器的梯度更新函数
self.d_optim = tf.train.AdamOptimizer(self.lr, beta1=args.beta1) \
.minimize(self.d_loss, var_list=self.d_vars)
self.g_optim = tf.train.AdamOptimizer(self.lr, beta1=args.beta1) \
.minimize(self.g_loss, var_list=self.g_vars)

init_op = tf.global_variables_initializer()
self.sess.run(init_op)
# 指定一个文件用来保存图
self.writer = tf.summary.FileWriter("./logs", self.sess.graph)

counter = 1
start_time = time.time()

# 如果继续训练,加载最新的模型
if args.continue_train:
if self.load(args.checkpoint_dir):
print(" [*] Load SUCCESS")
else:
print(" [!] Load failed...")

for epoch in range(args.epoch):
# 获取通过正则得到的文件列表
dataA = glob('./datasets/{}/*.*'.format(self.dataset_dir + '/trainA'))
dataB = glob('./datasets/{}/*.*'.format(self.dataset_dir + '/trainB'))
# 随机打乱数据
np.random.shuffle(dataA)
np.random.shuffle(dataB)
# batch的数量
batch_idxs = min(min(len(dataA), len(dataB)), args.train_size) // self.batch_size
lr = args.lr if epoch < args.epoch_step else args.lr*(args.epoch-epoch)/(args.epoch-args.epoch_step)

for idx in range(0, batch_idxs):
# 数据预处理
# 抽取每一个batch的AB的数据组成一个对应的tuple的列表
batch_files = list(zip(dataA[idx * self.batch_size:(idx + 1) * self.batch_size],
dataB[idx * self.batch_size:(idx + 1) * self.batch_size]))
# 加载数据
batch_images = [load_train_data(batch_file, args.load_size, args.fine_size) for batch_file in batch_files]
batch_images = np.array(batch_images).astype(np.float32)

# 更新生成器并记录生成的假样本数据
fake_A, fake_B, _, summary_str = self.sess.run(
[self.fake_A, self.fake_B, self.g_optim, self.g_sum],
feed_dict={self.real_data: batch_images, self.lr: lr})
self.writer.add_summary(summary_str, counter)
[fake_A, fake_B] = self.pool([fake_A, fake_B])

# 更新识别器
_, summary_str = self.sess.run(
[self.d_optim, self.d_sum],
feed_dict={self.real_data: batch_images,
self.fake_A_sample: fake_A,
self.fake_B_sample: fake_B,
self.lr: lr})
self.writer.add_summary(summary_str, counter)

counter += 1
print(("Epoch: [%2d] [%4d/%4d] time: %4.4f" % (
epoch, idx, batch_idxs, time.time() - start_time)))
# 100次保存一个生成器的输出
if np.mod(counter, args.print_freq) == 1:
self.sample_model(args.sample_dir, epoch, idx)
# 1000次 保存一次模型
if np.mod(counter, args.save_freq) == 2:
self.save(args.checkpoint_dir, counter)

注意1代码需要在命令行中运行,因为训练所使用的变量都位于argparse模块定义的ArgumentParser类中,如果直接运行系统会报错
​​​SystemExit: 2 error when calling parse_args()​​​ 可以使用以下命令在ipython运行
​!CUDA_VISIBLE_DEVICES=0 python main.py --dataset_dir=horse2zebra​注意2如果有小伙伴跑该程序终端,再次continue_train,会发现模型保存不太正常,建议接着上次训练的话把counter参数改成上次最新保存的模型的数字,这样保存的模型是连续的~

总结

以上是本次的主要内容,后续还会再更新一下,补充下细节,有问题欢迎提问,一起讨论下哈~


标签:real,loss,name,cycleGAN,self,详解,fake,tf,对抗
From: https://blog.51cto.com/u_15899958/5909919

相关文章

  • 详解支持向量机-SVC真实数据案例:预测明天是否会下雨-处理困难特征:地点【菜菜的sklearn
    视频作者:菜菜TsaiTsai链接:【技术干货】菜菜的机器学习sklearn【全85集】Python进阶_哔哩哔哩_bilibili常识上来说,我们认为地点肯定是对明天是否会下雨存在影响的。比如......
  • Day30:ArrayList详解
    ArrayList1.1集合概述当我们要存储多个数据时,固定长度的数组存储格式已经满足不了我们的需要了,且不能满足变化的需求;Java中集合类则可以解决我们的需求特点:提供一种存......
  • Spring MVC请求地址映射详解:HandlerMapping
    1HandlerMapping介绍HandlerMapping是SpringMVC的核心组件之一,用来保存request-handler之间的映射。简单来说,request指的是请求地址(还包括请求方法等),handler指的是Cont......
  • Nginx map 使用详解
    map指令介绍:map指令是由ngx_http_map_module模块提供的,默认情况下安装nginx都会安装该模块。map的主要作用是创建自定义变量,通过使用nginx的内置变量,去匹配某......
  • 【SpringBoot】对于yaml的详细学习和三种属性赋值的实战详解
    一.yaml详细讲解1.1什么是yaml?YAML是一种数据序列化语言,通常用于编写配置文件。业界对YAML有不同的看法。有些人会说YAML代表另一种标记语言。其他人认为“YAML不是标记......
  • k8s篇-k8s集群架构及组件详解【史上最详细】
    Okubernetes简介k8s是什么k8s是一个可移植的、可扩展的开源平台,用于管理容器化的工作负载和服务,可以促进声明式配置和自动化。k8s能做什么1)服务发现和负载......
  • 对抗搜索
    对抗搜索(博弈搜索)主要内容:最小最大搜索Alpha-Beta剪枝搜索蒙特卡洛树搜索最小最大搜索:max就是利益最大化   复杂度:O(b^m)m是树的最大深度,在每个节点存......
  • Day29:StringBuilder详解
    StringBuilder1.1StringBuilder概述我们先对普通的String字符串对象建立进行内存分析;publicclassDemo{publicstaticvoidmain(String[]args){Strin......
  • asp教程:ASP开发中存储过程应用详解
    ASP开发中存储过程应用详解|调用,参数,存储,数据库,输出,编译,mycomm,输入,userid,代码ASP与存储过程(StoredProcedures)的文章不少,但是我怀疑作者们是否真正实践过。......
  • Python+NumPy绘制常见曲线的方法详解_python
    一、利萨茹曲线二、计算斐波那契数列 三、方波方波可以近似表示为多个正弦波的叠加。任意一个方波信号都可以用无穷傅里叶级数来表示。需要累加很多项级数,且级数越......