首页 > 其他分享 >图像重建Restormer介绍与使用

图像重建Restormer介绍与使用

时间:2024-11-06 14:46:39浏览次数:3  
标签:task idx img restored Restormer 图像 tile dir 重建

文章目录


前言

图像恢复是计算机视觉领域中的一个重要研究方向,它旨在通过算法修复损坏、模糊或缺失的图像信息,从而恢复图像的原始质量。随着图像处理技术的不断发展,图像恢复在许多领域都发挥着重要作用,如医学影像、遥感图像、文化遗产保护等。
本次我将介绍一款在图像恢复的多个任务中表现都不错的一个网络Restormer,并介绍其环境配置与代码使用,帮助大家在实际项目中使用Restormer。


一、Restormer介绍

在计算机视觉领域,高分辨率图像恢复是一个重要的挑战。图像在采集、传输或处理过程中,往往因为各种原因受到模糊、噪声等干扰,导致图像质量下降。为了解决这个问题,研究者们提出了各种模型和技术。近年来,Transformer模型在自然语言处理和计算机视觉领域取得了巨大成功。然而,传统的Transformer模型在处理高分辨率图像时,由于其庞大的计算量和参数量,难以实现高效的处理。
为了解决这个问题,作者提出了一种高效的Transformer模型——Restormer。Restormer模型通过在构建块中进行了几个关键的设计,包括多头注意和前馈网络,使其能够捕获长程像素交互,同时仍然适用于大型图像。这种设计使得Restormer在处理高分辨率图像时,能够更高效地恢复图像质量。
Restormer模型在多个图像恢复任务中取得了最先进的结果。这些任务包括图像去模糊、单图像运动去模糊(单图像和双像素数据)和图像去噪(高斯灰度/颜色去噪和真实图像去噪)。这些结果证明了Restormer模型在图像恢复任务中的有效性。

模型结构:在这里插入图片描述
在这里插入图片描述

二、环境安装与配置

1.下载项目

2.安装虚拟环境:

conda create -n pytorch181 python=3.7
conda activate pytorch181

3.安装依赖

conda install pytorch=1.8 torchvision cudatoolkit=10.2 -c pytorch
pip install matplotlib scikit-learn scikit-image opencv-python yacs joblib natsort h5py tqdm
pip install einops gdown addict future lmdb numpy pyyaml requests scipy tb-nightly yapf lpips

4.下载预训练模型

三、代码使用与效果

运行下面代码,demo.py:修改def get_weights_and_parameters(task, parameters):函数中的
elif task == ‘Single_Image_Defocus_Deblurring’:
weights = r"F:\Restormer-main\Defocus_Deblurring\pretrained_models\single_image_defocus_deblurring.pth"
将权重路径修改为自己刚才下载的预训练模型路径,运行即可

## Restormer: Efficient Transformer for High-Resolution Image Restoration
## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang
## https://arxiv.org/abs/2111.09881

##--------------------------------------------------------------
##------- Demo file to test Restormer on your own images---------
## Example usage on directory containing several images:   python demo.py --task Single_Image_Defocus_Deblurring --input_dir './demo/degraded/' --result_dir './demo/restored/'
## Example usage on a image directly: python demo.py --task Single_Image_Defocus_Deblurring --input_dir './demo/degraded/portrait.jpg' --result_dir './demo/restored/'
## Example usage with tile option on a large image: python demo.py --task Single_Image_Defocus_Deblurring --input_dir './demo/degraded/portrait.jpg' --result_dir './demo/restored/' --tile 720 --tile_overlap 32
##--------------------------------------------------------------

import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
import os
from runpy import run_path
from skimage import img_as_ubyte
from natsort import natsorted
from glob import glob
import cv2
from tqdm import tqdm
import argparse
from pdb import set_trace as stx
import numpy as np

parser = argparse.ArgumentParser(description='Test Restormer on your own images')
parser.add_argument('--input_dir', default='./demo/degraded/', type=str, help='Directory of input images or path of single image')
parser.add_argument('--result_dir', default='./demo/restored/', type=str, help='Directory for restored results')
parser.add_argument('--task', default="Single_Image_Defocus_Deblurring", type=str, help='Task to run', choices=['Motion_Deblurring',
                                                                                    'Single_Image_Defocus_Deblurring',
                                                                                    'Deraining',
                                                                                    'Real_Denoising',
                                                                                    'Gaussian_Gray_Denoising',
                                                                                    'Gaussian_Color_Denoising'])
parser.add_argument('--tile', type=int, default=None, help='Tile size (e.g 720). None means testing on the original resolution image')
parser.add_argument('--tile_overlap', type=int, default=32, help='Overlapping of different tiles')

args = parser.parse_args()

def load_img(filepath):
    return cv2.cvtColor(cv2.imread(filepath), cv2.COLOR_BGR2RGB)

def save_img(filepath, img):
    cv2.imwrite(filepath,cv2.cvtColor(img, cv2.COLOR_RGB2BGR))

def load_gray_img(filepath):
    return np.expand_dims(cv2.imread(filepath, cv2.IMREAD_GRAYSCALE), axis=2)

