首页 > 其他分享 >pytorch深度学习基础 4 (马变斑马)

pytorch深度学习基础 4 (马变斑马)

时间:2024-08-10 21:28:00浏览次数:22  
标签:ngf nn self pytorch mult 图像 model 斑马 马变

今天我们来介绍一个神奇的网络,生成对抗网络GAN,这个模型纯属当做娱乐,供大家消遣娱乐,在这里我只展示一下GAN模型有趣的一个小功能,先来给大家介绍一下GAN模型吧。

GAN 的基本原理

GAN(Generative Adversarial Networks,生成对抗网络)是一种深度学习模型,由两个主要的部分组成:生成器和判别器。生成器试图从一个简单的随机噪声分布中生成数据实例,而判别器则尝试区分这些生成的样本和真实的样本 1

生成器和判别器的对抗训练

在 GAN 的训练过程中,生成器和判别器进行着一种零和游戏的对抗训练。生成器试图欺骗判别器,使其认为生成的样本是真实的,而判别器则试图正确地区分出哪些样本是真实的,哪些是生成的 1

权重共享和批标准化

在训练过程中,为了解决模式塌陷问题,GAN 使用了权重共享和批标准化技术。权重共享意味着每个样本使用相同的权重进行生成,而批标准化则确保了每个批次的样本具有相同的均值和方差 1

GAN 的主要应用场景

GAN 由于其生成高质量数据的能力

但是我们今天使用到的是GAN网络中的其中的一个小的分支CycleGAN

一、CycleGAN 简介

CycleGAN(Cycle Generative Adversarial Network)是一种特殊类型的生成对抗网络,旨在解决无配对数据的图像到图像转换问题。

二、工作原理

  1. 两个生成器
    • 一个将源域图像转换为目标域图像。
    • 另一个执行相反的转换。
  2. 两个判别器
    • 分别判断生成的目标域图像和源域图像的真实性。
  3. 循环一致性损失
    • 确保转换后的图像能够再转换回原始图像,保持一定的相似性。

三、特点

  1. 无需配对数据
    • 传统的图像转换方法通常需要源域和目标域一一对应的图像对,而 CycleGAN 打破了这一限制。
  2. 多领域转换
    • 能够实现多种不同领域之间的图像转换,如风格迁移、季节转换等。
  3. 灵活性高
    • 可以根据不同的任务和数据进行调整和优化。

四、应用领域

  1. 艺术创作
    • 实现不同艺术风格之间的转换。
  2. 图像增强
    • 改善图像的质量和效果。
  3. 虚拟现实和增强现实
    • 生成逼真的虚拟场景和增强现实效果。

五、结论

CycleGAN 为图像转换任务提供了一种创新且有效的方法,具有广泛的应用前景和研究价值。

CycleGAN是循环生成式对抗网络的缩写,它可以将一个领域的图像转换为另一个领域的图像。

本篇博客也是利用了这个特点。

一个把马变成斑马的网络,CycleGAN对从IamgeNet数据集中提取的(不相关的)马和斑马的数据集进行了训练。该网络学习获取一匹或多匹马的图像,并将他们全部变成斑马,图像的其余部分尽可能不被修改。代码直接给大家了,我就不进行代码的讲解了。

from PIL import Image
from torchvision import transforms
import torch


import torch
import torch.nn as nn

class ResNetBlock(nn.Module): # <1>

    def __init__(self, dim):
        super(ResNetBlock, self).__init__()
        self.conv_block = self.build_conv_block(dim)

    def build_conv_block(self, dim):
        conv_block = []

        conv_block += [nn.ReflectionPad2d(1)]

        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True),
                       nn.InstanceNorm2d(dim),
                       nn.ReLU(True)]

        conv_block += [nn.ReflectionPad2d(1)]

        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True),
                       nn.InstanceNorm2d(dim)]

        return nn.Sequential(*conv_block)

    def forward(self, x):
        out = x + self.conv_block(x) # <2>
        return out


class ResNetGenerator(nn.Module):

    def __init__(self, input_nc=3, output_nc=3, ngf=64, n_blocks=9): # <3>

        assert(n_blocks >= 0)
        super(ResNetGenerator, self).__init__()

        self.input_nc = input_nc
        self.output_nc = output_nc
        self.ngf = ngf

        model = [nn.ReflectionPad2d(3),
                 nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=True),
                 nn.InstanceNorm2d(ngf),
                 nn.ReLU(True)]

        n_downsampling = 2
        for i in range(n_downsampling):
            mult = 2**i
            model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
                                stride=2, padding=1, bias=True),
                      nn.InstanceNorm2d(ngf * mult * 2),
                      nn.ReLU(True)]

        mult = 2**n_downsampling
        for i in range(n_blocks):
            model += [ResNetBlock(ngf * mult)]

        for i in range(n_downsampling):
            mult = 2**(n_downsampling - i)
            model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
                                         kernel_size=3, stride=2,
                                         padding=1, output_padding=1,
                                         bias=True),
                      nn.InstanceNorm2d(int(ngf * mult / 2)),
                      nn.ReLU(True)]

        model += [nn.ReflectionPad2d(3)]
        model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
        model += [nn.Tanh()]

        self.model = nn.Sequential(*model)

    def forward(self, input): # <3>
        return self.model(input)
