首页 > 其他分享 >一文读懂 PyTorch 版本 SRGAN训练和测试【CVPR 2017】

一文读懂 PyTorch 版本 SRGAN训练和测试【CVPR 2017】

时间:2022-11-08 17:39:11浏览次数:43  
标签:训练 -- 代码 py PyTorch CVPR SRGAN install test

  • ???? 声明: 作为全网 AI 领域 干货最多的博主之一,❤️ 不负光阴不负卿 ❤️

一文读懂 PyTorch 版本 SRGAN训练和测试【CVPR 2017】_数据集

  • ???? 深度学习: # 超分重建、一文读懂
  • ???? 超分重建经典网络 SRGAN 详尽教程
  • ???? 最近更新:2022年2月28日
  • ???? 点赞 ???? 收藏 ⭐留言 ???? 都是博主坚持写作、更新高质量博文的最大动力!
  • ???? ???? Follow me ????,一起 Get 更多有趣 AI、冲冲冲 ???? ????

???? 基础信息


  • 本博文运行的代码GitHub链接,弱鸡一枚,向各位前辈大佬致敬
  • ​​PyTorch implementation of SRGAN -- 非官方实现​​

遇到疑问、可第一时间评论区交流


???? 环境搭建


Ubuntu 16.04.4 LTS \n \l

2080Ti , cuda10.0

conda create -n torch11 python=3.6.9

conda activate torch11

conda install pytorch==1.1.0 torchvision==0.3.0 cudatoolkit=10.0 -c pytorch

pip install pillow==5.2.0

pip install opencv-python

pip install scipy

pip install thop

pip install matplotlib

pip install pandas

pip install tqdm

???? 数据准备


一文读懂 PyTorch 版本 SRGAN训练和测试【CVPR 2017】_数据_02

我使用的训练数据:DIV2K,如果不知道如何下载,可参考我的这篇博文

​​# 超分辨率重建数据集总结 看这篇就够了​​

  • 测试数据:点击上图 gitHub 提供的下载链接下载
  • 数据集存放位置如下

一文读懂 PyTorch 版本 SRGAN训练和测试【CVPR 2017】_数据集_03

这个代码 训练数据只需要 ​​HR​​ 图片即可:

一文读懂 PyTorch 版本 SRGAN训练和测试【CVPR 2017】_数据集_04

在 ​​train.py​​ 中设置 训练数据集、和 评估集 路径;

一文读懂 PyTorch 版本 SRGAN训练和测试【CVPR 2017】_计算机视觉_05


❤️ 训练和测试


???? Train

python train.py

optional arguments:
--crop_size training images crop size [default value is 88]
--upscale_factor super resolution upscale factor [default value is 4](choices:[2, 4, 8])
--num_epochs train epoch number [default value is 100]

The output val super resolution images are on `training_results` directory.

???? Test Benchmark Datasets

python test_benchmark.py --upscale_factor 2  --model_name netG_epoch_2_10.pth

optional arguments:
--upscale_factor super resolution upscale factor [default value is 4]
--model_name generator model epoch name [default value is netG_epoch_4_100.pth]

The output super resolution images are on benchmark_results directory.

和小伙伴的一些讨论


关于此博文的一些交流

我训练完之后如何将它的生成模型拿出来,然后单独训练我自己的低分辨率图片

  • 该代码训练出的模型位于 srgan_torch/epochs 目录下,做测试时指定 模型 路径即可
  • SRGAN pytorch 这个 代码做训练 需要 HR 原图 ;(它会把HR 下采样 得到 配对的 LR )然后进行训练;
  • test_benchmark.py 是对 Set5 、Set14 等数据集进行测试 并且 计算 psnr ;
  • test_image.py 这个 代码 就是 指定模型,然后针对一张 图片 来 进行 单独 SR 重建的;
  • 如果 你要使用 模型 来 重建自己的 LR 图片 得到 HR ,那么你写一个 for 循环 ,执行 test_image.py(或者把里面的代码 抽取为一个方法进行调用)就可以了;

test_image.py 文件 是怎么运行的?