def save_gray_img(filepath, img):
    cv2.imwrite(filepath, img)

def get_weights_and_parameters(task, parameters):
    if task == 'Motion_Deblurring':
        weights = os.path.join('Motion_Deblurring', 'pretrained_models', 'motion_deblurring.pth')
    elif task == 'Single_Image_Defocus_Deblurring':
        weights = r"F:\Restormer-main\Defocus_Deblurring\pretrained_models\single_image_defocus_deblurring.pth"
    elif task == 'Deraining':
        weights = os.path.join('Deraining', 'pretrained_models', 'deraining.pth')
    elif task == 'Real_Denoising':
        weights = os.path.join('Denoising', 'pretrained_models', 'real_denoising.pth')
        parameters['LayerNorm_type'] =  'BiasFree'
    elif task == 'Gaussian_Color_Denoising':
        weights = os.path.join('Denoising', 'pretrained_models', 'gaussian_color_denoising_blind.pth')
        parameters['LayerNorm_type'] =  'BiasFree'
    elif task == 'Gaussian_Gray_Denoising':
        weights = os.path.join('Denoising', 'pretrained_models', 'gaussian_gray_denoising_blind.pth')
        parameters['inp_channels'] =  1
        parameters['out_channels'] =  1
        parameters['LayerNorm_type'] =  'BiasFree'
    return weights, parameters

task    = args.task
inp_dir = args.input_dir
out_dir = os.path.join(args.result_dir, task)

os.makedirs(out_dir, exist_ok=True)

extensions = ['jpg', 'JPG', 'png', 'PNG', 'jpeg', 'JPEG', 'bmp', 'BMP']

if any([inp_dir.endswith(ext) for ext in extensions]):
    files = [inp_dir]
else:
    files = []
    for ext in extensions:
        files.extend(glob(os.path.join(inp_dir, '*.'+ext)))
    files = natsorted(files)

if len(files) == 0:
    raise Exception(f'No files found at {inp_dir}')

# Get model weights and parameters
parameters = {'inp_channels':3, 'out_channels':3, 'dim':48, 'num_blocks':[4,6,6,8], 'num_refinement_blocks':4, 'heads':[1,2,4,8], 'ffn_expansion_factor':2.66, 'bias':False, 'LayerNorm_type':'WithBias', 'dual_pixel_task':False}
weights, parameters = get_weights_and_parameters(task, parameters)

load_arch = run_path(os.path.join('basicsr', 'models', 'archs', 'restormer_arch.py'))
model = load_arch['Restormer'](**parameters)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

checkpoint = torch.load(weights)
model.load_state_dict(checkpoint['params'])
model.eval()

img_multiple_of = 8

print(f"\n ==> Running {task} with weights {weights}\n ")

