首页 > 其他分享 >CycleGAN训练教程

CycleGAN训练教程

时间:2024-03-26 20:32:04浏览次数:15  
标签:教程 help 训练 CycleGAN parser argument -- add type

目录

项目下载链接 

CycleGAN概述

CycleGAN原理

CycleGAN的应用场景

 训练过程

       代码内容

        环境配置

        预训练权重下载

        下载训练数据

        参数设置

        训练操作

        训练界面

        训练结果

测试操作

        测试命令

        测试结果

相关链接


项目下载链接 

       基于 PyTorch 的图像生成模型,包含了 CycleGAN 和 pix2pix 两种模型,适合用于实现图像生成和风格迁移等任务。 

论文:https://arxiv.org/pdf/1703.10593.pdf 

代码:junyanz/pytorch-CycleGAN-and-pix2pix首页 - GitCodeicon-default.png?t=N7T8https://gitcode.com/junyanz/pytorch-CycleGAN-and-pix2pix/overview           junyanz/pytorch-CycleGAN-and-pix2pix: Image-to-Image Translation in PyTorch (github.com)icon-default.png?t=N7T8https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix

CycleGAN概述

        在输入图像和输出图像之间,使用对齐图像对的训练集。但是,对于许多任务,配对训练数据将不可用。我们提出一个学习从源翻译图像的方法域 X 到目标域 Y,在没有配对的情况下例子。我们的目标是学习映射 G : X → Y使得来自 G(X) 的图像分布与使用对抗性损失的分布 Y 无法区分。由于此映射的约束非常不足,因此我们将其与逆映射 F : Y → X ,并引入强制执行 F(G(X)) ≈ X 的循环一致性损失(反之亦然)。在几个任务上呈现了定性结果不存在配对训练数据的地方,包括集合风格转移、对象变形、季节转移、照片增强等定量比较。

CycleGAN原理

(a) 我们的模型包含两个映射函数 G : X → Y 和 F : Y → X,以及相关的对抗函数
鉴别器 DY 和 DX。DY 鼓励 G 将 X 转换为与域 Y 无法区分的输出,反之亦然
用于 DX 和 F。为了进一步规范映射,我们引入了两个循环一致性损失函数,即如果
我们从一个领域转换到另一个领域,然后再转换回来:(b)前向损失函数:x → G(x) → F(G(x)) ≈ x,以及 (c) 后向损失函数:y → F(y) → G(F(y)) ≈ y

CycleGAN的应用场景

 

 训练过程

        (以马<——>斑马的训练过程为例)

       代码内容

       

 

        环境配置

#操作命令:
pip install -r requirements.txt

         requirements.txt内容如下:

torch>=1.4.0
torchvision>=0.5.0
dominate>=2.4.0
visdom>=0.1.8.8
wandb

        torch和torchvision安装建议直接官网配置:PyTorch

        预训练权重下载

        进入到根目录pytorch-CycleGAN-and-pix2pix-master/下面:

        运行命令:

bash ./scripts/download_cyclegan_model.sh horse2zebra

        下载到.//checkpoints/horse2zebra_pretrained目录下面

        

        也可以网页直接下载:Index of /cyclegan/pretrained_models (berkeley.edu)

        下载训练数据

        运行命令:

