首页 > 其他分享 >cycleGAN代码实现(附详细代码注释)

cycleGAN代码实现(附详细代码注释)

时间:2022-08-18 16:35:12浏览次数:41  
标签:channels features nn cycleGAN 代码 注释 num self size

最近刚刚入门深度学习,试着复现cycleGAN代码。看了一个YouTube博主的cycleGAN代码,自己跟着写了一遍,同时加上了代码注释,希望能帮到同样的入门伙伴

下面的github地址

RRRRRBL/CycleGAN-Detailed-notes-: 内含cycleGAN代码,且有详细代码注释 (github.com)
在这里给出一个生成器的代码

import torch
import torch.nn as nn

class ConvBlock(nn.Module):
def init(self, in_channels, out_channels, down=True, use_act=True, kwargs): # down:下采样,act:激活,kwargs字典参数
super().init()
self.conv = nn.Sequential( # 卷积块,可以完成下采样卷积或者保持原size卷积
nn.Conv2d(in_channels, out_channels, padding_mode='reflect', **kwargs)
if down
else nn.ConvTranspose2d(in_channels, out_channels, **kwargs),
nn.InstanceNorm2d(out_channels), # 标准化
nn.ReLU(inplace=True) if use_act else nn.Identity() # identity不会做任何操作
)

def forward(self, x):
    return self.conv(x)

class ResidualBlock(nn.Module): # 残差块,不改变size
def init(self, channels):
super().init()
self.block = nn.Sequential(
ConvBlock(channels, channels, kernel_size=3, padding=1),
ConvBlock(channels, channels, use_act=False, kernel_size=3, padding=1)
)

def forward(self, x):
    return x + self.block(x)  # 残差块儿

class Generator(nn.Module):
def init(self, img_channels, num_features=64, num_residuals=9, ): # num_features是通道数的一个公约数,num_residuals残差层数
super(Generator, self).init()
self.initial = nn.Sequential( # 初始化
nn.Conv2d(img_channels, num_features, kernel_size=7, stride=1, padding=3, padding_mode='reflect'),
nn.InstanceNorm2d(num_features),
nn.ReLU(inplace=True), # 原地激活
)
self.down_blocks = nn.ModuleList( # 下采样(增加通道数,减小img尺寸
[
ConvBlock(num_features, num_features * 2, kernel_size=3, stride=2, padding=1),
ConvBlock(num_features * 2, num_features * 4, kernel_size=3, stride=2, padding=1),

        ]
    )
    self.residual_block = nn.Sequential(  # 残差块儿(不改变大小
        *[ResidualBlock(num_features * 4) for _ in range(num_residuals)]
        # *4是因为之前的各类操作得到的变量channel已经是4
        # 是4*num_featurs了,这里调用了九次残差块儿,进行训练,大小一直不变
    )
    self.up_blocks = nn.ModuleList(  # 上采样block channels减小,img变大
        [
            ConvBlock(num_features * 4, num_features * 2, down=False, kernel_size=3, stride=2, padding=1,
                      output_padding=1),
            ConvBlock(num_features * 2, num_features * 1, down=False, kernel_size=3, stride=2, padding=1,
                      output_padding=1),

        ]
    )
    self.last = nn.Conv2d(num_features * 1, img_channels, kernel_size=7, stride=1, padding=3,
                          padding_mode='reflect')

def forward(self, x):
    x = self.initial(x)  # 初始化
    for layer in self.down_blocks:
        x = layer(x)
    x = self.residual_block(x)
    for layer in self.up_blocks:
        x = layer(x)
    return torch.tanh(self.last(x))

'''
观察代码不难发现,在整个生成器的生成过程中,用到的还是简单基础的知识,只是在一些处理选择上比较特殊
代码利用了残差神经网络 和卷积神经网络集合的方式进行训练
def test():
img_channels = 3
img_size = 256
x = torch.randn((2, img_channels, img_size, img_size))
gen = Generator(img_channels, 9)
print(gen(x).shape)

if name == "main":
test()
'''

标签:channels,features,nn,cycleGAN,代码,注释,num,self,size
From: https://www.cnblogs.com/RBLstudying/p/16599150.html

相关文章

  • 手机网页限制用户缩放代码 (2014-03-25 18:16:52)
    网页手机wap2.0网页的head里加入下面这条元标签,在iPhone的浏览器中页面将以原始大小显示,并不允许缩放。    width-viewport的宽度height-viewport的高度  initi......
  • Dynamics CRM 365 通过代码的方式,移除实体窗体里面的JS脚本
    在某些场景,您想把所有实体的某个JS脚本移除,或者您想大量实体上追加某个JS脚本的时候,那这篇博客就能给你最好的启示。 1.我们分析一下,JS脚本是挂在窗体上的,那是否在窗体表......
  • 浏览器端代码获取资源
    如何获取本地国标文件很少把自己的代码发上来,因为一方面都是旁门左道,不是真正能够独当一面的成果。另一方面只是一时使用,过了这个阶段或者环境变了就没法用了。这次贴上来......
  • 针对`Code View`友好的代码重构方法
    针对CodeView友好的代码重构方法本文记录在开发过程中,写出对CodeReView友好代码的若干方法。抽取函数将较为独立的语句抽取为函数,是一种很常见的重构手段,本文在此基......
  • Unity中报不能启用不安全代码的错误
    今天下载了网上的一个Demo,打开的时候报了一个不能加载不安全代码的错误这个问题我以前遇见过,但太久不用一时间没想起来在哪里设置,这里记录一下打开unity的 编辑==>项......
  • 代码实现斐波那契数列
    #定义函数deffab(n):#判断n的有效性ifn<=0:return'传递的参数必须大于0的正整数'#当n为1时返回斐波那契数的第1个数0elifn==1:......
  • 代码考核
    实现数组的flat实现数组的flat方法,支持传入递归深度代码模板:constreadline=require('readline');constrl=readline.createInterface({input:process.std......
  • 代码实现 打印九九乘法口诀
    #for循环,其中range(1,10)取1-9之间的整数,不会取到10#range(1,10)相当于数学中的[1,10),取值范围是前闭后开foriinrange(1,10):#for循环,取1到i的整数......
  • python 代码测试(pytest)
    前话代码测试用于检验代码运行结果是否符合预期。优势一:编写测试函数,更规范,高效的核对代码运行结果,当被测试对象进行了调整和重构的时候,可以节省大量人工排查问题的时间......
  • Unity 代码调用重新生成csproj文件
    结论先放结论:editor代码中直接调用Unity.CodeEditor.CodeEditor.CurrentEditor.SyncAll();原因在一些操作后,比如修改csc.rsp的内容之后,需要重新生成csproj文件方......