首页 > 其他分享 >超分辨率(2)--基于EDSR网络实现图像超分辨率重建

超分辨率(2)--基于EDSR网络实现图像超分辨率重建

时间:2024-03-16 16:32:13浏览次数:29  
标签:EDSR opt -- 分辨率 input path self dir size

目录

一.项目介绍

二.项目流程详解

2.1.构建网络模型

2.2.数据集处理

2.3.训练模块

2.4.测试模块

三.测试网络


一.项目介绍

EDSR全称Enhanced Deep Residual Networks,是SRResnet的升级版,其对网络结构进行了优化(去除了BN层),省下来的空间可以用于提升模型的size来增强表现力。

为什么要去除BN层:

Batch Norm是深度学习中非常重要的技术,不仅可以使训练更深的网络变容易,加速收敛,还有一定正则化的效果,可以防止模型过拟合。

但对于图像超分辨率来说,网络输出的图像在色彩、对比度、亮度上要求和输入一致,改变的仅仅是分辨率和一些细节,而Batch Norm,对图像来说类似于一种对比度的拉伸,任何图像经过Batch Norm后,其色彩的分布都会被归一化,也就是说,它破坏了图像原本的对比度信息,所以Batch Norm的加入反而影响了网络输出的质量。

网络结构及对比:

移除BN层后,模型更加轻量,BN层所消耗的存储空间等同于上一层CNN层所消耗的,作者指出相比于SRResNet,EDSR去掉BN层之后节约了40%的存储资源。

同时在BN腾出来的空间下插入更多的类似于残差块等CNN-based子网络来增加模型的表现力。

论文地址:

[1707.02921] Enhanced Deep Residual Networks for Single Image Super-Resolution (arxiv.org)icon-default.png?t=N7T8https://arxiv.org/abs/1707.02921源码地址:

developer0hye/EDAR: PyTorch implementation of Deep Convolution Networks based on EDSR for Compression(Jpeg) Artifacts Reduction (github.com)icon-default.png?t=N7T8https://github.com/developer0hye/EDAR

二.项目流程详解

2.1.构建网络模型