bash ./datasets/download_cyclegan_dataset.sh maps

         下载到:./datasets/horse2zebra目录下面

        

        也可以直接网页下载:打开./datasets路径下的download_cyclegan_dataset文件查看到数据下载路径为:Index of /cyclegan/datasets

        参数设置

        ./options目录下

        基础参数配置:base_options.py文件

    def initialize(self, parser):
        """Define the common options that are used in both training and test."""
        # basic parameters
        parser.add_argument('--dataroot', required=True, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')
        parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')
        parser.add_argument('--gpu_ids', type=str, default='0,5,6,7', help='gpu ids: e.g. 0  0,1,2, 0,2. use -1 for CPU')
        parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
        # model parameters
        parser.add_argument('--model', type=str, default='cycle_gan', help='chooses which model to use. [cycle_gan | pix2pix | test | colorization]')
        parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale')
        parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale')
        parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')
        parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')
        parser.add_argument('--netD', type=str, default='basic', help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')
        parser.add_argument('--netG', type=str, default='resnet_9blocks', help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]')
        parser.add_argument('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')
        parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]')
        parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]')
        parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')
        parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator')
        # dataset parameters
        parser.add_argument('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]')
        parser.add_argument('--direction', type=str, default='AtoB', help='AtoB or BtoA')
        parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
        parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data')
        parser.add_argument('--batch_size', type=int, default=32, help='input batch size')
        parser.add_argument('--load_size', type=int, default=286, help='scale images to this size')
        parser.add_argument('--crop_size', type=int, default=256, help='then crop to this size')
        parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
        parser.add_argument('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]')
        parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')
        parser.add_argument('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML')
        # additional parameters
        parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
        parser.add_argument('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]')
        parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information')
        parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')
        self.initialized = True
        return parser

        训练参数设置:train_options.py


    def initialize(self, parser):
        parser = BaseOptions.initialize(self, parser)
        # visdom and HTML visualization parameters
        parser.add_argument('--display_freq', type=int, default=400, help='frequency of showing training results on screen')
        parser.add_argument('--display_ncols', type=int, default=4, help='if positive, display all images in a single visdom web panel with certain number of images per row.')
        parser.add_argument('--display_id', type=int, default=1, help='window id of the web display')
        parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display')
        parser.add_argument('--display_env', type=str, default='main', help='visdom display environment name (default is "main")')
        parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display')
        parser.add_argument('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html')
        parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
        parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')
        # network saving and loading parameters
        parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')
        parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs')
        parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration')
        parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
        parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>, ...')
        parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
        # training parameters                     ##总共epoch niter+niter_decay=200轮
        parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate')
        parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero')
        parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
        parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
        parser.add_argument('--gan_mode', type=str, default='lsgan', help='the type of GAN objective. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.')
        parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images')
        parser.add_argument('--lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine]')
        parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations')

        测试参数配置:test_options.py

    def initialize(self, parser):
        parser = BaseOptions.initialize(self, parser)  # define shared options
        parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.')
        parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
        parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')
        parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
        # Dropout and Batchnorm has different behavioir during training and test.
        parser.add_argument('--eval', action='store_true', help='use eval mode during test time.')
        parser.add_argument('--num_test', type=int, default=50, help='how many test images to run')
        # rewrite devalue values
        parser.set_defaults(model='test')
        # To avoid cropping, the load_size should be the same as crop_size
        parser.set_defaults(load_size=parser.get_default('crop_size'))

        训练操作

        命令:

python train.py --dataroot ./datasets/horse2zebra --name maps_cyclegan --model cycle_gan

        训练界面

        训练结果

        保存在:./checkpoints/maps_cyclegan文件下面

测试操作

        测试命令

python test.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan

        测试结果

         在路径.results\maps_cyclegan\test_latest_maps下可以看到测试结果

        

         real为原图,fake为具有原图风格的假图,rec为根据假图复原原图的图。

相关链接

用自己的数据集实战CycleGAN_cyclegan训练-CSDN博客

标签:教程,help,训练,CycleGAN,parser,argument,--,add,type
From: https://blog.csdn.net/weixin_45866058/article/details/137049831