这个文件是作者项目中的文件,它的代码可以运行,但是如果想大量重建自己某个目录下的图像;该代码尚有一些缺陷,参考该文件代码,重写一份循环调用代码即可;

  • 运行程序,只需传入对应参数即可

一文读懂 PyTorch 版本 SRGAN训练和测试【CVPR 2017】_python_06

  • 运行命令示例
python test_image.py --upscale_factor 2 --model_name netG_epoch_2_100.pth --test_mode CPU --image_name 4x2.png
  • 输出如下
test_image.py:31: UserWarning: volatile was removed and now has no effect. Use `with torch.no_grad():` instead.
image = Variable(ToTensor()(image), volatile=True).unsqueeze(0)
cost29.748524s
  • 输入和重建得到的图片位置如下,如果是其它目录,需要修改一下该代码才可行

一文读懂 PyTorch 版本 SRGAN训练和测试【CVPR 2017】_python_07


???? 相关报错


???? 初次训练 会自动下载 pytorch版本的 vgg16 model 用来 计算 loss ,考验网速哈:

一文读懂 PyTorch 版本 SRGAN训练和测试【CVPR 2017】_数据集_08

  • Linux 系统下 VGG16 模型的安放位置如下:

一文读懂 PyTorch 版本 SRGAN训练和测试【CVPR 2017】_人工智能_09

  • Windows 安放到哪里呢??? 哪位大佬看到了,欢迎评论区补充 ,给小白同学安排一下
  • 遇到这个错误,可能是因为 下载的 模型 文件是坏的(Danny 同学补充)

一文读懂 PyTorch 版本 SRGAN训练和测试【CVPR 2017】_人工智能_10

???? CUDA out of memory 报错如下:

RuntimeError: CUDA out of memory. Tried to allocate 1018.00 MiB (GPU 0; 7.79 GiB total capacity; 4.72 GiB already allocated; 853.50 MiB free; 1.52 GiB cached)

一文读懂 PyTorch 版本 SRGAN训练和测试【CVPR 2017】_数据_11

  • 解决方法: 修改 batch_size=4

???? test_benchmark.py 测试 ssim 计算 报错,处理方法如下:

一文读懂 PyTorch 版本 SRGAN训练和测试【CVPR 2017】_人工智能_12


???? 其他实验补充


显卡 2080Ti , cuda10.0 ,8G

训练数据: DIV2K

训练命令: python train.py

train.py 中 参数设置如下:

parser = argparse.ArgumentParser(description='Train Super Resolution Models')
parser.add_argument('--crop_size', default=96, type=int, help='training images crop size')
parser.add_argument('--upscale_factor', default=2, type=int, choices=[2, 4, 8],
help='super resolution upscale factor')
parser.add_argument('--num_epochs', default=100, type=int, help='train epoch number')

训练时长: 2小时20分钟 【100 epochs】

测试时长: 57s


???????? 探讨【如何开始跑实验】


没想到PyTorch 版本 SRGAN 的关注的新同学还蛮多哈;看来 PyTorch 是真香啊;就此我说一下个人对此的看法:

对于新同学而言:

  1. 第一步:正确搭建环境
  2. 第二步:正确设置数据路径
  3. 第三步:运行训练和测试
  4. 第四步:主观和客官评估超分重建效果
  5. 第五步:改进代码,循环以上步骤

最重要(基础)的就是前两步骤,如果新同学遇到问题,建议按照博文教程认真检查一下自己是否落实好基础工作;

大部分看了这个博文的同学都能够直接顺利运行和测试,说明这个教程总体上是充分够用的;

SRGAN 虽然经典,但是这3 年 过去,它终将成为过去,还是建议大家学习近两年 最新的 SR 相关论文和代码;

​​超分重建-代码环境搭建--专栏​​ 下有几篇 19、20年的超分代码经典总结相信也一定可以帮助各位新同学参考学习哈;

嗯,最后,感谢您的耐心查阅,博主本人现在已经从 SR 脱坑到图像修复、目标检测(难瘦)了,不过也还在视觉这个深坑里哈,大家一起学习进步啊;

搭合适自己的顺风车,即是真正的高效