with torch.no_grad():
    for file_ in tqdm(files):
        if torch.cuda.is_available():
            torch.cuda.ipc_collect()
            torch.cuda.empty_cache()

        if task == 'Gaussian_Gray_Denoising':
            img = load_gray_img(file_)
        else:
            img = load_img(file_)

        input_ = torch.from_numpy(img).float().div(255.).permute(2,0,1).unsqueeze(0).to(device)

        # Pad the input if not_multiple_of 8
        height,width = input_.shape[2], input_.shape[3]
        H,W = ((height+img_multiple_of)//img_multiple_of)*img_multiple_of, ((width+img_multiple_of)//img_multiple_of)*img_multiple_of
        padh = H-height if height%img_multiple_of!=0 else 0
        padw = W-width if width%img_multiple_of!=0 else 0
        input_ = F.pad(input_, (0,padw,0,padh), 'reflect')

        if args.tile is None:
            ## Testing on the original resolution image
            restored = model(input_)
        else:
            # test the image tile by tile
            b, c, h, w = input_.shape
            tile = min(args.tile, h, w)
            assert tile % 8 == 0, "tile size should be multiple of 8"
            tile_overlap = args.tile_overlap

            stride = tile - tile_overlap
            h_idx_list = list(range(0, h-tile, stride)) + [h-tile]
            w_idx_list = list(range(0, w-tile, stride)) + [w-tile]
            E = torch.zeros(b, c, h, w).type_as(input_)
            W = torch.zeros_like(E)

            for h_idx in h_idx_list:
                for w_idx in w_idx_list:
                    in_patch = input_[..., h_idx:h_idx+tile, w_idx:w_idx+tile]
                    out_patch = model(in_patch)
                    out_patch_mask = torch.ones_like(out_patch)

                    E[..., h_idx:(h_idx+tile), w_idx:(w_idx+tile)].add_(out_patch)
                    W[..., h_idx:(h_idx+tile), w_idx:(w_idx+tile)].add_(out_patch_mask)
            restored = E.div_(W)

        restored = torch.clamp(restored, 0, 1)

        # Unpad the output
        restored = restored[:,:,:height,:width]

        restored = restored.permute(0, 2, 3, 1).cpu().detach().numpy()
        restored = img_as_ubyte(restored[0])

        f = os.path.splitext(os.path.split(file_)[-1])[0]
        # stx()
        if task == 'Gaussian_Gray_Denoising':
            save_gray_img((os.path.join(out_dir, f+'.png')), restored)
        else:
            save_img((os.path.join(out_dir, f+'.png')), restored)

    print(f"\nRestored images are saved at {out_dir}")

使用效果:
原图:
在这里插入图片描述
在这里插入图片描述

在这里插入图片描述
修复图:
请添加图片描述
在这里插入图片描述

在这里插入图片描述

标签:task,idx,img,restored,Restormer,图像,tile,dir,重建
From: https://blog.csdn.net/HanWenKing/article/details/143558048

相关文章

  • cv2.threshold利用OSTU方法分割图像的前景和背景
    OSTU方法,又称大津法或最大类间方差法,是一种在图像处理中广泛应用的自动阈值选择方法。该方法由日本学者大津(NobuyukiOtsu)于1979年提出,旨在通过最大化前景与背景之间的类间方差来自动确定一个最佳阈值,从而将图像分割成前景和背景两部分。OSTU方法的核心思想是寻找一个阈值T,使......
  • GS-Blur数据集:首个基于3D场景合成的156,209对多样化真实感模糊图像数据集。
    2024-10-31,由韩国首尔国立大学的研究团队创建的GS-Blur数据集,通过3D场景重建和相机视角移动合成了多样化的真实感模糊图像,为图像去模糊领域提供了一个大规模、高覆盖度的新工具,显著提升了去模糊算法在真实世界场景中的泛化能力。数据集地址:GS-Blur|图像去模糊数据集|图像处理......
  • Python socket传输图像文件
    客户端发送图像文件importsocketdata=numpy.frombuffer(stringData,numpy.uint8)#将获取到的字符流数据转换成1维数组#decimg=cv2.imdecode(data,cv2.COLOR_BGR2GRAY)#将数组解码成图像#cv2.imwrite("./test.jpg",decimg)#imencode()将图片格式转换(编码)成流数据,......
  • opencv保姆级讲解——图像预处理(3)
    图像滤波所为图像滤波通过滤波器得到另一个图像什么是滤波器在深度学习中,滤波器又称为卷积核,滤波的过程成为卷积卷积核概念卷积核大小,一般为奇数,如3*35*57*7为什么卷积核大小是奇数?原因是:保证锚点在中间,防止位置发生偏移的原因卷积核大小的影响在深度学习中,卷积......
  • 《DNK210使用指南 -CanMV版 V1.0》第三十五章 image图像特征检测实验
    第三十五章image图像特征检测实验1)实验平台:正点原子DNK210开发板2)章节摘自【正点原子】DNK210使用指南-CanMV版V1.03)购买链接:https://detail.tmall.com/item.htm?&id=7828013987504)全套实验源码+手册+视频下载地址:http://www.openedv.com/docs/boards/k210/ATK-DNK210.htm......
  • 云渲染与汽车CGI图像技术优势和劣势
    在数字时代,云渲染技术以其独特的优势在汽车CGI图像制作中占据了重要地位。云渲染通过利用云计算的分布式处理能力,将渲染任务分配给云端的服务器集群进行计算,从而实现高效、高质量的渲染效果。这种技术的优势主要体现在以下几个方面:高性能和可扩展性:云渲染能够利用云计算平台的......
  • 推荐一款图像批量转换软件:ReaConverter Pro
    ReaConverterPro是一款图像批量转换软件,旨在帮助您轻松转换图像文件以及执行其他操作(例如调整大小)的应用程序。一个优雅而强大的实用程序,可帮助您批量编辑图像,然后将它们转换为多种其他格式,例如PNG、JPG、TIF或BMP。您可以使用基于资源管理器的布局来定位和访问图片。支......
  • 基于Open-CV的多四边形检测方案(一):图像预处理与霍夫变换
    目录一、设计目标二、工作流程三、图像预处理与霍夫变换一、设计目标对于一个含有多个相邻四边形的图片,可以定位出其中每一个四边形的顶点。典型的案例如一个围棋棋盘,可以定位出所有的格子的点。软件工具:Python==3.9opencv-python==4.10.0.84numpy==1.22.4二......
  • 揭秘OpenAI推出革命性sCM模型,0.1秒内出图?50倍速AI图像生成
    sCM是什么?sCM(连续时间一致性模型,Simplifying,Stabilizing,andScalingContinuous-timeConsistencyModels)是OpenAI推出的一种新型生成模型。它基于扩散模型的原理进行改进,通过简化理论框架和优化采样过程,实现了图像生成速度和质量的显著提升。与传统的扩散模型相比,sCM在生......
  • 基于yolov8的生猪检测和统计系统,支持图像、视频和摄像实时检测【pytorch框架、python
     更多目标检测和图像分类识别项目可看我主页其他文章功能演示:基于yolov8的生猪检测和统计系统,支持图像、视频和摄像实时检测【pytorch框架、python源码】_哔哩哔哩_bilibili(一)简介基于yolov8的生猪检测和统计系统是在PyTorch框架之下得以实现的。这是一个完备的项目,涵盖......