首页 > 编程语言 >SRGAN图像超分重建算法Python实现(含数据集代码)

SRGAN图像超分重建算法Python实现(含数据集代码)

时间:2023-07-06 22:35:38浏览次数:68  
标签:Python 模型 超分 channels 卷积 SRGAN 图像 size

摘要:本文介绍深度学习的SRGAN图像超分重建算法,使用Python以及Pytorch框架实现,包含完整训练、测试代码,以及训练数据集文件。博文介绍图像超分算法的原理,包括生成对抗网络和SRGAN模型原理和实现的代码,同时结合具体内容进行
解释说明,完整代码资源文件请转至文末的下载链接。

完整代码下载地址https://download.csdn.net/download/qq_32892383/87953641

COCO训练数据集https://pan.baidu.com/s/18xiqkK2m34TKo1FcKo0RJw?pwd=y5gf 提取码:y5gf

➷点击跳转至文末所有涉及的完整代码文件下载页☇


前言

        一张低分辨率的图像要想放大为更高尺寸的图像,需要对确实的细节进行插值,常见的线性插值方法利用相邻像素的信息进行补充,但放大后图像模糊、质量低下的问题仍然存在。图像的超分辨率重建技术指的是将给定的低分辨率图像通过特定的算法恢复成相应的高分辨率图像。简单来理解超分辨率重建就是将小尺寸图像变为大尺寸图像,使图像更加“清晰”,但放大时通过了深度学习的技术补充了更多细节。

        这种细节的补充并不是简单插值,而是在经过大量现实数据的训练后,针对细节的“推理”填充,你可以简单理解为和人在见过大量的高清照片后,对于模糊的相似部分能够“脑补”画面,根据画面大体轮廓将具体细节勾画出来。图像超分效果如下图所示:

        可以看到,通过特定的超分辨率重建算法,使得原本模糊的图像变得清晰了,直至今日,依托深度学习技术,图像的超分辨率重建已经取得了非凡的成绩,在效果上愈发真实和清晰。至于应用就更加广泛,如医学成像、遥感、公共安防、视频感知等。在影视素材画质的增强恢复中,许多基于深度学习的超分重建技术得到了实际应用,比如Topaz Video Enhance AI软件。有了图像超分技术,从此再也不用忍受渣渣画质了,老司机萌觉得还可以怎么用呢。


1. 实现原理

        图像的超分重建算法按照时间和效果,可以分为传统算法和深度学习算法两类。传统的超分辨率重建算法主要依靠基本的数字图像处理技术进行重建,常见的有基于插值的超分辨率重建、基于退化模型的超分辨率重建、基于学习的超分辨率重建等。

1.1 超分重建流程

        基于深度学习进行超分辨率重建的算法,较早的要属SRCNN(Super-Resolution Convolutional Neural Network)算法了,作为开山之作,其原理简单。SRCNN利用深度学习模型和大批量样本数据的训练,在超分性能上超越了一大批传统图像处理算法,从此深度学习开始向超分辨率领域研究迈进。SRCNN的网络结构如下图所示:

        以上模型(博主已添加中文注释)来自Chao Dong等人的论文"Image Super-Resolution Using Deep Convolutional Networks",主要由一个三层结构的卷积神经网络(CNN)构成。对于一张低分辨率图像,首先使用双立方插值将其放大至目标尺寸,使用以上的CNN模型去拟合低分辨率图像与高分辨率图像之间的非线性映射,最后通过重构将网络输出的结果作为高分辨率图像。

        SRCNN的流程可以简单理解为两步:图像放大和修复,如下图所示。其中,放大是采用某种方式(SRCNN采用插值上采样)将图像放大到指定倍数,再利用大数据的学习模型结合图像修复原理,将放大后的图像映射为最终输出目标。可以看出,超分辨率重建相比简单的插值放大,其在此基础上又具备了图像修复的作用,因此在超分性能上无疑大大增强。因此,超分辨率重建的很多算法也被学者迁移到图像修复领域中,完成一些诸如jpep压缩去燥、去模糊等任务。

        除此之外,对于模型的训练其流程也具有参考意义:(1)寻找大量真实场景下的高清图像样本,对每张图片进行下采样处理以降低图像分辨率(如2倍下采样、4倍下采样等),这样经过下采样图像长宽均得到等比例缩小;(2)将采样后的图像作为低分辨率图像用于输入,采样前的图像作为高分辨率图像作为真实值,以此构成有效的训练样本集;(3)利用深度学习模型对低分辨率图像进行放大重建为高分辨率的输出结果,将其与原始高分辨率图像进行比较计算误差,调整模型参数并不断迭代,使得误差下降至最低;(4)训练完的模型可以用于对新的低分辨率图像进行重建,得到高分辨率图像。