一文读懂 PyTorch 版本 SRGAN训练和测试【CVPR 2017】_人工智能_13


墨理学AI


  • 作为全网 AI 领域 干货最多的博主之一,❤️ 不负光阴不负卿 ❤️
  • ❤️ 如果文章对你有帮助、点赞、评论鼓励博主的每一分认真创作

快乐学AI 、深度学习环境搭建 : 一文读懂

  • ???? # ubuntu给当前用户安装cuda11.2 图文教程
  • ???? # linux和window设置 pip 镜像源——最实用的机器学习库下载加速设置
  • ????# anaconda conda 切换为国内源 、windows 和 Linux配置方法、 添加清华源——【一文读懂】
  • ???? # 指定GPU运行和训练python程序 、深度学习单卡、多卡 训练GPU设置【一文读懂】
  • ???? # Linux下cuda10.0安装Pytorch和Torchvision【一文读懂】
  • ???? # 一文读懂SSH密码登录、公钥认证登录
  • ???? # 一文读懂 Centos、Ubuntu 环境 安装JDK 11:配置JAVA_HOME环境变量

一文读懂 PyTorch 版本 SRGAN训练和测试【CVPR 2017】_人工智能_14

标签:训练,--,代码,py,PyTorch,CVPR,SRGAN,install,test
From: https://blog.51cto.com/u_15660370/5833861

相关文章

  • 使用PyTorch实现简单的AlphaZero的算法(1):背景和介绍
    在本文中,我们将在PyTorch中为ChainReaction[2]游戏从头开始实现DeepMind的AlphaZero[1]。为了使AlphaZero的学习过程更有效,我们还将使用一个相对较新的改进,称为“Playout......
  • Mathis Petrovich-2021-Action-Conditioned-3D-Human-Motion-Synthesis-with-Transfor
    #Action-Conditioned3DHumanMotionSynthesisWithTransformerVAE#paper1.paper-info1.1MetadataAuthor::[[MathisPetrovich]],[[MichaelJ.Black]],[......
  • PyTorch实现非极大值抑制(NMS)
    NMS即nonmaximumsuppression即非极大抑制,顾名思义就是抑制不是极大值的元素,搜索局部的极大值。在最近几年常见的物体检测算法(包括rcnn、sppnet、fast-rcnn、faster-rcnn......
  • Pytorch中模型调用
    注意:RNN、LSTM的batch_first参数,对于不同的网络层,输入的维度虽然不同,但是通常输入的第一个维度都是batch_size,比如torch.nn.Linear的输入(batch_size,in_features),torch.nn......
  • PyTorch笔记:hook的作用
    参考自https://zhuanlan.zhihu.com/p/279903361,原始来自:https://towardsdatascience.com/how-to-use-pytorch-hooks-5041d777f904在Module官方文档那片笔记中已经有一部......
  • PyTorch笔记:如何保存与加载checkpoints
    https://pytorch.org/tutorials/recipes/recipes/saving_and_loading_a_general_checkpoint.html保存和加载checkpoints很有帮助。为了保存checkpoints,必须将它们放在......
  • PyTorch笔记:Python中的state_dict是啥
    来自:https://pytorch.org/tutorials/recipes/recipes/what_is_state_dict.html在PyTorch中,可学习的参数都被保存在模型的parameters中,可以通过model.parameters()访问......
  • PyTorch笔记:Modules官方文档
    来自https://pytorch.org/docs/stable/notes/modules.htmlASimpleCustomModuleimporttorchfromtorchimportnnclassMyLinear(nn.Module):def__init__(se......
  • 使用LabVIEW实现基于pytorch的DeepLabv3图像语义分割
     前言今天我们一起来看一下如何使用LabVIEW实现语义分割。一、什么是语义分割图像语义分割(semanticsegmentation),从字面意思上理解就是让计算机根据图像的语义来进......
  • 【2022.11.03】pytorch的使用相关
    Pytorch的使用相关,学习来源:https://www.bilibili.com/video/BV1Wv411h7kN/?p=6加载数据有两种方法,一种是torch.utils.data.Dataset,一种是torch.utils.data.DataloaderTe......