netG = ResNetGenerator()
model_path = "C:\\deep learning\\pytorch学习\\horse2zebra_0.4.0.pth"
model_data = torch.load(model_path)
netG.load_state_dict(model_data)
netG.eval()
preprocess = transforms.Compose([transforms.Resize(256),
                                 transforms.ToTensor()])
img = Image.open("C:\\deep learning\\pytorch学习\\horse.jpg")
img.show()
img_t = preprocess(img)
batch_t = torch.unsqueeze(img_t, 0)
batch_out = netG(batch_t)

out_t = (batch_out.data.squeeze() + 1.0) / 2
out_img = transforms.ToPILImage()(out_t)
out_img.save("C:\\deep learning\\pytorch学习\\horse2.jpg")
out_img.show()

上面的ResNetGenerator类将其方便后续的实例化,把路径改为自己的就行,需要权重文件以及horse的图片的评论区留言,看到就发网盘链接

代码调试完成后,运行代码后的效果如下图所示

是不是特别神奇,有趣!!!

标签:ngf,nn,self,pytorch,mult,图像,model,斑马,马变
From: https://blog.csdn.net/2301_76846375/article/details/141095585

相关文章

  • 推理延迟:解决PyTorch模型Inference阶段的RuntimeError ⏳⚡
    推理延迟:解决PyTorch模型Inference阶段的RuntimeError⏳⚡推理延迟:解决PyTorch模型Inference阶段的RuntimeError⏳⚡摘要引言正文内容什么是RuntimeError?⏳RuntimeError的常见成因⚠️数据格式不一致内存不足模型参数不匹配解决RuntimeError的方法......
  • 零基础学习人工智能—Python—Pytorch学习(三)
    前言这篇文章主要两个内容。一,把上一篇关于requires_grad的内容补充一下。二,介绍一下线性回归。关闭张量计算关闭张量计算。这个相对简单,阅读下面代码即可。print("============关闭require_grad==============")x=torch.randn(3,requires_grad=True)print(x)x.requir......
  • Pytorch深度学习入门基础(三):python 加载数据初认识
    目录 一、 导入二、数据集中数据和label的组成形式三、Dataset读入数据四、Dataset类代码实战4.1创建函数4.2  设置初始化函数4.3读取每一个图片4.4设置获取数据长度函数4.5创建实例4.5.1单个图片数据集4.5.2 多个图片数据集    现在来开......
  • Pytorch函数基础:鸢尾花数据集分类
    博客框架引言简要介绍机器学习和分类问题介绍鸢尾花数据集简述PyTorch的作用及其在深度学习中的重要性环境准备安装所需的库(PyTorch、NumPy、Matplotlib、Pandas等)创建并激活Python虚拟环境(可选)数据加载与预处理从CSV文件读取数据数据转换和标准化将数据转换为Py......
  • 【深度学习与NLP】——快速入门Pytorch基本语法
    目录Pytorch基本语法1.1认识Pytorch1.1.1什么是Pytorch1.1.2Pytorch的基本元素操作1.1.3 Pytorch的基本运算操作1.1.4 关于TorchTensor和Numpyarray之间的相互转换1.1.5小节总结1.2Pytorch中的autograd1.2.1关于torch.Tensor1.2.2关于Tensor的操作1.2.3......
  • Nvidia Jetson Xavier NX安装GPU版pytorch与torchvision
    前提是已经安装好了系统,并通过JetPack配置完了cuda、cudnn、conda等库。1.安装GPU版pytorch在base环境上新建环境,python版本3.8,激活并进入。condacreate-npytorch_gpupython=3.8condaactivatepytorch_gpu前往Nvidia论坛,下载JetsonNX专用的pytorch安装包。传送门:ht......
  • 零基础学习人工智能—Python—Pytorch学习(一)
    前言其实学习人工智能不难,就跟学习软件开发一样,只是会的人相对少,而一些会的人写文章,做视频又不好好讲。比如,上来就跟你说要学习张量,或者告诉你张量是向量的多维度等等模式的讲解;目的都是让别人知道他会这个技术,但又不想让你学。对于学习,多年的学习经验,和无数次的回顾学习过程,都......
  • pytorch深度学习分类代码简单示例
    train.py代码如下importtorchimporttorch.nnasnnimporttorch.optimasoptimmodel_save_path="my_model.pth"#定义简单的线性神经网络模型classMyModel(nn.Module):def__init__(self):super(MyModel,self).__init__()self.output=n......
  • pytorch和deep learning技巧和bug解决方法短篇收集
    有一些几句话就可以说明白的观点或者解决的的问题,小虎单独收集到这里。torch.hub.loadhowdoesitwork下载预训练模型再载入,用程序下载链接可能失效。model=torch.hub.load('ultralytics/yolov5','yolov5s')model=torch.hub.load('ultralytics/yolov3','yolov3......
  • pytorch张量运算
    pytorch张量运算2.1数据操作深度学习落实到计算表现为矩阵计算pytorch、tensorflow中,计算的基本组件是Tensor。张量即多维数组,是矩阵计算的基本单元。Tensor:张量,一维张量即向量vector,二维张量即二维数组。张量是n维数组的统称python中有专门进行矩阵计算的库:numpy。pytor......