1.2 SRResNet的深度网络

        相比只有3个卷积层的SRCNN,SRResNet采用更深的网络结构模型,抽取出更高级的图像特征,深层模型对图像可以更好的进行表达,实现超分重构的性能也得到加强。深度残差网络(ResNet)的提出,很好解决了深层模型不能很好收敛的问题,其在图像分类、图像分割、目标检测等领域有着广泛应用。

        ResNet成功的重要一点,是在传统网络中引入了残差学习(Residual Learning),从而有效解决深层网络中梯度消失和精度下降的问题,使得网络层数能够大大加深。残差网络的原理图如下图所示,从图中可以看出原始数据x不仅有直接进入下一层的链接,还有一条跨越两层网络(跳链)的链接,将x带入到输出中,此时输出改为F(x)+x,使得整个模型训练时不容易发散。这里我绘制了一个残差模块,如下图所示:

        至此可以借住ResNet的特性,在SRCNN的基础上我们就可以构建更加强大的网络结构,用于超分重建的深度神经网络。SRResNet模型的主干网络其实采用了这种网络结构,如下图所示:

        SRResNet模型中采用了多个深度残差模块(16个残差模块)对图像特征进行提取,保证整个网络稳定的同时,采用深度模型提升性能。以上模型中的卷积层仅仅改变了图像的通道数,并未修改图像尺寸,由此可见目前为止的模型仍然可以看出是SRCNN类似的修复模型。

        SRResNet模型利用子像素卷积来放大图像,即在以上模型后继续添加两个子像素卷积模块,每个子像素卷积模块使得输入图像放大2倍,因此这个模型最终可以将图像放大4倍。SRResNet模型主要包含两部分:深度残差模型子像素卷积模型。深度残差模型用来进行高效的特征提取,可以在一定程度上削弱图像噪点;子像素卷积模型主要用来放大图像尺寸,其结构如下图所示:

        以上模型中,k表示卷积核大小,s为步长,n表示通道数。最后模型在输出前增加了一个卷积层用于数据调整和增强。为了训练模型SRResNet算法采用了MSE作为目标函数,即最小化模型输出的高分辨率图像(F(X)与原始分辨率图像(Y)的均方误差,其目标函数公式如下:

\[L=\frac{1}{n}\sum_{i=1}^{n}\left \| F(X_{i};\theta )-Y_i\right \|^2 \]

        MSE被广泛应用于超分重建算法的目标函数,但使用该目标函数重建的超分图像可能出现不能很好符合人眼主观感受的问题,SRGAN算法则针对该问题进行了改进。


2. SRGAN 原理与代码实现

        SRResNet算法通过深层的卷积模块完成特征映射,但也存在重建出的图像过于平滑,纹理细节信息丢失的缺陷。究其原因是采用MSE的目标函数,纹理细节处理难以满足人眼主观感受,为此如何“无中生有”重建纹理细节,那就需要利用生成对抗网络(Generative Adversarial Network, GAN)。

2.1 生成对抗网络简介

        生成对抗网络(GAN)的灵感与博弈论中博弈的思想相契合,对于深度学习而言,不再是简单的单一模型(如SRResNet),而是构造两个深度学习模型:生成网络(Generator)和判别网络(Discriminator),两个模型相互博弈,即生成网络Generator产生以假乱真的图像,而判别网络Discriminator具备辨别图像真伪的能力,彼此在相互竞争对抗中达到更好效果。GAN的模型结构如下图所示:

        上图中生成网络和判别网络的主要功能:(1)生成网络(Generator),它通过某种特定的网络结构以及目标函数来生成图像;(2)判别网络(Discriminator),判别一张图片是不是“真实的”,即判断输入的照片是不是由Generator生成;Generator的作用就是尽可能的生成逼真的图像来迷惑Discriminator,使得Discriminator判断失败;而Discriminator的作用就是尽可能的挖掘Generator的破绽,来判断图像到底是不是由Generator生成的“假冒伪劣”。

        GAN已经应用于图像补全、去噪,风格迁移,超分重建等图像领域,这里运用GAN能够减少损失函数的设计成本,从功能上看利用一定的基准,直接加上判别器,对抗训练会帮助我们解决很多问题。相比之前的简单模型,GAN可以产生更加清晰、真实的效果。

2.2 感知损失函数

        在SRGAN中重新设计了新的损失函数——感知损失(Perceptual Loss),它由内容损失和对抗损失构成:1. 对抗损失:与一般GAN定义类似,即重建出的图像被判别器正确判断的损失;2. 内容损失:内容损失更加关注重建图片与真实高清图像的语义特征差异,而不是逐个像素之间的颜色亮度差异;SRGAN的作者考虑计算图像的固有特征差异,而固有特征提取其实早有专门模型被提出用于分类等任务。因此截取这些模型的特征提取模块,用于计算重建图像和真实图像的特征(语义特征)提取,然后在提取的特征层上再进行MSE计算。

        值得一说的是,SRGAN在进行语义特征提取时,选取了VGG19模型,截取模型的有用部分后,截取的模型被称为truncated_vgg19模型。至此内容损失的计算总结如下:

  1. 根据SRResNet模型重建出超分图像(Super-Resolution,SR);
  2. 对于原始高清图像H和重建出的超分图像SR,分别应用truncated_vgg19模型,计算得到两幅图像的特征图H_fea和SR_fea;
  3. 计算推理后的特征图H_fea和SR_fea的MSE值;

2.3 SRGAN网络结构

        SRGAN 是由 Christian Ledig 和他的团队在 2017 年的论文 "Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network" 中提出的。在这篇论文中,他们提出了一种新的超分辨率方法,不仅可以恢复高分辨率图像的细节,还能使得生成的图像在视觉上更接近于真实图像。这种方法结合了深度学习中的生成对抗网络(GAN)和残差网络(ResNet),两者的结合提高了超分辨率的效果。SRGAN的网络结构由两部分组成,分别为生成器模型(Generator)和判别器模型(Discriminator)。

        (1)生成器:生成器的目标是将低分辨率的输入图像变换为高分辨率图像。在 SRGAN 中,生成器是一个深度残差网络(Deep Residual Network)。其核心是一系列的残差块,每个残差块中包含两个卷积层,每个卷积层后面都跟有批量归一化(Batch Normalization)和参数化ReLU(PReLU)激活函数。在所有的残差块后,通过两个卷积层和一个像素级卷积层(PixelShuffle)将特征映射转换回高分辨率图像。这个结构允许模型学习低分辨率和高分辨率图像之间的残差映射,从而使得网络能够有效地重建高分辨率图像的细节。生成器的网络结构如下图所示:

        (2)判别器:判别器是一个卷积神经网络,其目标是区分生成的图像是否来自真实的高分辨率图像。在 SRGAN 中,判别器的网络结构是一个深度卷积神经网络,其中包括一系列的卷积层、批量归一化层和LeakyReLU激活函数。最后通过全连接层和sigmoid激活函数输出图像的真实性概率。判别器的网络结构如下图所示:

        这两个网络相互对抗:生成器尝试生成越来越真实的图像以欺骗判别器,而判别器则努力提高其区分真实图像和生成图像的能力。通过这种对抗过程,模型最终可以生成出具有高质量细节的超分辨率图像。在实际操作中,SRGAN 需要大量的训练数据和计算资源,且训练过程需要一定的技巧和经验。尽管如此,SRGAN 仍然是图像超分辨率领域的一种重要技术,为生成逼真的高分辨率图像提供了一种有效的方法。

2.4 SRGAN网络训练

        SRGAN 的训练主要分为两个阶段:预训练阶段和对抗训练阶段。

        预训练阶段:这个阶段主要是为了训练生成器。SRGAN 中的生成器是一个深度残差网络,其目标是学习一个从低分辨率图像到高分辨率图像的映射。在预训练阶段,我们主要使用均方误差(MSE)作为损失函数,这样可以确保网络可以学习到一个相对精确的映射。这个阶段的训练可以使用高分辨率图像和对应的低分辨率图像作为训练数据。

        对抗训练阶段:在预训练阶段结束后,我们得到了一个可以生成相对准确的高分辨率图像的生成器。然后,我们进入对抗训练阶段,这个阶段的目标是训练生成器和判别器进行对抗。在这个阶段,生成器的目标是生成尽可能真实的高分辨率图像以欺骗判别器,而判别器的目标是尽可能准确地区分真实的高分辨率图像和生成器生成的高分辨率图像。对抗训练的损失函数通常包括对抗损失和内容损失两部分。对抗损失来自判别器对生成图像的判别结果,内容损失则是生成图像和真实高分辨率图像在特征空间上的差异。

        SRGAN 的损失函数主要包括两部分:对抗损失(Adversarial Loss)和感知损失(Perceptual Loss)。

        对抗损失(Adversarial Loss):对抗损失主要用于衡量生成器生成的图像和真实图像在判别器中的判别结果的差距。对抗损失的目标是鼓励生成器生成能够欺骗判别器的图像。在 SRGAN 中,使用了交叉熵损失(Cross-Entropy Loss)作为对抗损失,其公式如下:

\[L_{\text{adv}}(G,D) = \mathbb{E}_{I_{hr}\sim p_{\text{train}}(I_{hr})}[\log D(I_{hr})] + \mathbb{E}_{I_{lr}\sim p_{I_{lr}}(I_{lr})}[\log(1-D(G(I_{lr})))] \]

其中,G 是生成器,D 是判别器,\(I_{hr}\)是真实的高分辨率图像,\(I_{lr}\) 是低分辨率图像。

        感知损失(Perceptual Loss):感知损失则用于衡量生成图像和真实图像在特征空间上的差距。在 SRGAN 中,感知损失包括内容损失(Content Loss)和纹理损失(Texture Loss)。内容损失是通过预训练的 VGG19 网络提取出的特征图之间的欧几里得距离,纹理损失则是生成图像和真实图像的 Gram 矩阵之间的差距。感知损失的公式如下:

\[L_{\text{perc}}(G) = \mathbb{E}_{I_{lr}\sim p_{I_{lr}}(I_{lr}), I_{hr}\sim p_{\text{train}}(I_{hr})}[\| \phi(I_{hr}) - \phi(G(I_{lr})) \|_1 + \lambda \| \phi_{\text{gram}}(I_{hr}) - \phi_{\text{gram}}(G(I_{lr})) \|_1] \]

其中,

标签:Python,模型,超分,channels,卷积,SRGAN,图像,size
From: https://www.cnblogs.com/sixuwuxian/p/17532444.html

相关文章

  • python: using pdfplumber Lib read pdf file
     fromopenpyxlimportWorkbookfromopenpyxl.stylesimportPatternFill,Side,Borderimportpdfplumberl=[]defvisitDir(path):ifnotos.path.isdir(path):print('Error:"',path,'"isnotadirectoryordoesnotexi......
  • python列表(一)
    列表由一系列按特定顺序排列的元素组成。bicycles=['trek','cannondale','redline','specialized']print(bicycles)1.访问列表元素#索引print(bicycles[0])#最后一个元素print(bicycles[-1])#倒数第二个元素print(bicycles[-2])2.修改、添加和删除元素2.1......
  • python基础day39 生产者消费者模型和线程相关
    如何查看进程的id号进程都有几个属性:进程名、进程id号(pid--->processid)每个进程都有一个唯一的id号,通过这个id号就能找到这个进程importosimporttimedeftask():print("task中的子进程号:",os.getpid())print("主进程中的进程号:",os.getppid())#parent......
  • 多线程python
    如何开启进程使用的是内置的模块:multiprocessfrommultiprocessingimportProcessdeftask():withopen('a.txt','w',encoding="utf8")asf:f.write('helloworld')#开一个进程来执行task这个任务#如何开进程"""在Wind......
  • Logistic回归模型,python
    代码参考https://blog.csdn.net/DL11007/article/details/129204192?ops_request_misc=&request_id=&biz_id=102&utm_term=logistic%E6%A8%A1%E5%9E%8Bpython&utm_medium=distribute.pc_search_result.none-task-blog-2~all~sobaiduweb~default-1-129204192.142^v......
  • Python中标准输入(stdin)、标准输出(stdout)、标准错误(stdout)的用法
    1.标准输入input()、raw_input()Python3.x中input()函数可以实现提示输入,python2.x中要使用raw_input(),例如:foo=input("Enter:")#python2.x要用raw_input()print("Youinput:[%s]"%(foo))#测试执行Enter:abcdeYouinput:[abcde]#读取一行(不......
  • Python中os.system()、subprocess.run()、call()、check_output()的用法
    1.os.system()os.system()是对C语言中system()系统函数的封装,允许执行一条命令,并返回退出码(exitcode),命令输出的内容会直接打印到屏幕上,无法直接获取。示例:#test.pyimportosos.system("ls-l|greptest")#允许管道符#测试执行$ll<=======......
  • Python中startswith()和endswith()方法
    startswith()方法startswith()方法用于检索字符串是否以指定字符串开头,如果是返回True;反之返回False。endswith()方法endswith()方法用于检索字符串是否以指定字符串结尾,如果是则返回True;反之则返回Falses='helloword'print("s.startswith('wor'):",s.startswith('wor......
  • 【Python】多维列表变为一维列表的方法--numpy
    转载:(18条消息)【Python】多维列表变为一维列表的方法_四维列表变一维_Vincent__Lai的博客-CSDN博客题目给定一个多维列表,怎么让其变为一维?例如,输入:[[1,4],[2],[3,5,6]],输出:[1,4,2,3,5,6]常规一行做法a=[[1,4],[2],[3,5,6]]a=[jforiinaforjini......
  • 将PYTHON包环境从一个电脑拷贝到另外一个电脑
    将PYTHON包环境从一个电脑拷贝到另外一个电脑1、在当前电脑复制D:\ProgramFiles\Python\Python311\Lib中的所有文件生成myrequirement.txt文件pipfreeze>myrequirement.txtmyrequirement.txt文件如下:colorama==0.4.6constantly==15.1.0cpca==0.5.5cryptography=......