相关文章

  • 2024.3 训练记录
    3.17(模拟赛)100+45+30=175rk12/54A考虑拆点跑二分图最大匹配,左侧连向右侧的一条边就相当于在原图中选择两个有边的点相邻,然后就可以算出答案了。注意特判只有一个孤立点且其他部分能够完全匹配的情况,此时答案需要再\(-1\)。时间复杂度\(O(Tm\sqrtn)\)。B这种简单题咋就差......
  • Docker搭建持续集成平台Jenkins最简教程
    前言持续集成(ContinuousIntegration,CI)是软件开发中的一种实践,它旨在通过频繁地集成代码变更并自动运行测试,以确保代码的质量和稳定性。Jenkins是一个广泛使用的开源持续集成工具,它能够自动化构建、测试和部署软件项目。在本文中,我们将使用Docker搭建一个基于Jenkins的持续......
  • ISC2016训练赛-phrackCTF-FindKey
    ISC2016训练赛——phrackCTFReverse-FindKey:题目描述:FLAG就是你输入的key解题方法:将题目附件下载下来是一个无后缀名的文件,把他放进exeinfope.exe里查看一下它的信息这里我们看到它不是一个EXE文件,但是下面有提示说是用python,然后我们将他的后缀名改成.py文件,用python打开是......
  • 尚硅谷全新JavaWeb教程,企业主流javaweb技术栈【学习笔记】下
    尚硅谷全新JavaWeb教程,企业主流javaweb技术栈【学习笔记】上-CSDN博客 第五章XML_Tomcat10_HTTP 一XMLXML是EXtensibleMarkupLanguage的缩写,翻译过来就是可扩展标i记语言。所以很明显,XML和HTML一样都是标记语言,也就是说它们的基本语法都是标签。可扩展三个字表面上......
  • 训练人形机器人时如何收集人类行为数据 —— 通过人来训练机器人(真人实际演示动作)or
    特斯拉的老马,搞的optimus人形机器人就是通过人来训练机器人(真人实际演示动作),但是未来使用仿真环境自动生成数据是否可行呢,NVIDIA的老黄在2024GTC上是大力推出自家的GROOT平台,该平台的主要数据则是使用仿真器生成的,到底哪种方式更优呢?......
  • 【图论 | 数据结构】用链式前向星存图(保姆级教程,详细图解+完整代码)
    一、概述链式前向星是一种用于存储图的数据结构,特别适合于存储稀疏图,它可以有效地存储图的边和节点信息,以及边的权重。它的主要思想是将每个节点的所有出边存储在一起,通过数组的方式连接(类似静态数组实现链表)。这种方法的优点是存储空间小,查询速度快,尤其适合于处理大规模......
  • 代码随想录算法训练营day34 | leetcode 1005. K 次取反后最大化的数组和、134. 加油站
    目录题目链接:1005.K次取反后最大化的数组和-简单题目链接:134.加油站-中等题目链接:135.分发糖果-困难题目链接:1005.K次取反后最大化的数组和-简单题目描述:给你一个整数数组nums和一个整数k,按以下方法修改该数组:选择某个下标i并将nums[i]替换为-nums[i]。重......
  • python入门教程(非常详细),从零基础入门到精通,看完这一篇就够了
    本文罗列了了python零基础入门到精通的详细教程,内容均以知识目录的形式展开。01.python由来与发展介绍02.项目开发流程第一章markdown编辑器01.markdown基本语法02.Typora简介与安装03.Windows上gitee+Typora设置图床04.macOS上gitee+Typora设置图床第二章计算机......
  • [计算机效率] 文件搜索工具:Listary(附详细使用教程)
    3.5文件搜索工具:ListaryListary是一款实用的搜索工具,它能为我的电脑(资源管理器)增添许多智能命令,提高用户日常收藏和整理文件的效率。它具备多种实用功能,例如收藏文件夹、快速打开最近浏览的文件夹,以及快速显示/隐藏文件扩展名等。此外,Listary还是一个轻量级的软件,不会过多......
  • 07天【代码随想录算法训练营34期】 第三章 哈希表part02(● 454.四数相加II ● 383.
    454.四数相加IIclassSolution:deffourSumCount(self,nums1:List[int],nums2:List[int],nums3:List[int],nums4:List[int])->int:table=dict()foriinnums1:forjinnums2:if(i+j)intable:......