def default_conv(in_channels, out_channels, kernel_size, bias=True):
    return nn.Conv2d(
        in_channels, out_channels, kernel_size,
        padding=(kernel_size//2), bias=bias)

class MeanShift(nn.Conv2d):
    def __init__(self, rgb_mean, rgb_std, sign=-1):
        super(MeanShift, self).__init__(3, 3, kernel_size=1)
        std = torch.Tensor(rgb_std)
        self.weight.data = torch.eye(3).view(3, 3, 1, 1)
        self.weight.data.div_(std.view(3, 1, 1, 1))
        self.bias.data = sign * torch.Tensor(rgb_mean)
        self.bias.data.div_(std)
        self.requires_grad = False

class ResBlock(nn.Module):
    def __init__(
        self, conv, n_feat, kernel_size,
        bias=True, act=nn.ReLU(True)):

        super(ResBlock, self).__init__()
        m = []
        for i in range(2):
            m.append(conv(n_feat, n_feat, kernel_size, bias=bias))
            if i == 0: m.append(act)

        # m是设置好的conv层
        # 设置网络内部层次结构为body
        self.body = nn.Sequential(*m)

    def forward(self, x):
        # 获取当前的结果
        res = self.body(x)
        # 当前得到的网络和最初的网络融合
        res += x

        return res


class EDAR(nn.Module):
    def __init__(self, conv=common.default_conv):
        super(EDAR, self).__init__()

        # 参数设置
        n_resblock = 8  # resnet长度
        n_feats = 64
        kernel_size = 3  # 卷积核大小

        #DIV 2K mean
        rgb_mean = (0.4488, 0.4371, 0.4040)
        rgb_std = (1.0, 1.0, 1.0)
        self.sub_mean = common.MeanShift(rgb_mean, rgb_std)

        # define head module
        # 经过卷积,特征图数由3->n_feats
        m_head = [conv(3, n_feats, kernel_size)]

        # define body module
        # Residual Block设置
        m_body = [
            common.ResBlock(
                conv, n_feats, kernel_size
            ) for _ in range(n_resblock)
        ]
        m_body.append(conv(n_feats, n_feats, kernel_size))

        # define tail module
        # 经过卷积,特征图数由n_feats->3
        m_tail = [
            conv(n_feats, 3, kernel_size)
        ]

        self.add_mean = common.MeanShift(rgb_mean, rgb_std, 1)

        # 设置网络的三个层次
        self.head = nn.Sequential(*m_head)
        self.body = nn.Sequential(*m_body)
        self.tail = nn.Sequential(*m_tail)

前向传播过程:

    def forward(self, x):
        x = self.sub_mean(x)
        x = self.head(x)

        res = self.body(x)
        res += x

        x = self.tail(res)
        x = self.add_mean(x)
        
        # 将输入input张量每个元素的范围限制到区间 [min,max],返回结果到一个新张量。
        # 及输出一个新张量值x,并限制他的值在0~1之间
        return torch.clamp(x,0.0,1.0)

2.2.数据集处理

import os
import io
import random
from PIL import Image
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

class Dataset(object):
    def __init__(self, images_dir, patch_size=48, jpeg_quality=40, transforms=None):
        self.images = os.walk(images_dir).__next__()[2]
        self.images_path = []
        for img_file in self.images:
            if img_file.endswith((".ppm")):
                try:
                    #print(os.path.join(images_dir, img_file))
                    label = Image.open(os.path.join(images_dir, img_file))
                    self.images_path.append(os.path.join(images_dir, img_file))
                except:
                    print(f"Image {os.path.join(images_dir, img_file)} didn't get loaded")
        self.patch_size = patch_size
        self.jpeg_quality = jpeg_quality
        self.transforms = transforms
        self.random_rotate = [0, 90, 180, 270]

    def __getitem__(self, idx):
        label = Image.open(self.images_path[idx]).convert('RGB')
        label = label.rotate(self.random_rotate[random.randrange(0,4)])

        # randomly crop patch from training set
        crop_x = random.randint(0, label.width - self.patch_size)
        crop_y = random.randint(0, label.height - self.patch_size)
        # 使用crop函数对图片进行裁剪
        label = label.crop((crop_x, crop_y, crop_x + self.patch_size, crop_y + self.patch_size))


        # additive jpeg noise
        buffer = io.BytesIO()
        label.save(buffer, format='jpeg', quality=random.randrange(self.jpeg_quality+1))

        input = Image.open(buffer).convert('RGB')

        if self.transforms is not None:
            input = self.transforms(input)
            label = self.transforms(label)
        #print("Image transformed")
        return input, label

    def __len__(self):
        return len(self.images_path)

2.3.训练模块

import argparse
import os

from dataset import Dataset
from edar import EDAR

import torch
from torch import nn
from torch.utils.data.dataloader import DataLoader
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torchvision import transforms
from torchvision.models.vgg import vgg16

from utils import AverageMeter
from tqdm import tqdm

if __name__ == '__main__':
    '''
    It enables benchmark mode in cudnn.
    benchmark mode is good whenever your input sizes for your network do not vary. 
    This way, cudnn will look for the optimal set of algorithms for that particular configuration (which takes some time). 
    This usually leads to faster runtime.
    But if your input sizes changes at each iteration, 
    then cudnn will benchmark every time a new size appears, 
    possibly leading to worse runtime performances.
    '''
    cudnn.benchmark = True
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # 参数设置
    parser = argparse.ArgumentParser()
    # required为true的参数则是必须要设置的参数
    parser.add_argument('--images_dir', type=str, required=True)
    parser.add_argument('--outputs_dir', type=str, required=True)
    parser.add_argument('--jpeg_quality', type=int, default=40)
    parser.add_argument('--patch_size', type=int, default=48)
    parser.add_argument('--batch_size', type=int, default=16)
    parser.add_argument('--num_epochs', type=int, default=400)
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--threads', type=int, default=1)
    parser.add_argument('--seed', type=int, default=123)
    parser.add_argument('--resume', default='', type=str, metavar='PATH',
                        help='path to latest checkpoint (default: none)')

    opt = parser.parse_args()

    # 如果输出文件夹不存在,则自动创建一个文件夹
    if not os.path.exists(opt.outputs_dir):
        os.makedirs(opt.outputs_dir)

    torch.manual_seed(opt.seed)

    transforms_train = transforms.Compose([transforms.ToTensor()])
    # 模型设置
    model = EDAR().to(device)
    print("Model loaded")

    if opt.resume:
        if os.path.isfile(opt.resume):
            state_dict = model.state_dict()
            for n, p in torch.load(opt.resume, map_location=lambda storage, loc: storage).items():
                if n in state_dict.keys():
                    state_dict[n].copy_(p)
                else:
                    raise KeyError(n)

    # 损失函数设置
    criterion = nn.L1Loss()
    # 优化器设置
    optimizer = optim.Adam(model.parameters(), lr=opt.lr)
    print("Data processing started")
    # 数据集设置
    dataset = Dataset(opt.images_dir, opt.patch_size, opt.jpeg_quality,transforms=transforms_train)
    dataloader = DataLoader(dataset=dataset,
                            batch_size=opt.batch_size,
                            shuffle=True,
                            num_workers=opt.threads,
                            pin_memory=True,
                            drop_last=True)
    print("Data loading completed")
    #vgg = vgg16(pretrained=True).cuda()
    #loss_network = nn.Sequential(*list(vgg.features)[:31]).eval()
#     for param in loss_network.parameters():
#         param.requires_grad = False

    # 开始训练
    for epoch in range(opt.num_epochs):
        epoch_losses = AverageMeter()
        print("Length of the dataset is", len(dataset))
        with tqdm(total=(len(dataset) - len(dataset) % opt.batch_size)) as _tqdm:
            _tqdm.set_description('epoch: {}/{}'.format(epoch + 1, opt.num_epochs))
            # 按照dataloader的格式取出data
            for data in dataloader:
                inputs, labels = data
                inputs = inputs.to(device)
                labels = labels.to(device)
                #print(inputs.size(), labels.size())

                outs = model(inputs)

                # 损失值计算,参数是预测值和实际值
                loss = criterion(outs, labels)
                #perception_loss = criterion(loss_network(outs), loss_network(labels))

                #loss = loss + perception_loss*0.06

                epoch_losses.update(loss.item(), len(inputs))

                # 梯度清零
                optimizer.zero_grad()

                # 反向传播
                loss.backward()
                # 更新参数
                optimizer.step()

                _tqdm.set_postfix(loss='{:.6f}'.format(epoch_losses.avg))
                _tqdm.update(len(inputs))

        torch.save(model.state_dict(), os.path.join(opt.outputs_dir, '{}_epoch_{}.pth'.format("EDAR_", epoch)))

2.4.测试模块

import argparse
import os
import io
import torch
import torch.backends.cudnn as cudnn
from torchvision import transforms
import PIL.Image as pil_image
import glob

from edar import EDAR

cudnn.benchmark = True
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

if __name__ == '__main__':
    # 参数设置
    parser = argparse.ArgumentParser()
    parser.add_argument('--weights_path', type=str, required=True)
    parser.add_argument('--image_path', type=str, required=True)
    parser.add_argument('--outputs_dir', type=str, required=True)
    parser.add_argument('--jpeg_quality', type=int, default=40)
    parser.add_argument('--input_dir', type=str, required=False)
    opt, unknown = parser.parse_known_args()
    model = EDAR()

    state_dict = model.state_dict()
    # 参数获取
    for n, p in torch.load(opt.weights_path, map_location=lambda storage, loc: storage).items():
        if n in state_dict.keys():
            state_dict[n].copy_(p)
        else:
            raise KeyError(n)

    model = model.to(device)
    print(device)
    model.eval()
    
    if opt.input_dir:
        filenames = [os.path.join(opt.input_dir, file) for file in os.listdir(opt.input_dir) if file.endswith(("ppm", "jpeg", "png", "jpg"))]
        print(filenames)
    else:
        filenames = opt.image_path
        
    if not os.path.exists(opt.outputs_dir):
        os.makedirs(opt.outputs_dir)

    # 处理单个测试图片时使用:
    filename = filenames
    print("file is", filename)
    input = pil_image.open(filename).convert('RGB')
    print("Input size:", input.size)

    print("file is", filename)
    input = pil_image.open(filename).convert('RGB')
    print("Input size:", input.size)

    #buffer = io.BytesIO()
    #input.save(buffer, format='jpeg', quality=opt.jpeg_quality)
    #input = pil_image.open(buffer)
    #input.save(os.path.join(opt.outputs_dir, '{}_jpeg_q{}.png'.format(filename, opt.jpeg_quality)))

    input = transforms.ToTensor()(input).unsqueeze(0).to(device)
    output_path = os.path.join(opt.outputs_dir, '{}-{}.jpeg'.format(filename.split("/")[-1].split(".")[0], "edar"))
        
    if not os.path.exists(output_path):
            with torch.no_grad():
                pred = model(input)[-1]

            pred = pred.mul_(255.0).clamp_(0.0, 255.0).squeeze(0).permute(1, 2, 0).byte().cpu().numpy()
            output = pil_image.fromarray(pred, mode='RGB')
            print("Output size", output.size)
            print("Output dir is", opt.outputs_dir)
            output.save(output_path)
            #print(os.path.join(opt.outputs_dir, '{}_{}.png'.format(filename, "EDAR")))
            #print("Output saved")

    '''
    处理多个测试图片时使用:
    for filename in filenames:
        print("file is", filename)
        input = pil_image.open(filename).convert('RGB')
        print("Input size:", input.size)

        # buffer = io.BytesIO()
        # input.save(buffer, format='jpeg', quality=opt.jpeg_quality)
        # input = pil_image.open(buffer)
        # input.save(os.path.join(opt.outputs_dir, '{}_jpeg_q{}.png'.format(filename, opt.jpeg_quality)))

        input = transforms.ToTensor()(input).unsqueeze(0).to(device)
        output_path = os.path.join(opt.outputs_dir, '{}-{}.jpeg'.format(filename.split("/")[-1].split(".")[0], "edar"))

        if not os.path.exists(output_path):
            with torch.no_grad():
                pred = model(input)[-1]

            pred = pred.mul_(255.0).clamp_(0.0, 255.0).squeeze(0).permute(1, 2, 0).byte().cpu().numpy()
            output = pil_image.fromarray(pred, mode='RGB')
            print("Output size", output.size)
            print("Output dir is", opt.outputs_dir)
            output.save(output_path)
            # print(os.path.join(opt.outputs_dir, '{}_{}.png'.format(filename, "EDAR")))
            # print("Output saved")
    '''

三.测试网络

参数设置:

输入图片:

输出图片:

输入图片:

输出图片:

标签:EDSR,opt,--,分辨率,input,path,self,dir,size
From: https://blog.csdn.net/GodFishhh/article/details/136606410

相关文章

  • 初出茅庐的小李博客之串口屏开发一个音乐控制器UI
    串口屏介绍串口屏通常指的是一种带有串口接口的显示屏,可以通过串口与其他设备进行通信和控制。这种屏幕通常具有独立的控制器和显示功能,可以直接接入主控系统,实现信息的显示和交互。开发步骤准备UI素材准备了100张音量的图标,这里面还遇到了个小问题,这么多图片如何批量......
  • 【2024年5月备考新增】《软考真题分章练习 - 5 项目进度管理(高项)》
    1、()isatechniqueforestimatingthedurationorcostofanactivityoraprojectusinghistoricaldatafromasimilaractivityorproject.A.AnalogousestimatingB.parametricestimatingC.Three-PointestimatingD.Bottomestimating2、下图中(单位:......
  • <网络安全>《68 微课堂<第9课 常见IT系统集成商简介>》
    1什么是集成商集成商是指那些专门提供系统集成服务的公司,他们通过整合不同的技术、产品和服务,为客户提供一个完整、高效的解决方案。常见的集成商主要包括以下几类:IT系统集成商:这类集成商专注于信息技术领域的集成服务,包括硬件、软件、网络、数据中心等方面的集成。他们......
  • 一个命令查看自己的WIFI密码
    一个命令查看自己的WIFI密码忘记wifi密码怎么办?一个命令查看自己的wifi密码。一、打开命令行使用快捷键“Win+R”,打开运行窗口,输入“cmd”后回车即可。二、输入命令networkshell命令输入命令networkshell,简称netsh;再输入wlan,也就是wifi;最后输入showprofile。......
  • Python疑难杂症(13)---Python的几个比较难理解的内置函数,包括range、zip、map、lambda
    1、range()range(start=0, stop[, step=1])构造器的参数必须为整数(可以是内置的 int 或任何实现了 __index__() 特殊方法的对象)。生成一个start到stop的数组,左闭右开, 类型表示不可变的数字序列,通常用于在 for 循环中循环指定的次数。list(range(6))[0,1,2,3......
  • 嵌入(embedding)概念
    摘要:     嵌入(embedding)在数学和相关领域中是指将一个数学对象在保持其某些关键性质不变的前提下,注入到一个更大或更高维的空间中。这个过程不仅仅是简单的映射,而是要求注入的对象在新空间中的表现形式能够完整反映原有对象的内在结构和性质。    嵌入(embeddi......
  • 指针数组、数组指针、函数指针、指针函数
    数组指针:是指向数组的指针,它还是一个指针,只不过指向数组而已行指针定义形式:int(*p)[10]一定要加(),因为[]优先级高于*,所以必须要(*p)指一行,这里10为列的元素个数例1:二维数组数值为1-12,用行指针定义输出8例2:用行指针传参,2*3数组,输出第二行指针数组:实际是一个数组,长度是......
  • 刷题统计
    题目小明决定从下周一开始努力刷题准备蓝桥杯竞赛。他计划周一至周五每天做a道题目,周六和周日每天做b道题目。请你帮小明计算,按照计划他将在第几天实现做题数大于等于n题?题目描述:小明决定从下周一开始努力刷题准备蓝桥杯竞赛。他计划周一至周五每天做a道题目,周六......
  • 顺子日期
    题目本题为填空题,只需要算出结果后,在代码中使用输出语句将所填结果输出即可。小明特别喜欢顺子。顺子指的就是连续的三个数字:123、456等。顺子日期指的就是在日期的yyyymmdd表示法中,存在任意连续的三位数是一个顺子的日期。例如20220123就是一个顺子日期,因为它出现了......
  • pyCharm oj 习题 列表合并、去重、排序
    列表合并、去重、排序ProblemDescription从键盘输入两个数列,构成两个列表list1、list2,合并这两个列表为list3,将list3去掉重复元素、降序排序后生成list4.InputDescription输入两个数列,以英文逗号分隔OutputDescription输出列表list1、list2、list3、list4SampleInpu......