文章目录
前言
图像恢复是计算机视觉领域中的一个重要研究方向,它旨在通过算法修复损坏、模糊或缺失的图像信息,从而恢复图像的原始质量。随着图像处理技术的不断发展,图像恢复在许多领域都发挥着重要作用,如医学影像、遥感图像、文化遗产保护等。
本次我将介绍一款在图像恢复的多个任务中表现都不错的一个网络Restormer,并介绍其环境配置与代码使用,帮助大家在实际项目中使用Restormer。
一、Restormer介绍
在计算机视觉领域,高分辨率图像恢复是一个重要的挑战。图像在采集、传输或处理过程中,往往因为各种原因受到模糊、噪声等干扰,导致图像质量下降。为了解决这个问题,研究者们提出了各种模型和技术。近年来,Transformer模型在自然语言处理和计算机视觉领域取得了巨大成功。然而,传统的Transformer模型在处理高分辨率图像时,由于其庞大的计算量和参数量,难以实现高效的处理。
为了解决这个问题,作者提出了一种高效的Transformer模型——Restormer。Restormer模型通过在构建块中进行了几个关键的设计,包括多头注意和前馈网络,使其能够捕获长程像素交互,同时仍然适用于大型图像。这种设计使得Restormer在处理高分辨率图像时,能够更高效地恢复图像质量。
Restormer模型在多个图像恢复任务中取得了最先进的结果。这些任务包括图像去模糊、单图像运动去模糊(单图像和双像素数据)和图像去噪(高斯灰度/颜色去噪和真实图像去噪)。这些结果证明了Restormer模型在图像恢复任务中的有效性。
模型结构:
二、环境安装与配置
1.下载项目
2.安装虚拟环境:
conda create -n pytorch181 python=3.7
conda activate pytorch181
3.安装依赖
conda install pytorch=1.8 torchvision cudatoolkit=10.2 -c pytorch
pip install matplotlib scikit-learn scikit-image opencv-python yacs joblib natsort h5py tqdm
pip install einops gdown addict future lmdb numpy pyyaml requests scipy tb-nightly yapf lpips
4.下载预训练模型:
三、代码使用与效果
运行下面代码,demo.py:修改def get_weights_and_parameters(task, parameters):函数中的
elif task == ‘Single_Image_Defocus_Deblurring’:
weights = r"F:\Restormer-main\Defocus_Deblurring\pretrained_models\single_image_defocus_deblurring.pth"
将权重路径修改为自己刚才下载的预训练模型路径,运行即可
## Restormer: Efficient Transformer for High-Resolution Image Restoration
## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang
## https://arxiv.org/abs/2111.09881
##--------------------------------------------------------------
##------- Demo file to test Restormer on your own images---------
## Example usage on directory containing several images: python demo.py --task Single_Image_Defocus_Deblurring --input_dir './demo/degraded/' --result_dir './demo/restored/'
## Example usage on a image directly: python demo.py --task Single_Image_Defocus_Deblurring --input_dir './demo/degraded/portrait.jpg' --result_dir './demo/restored/'
## Example usage with tile option on a large image: python demo.py --task Single_Image_Defocus_Deblurring --input_dir './demo/degraded/portrait.jpg' --result_dir './demo/restored/' --tile 720 --tile_overlap 32
##--------------------------------------------------------------
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
import os
from runpy import run_path
from skimage import img_as_ubyte
from natsort import natsorted
from glob import glob
import cv2
from tqdm import tqdm
import argparse
from pdb import set_trace as stx
import numpy as np
parser = argparse.ArgumentParser(description='Test Restormer on your own images')
parser.add_argument('--input_dir', default='./demo/degraded/', type=str, help='Directory of input images or path of single image')
parser.add_argument('--result_dir', default='./demo/restored/', type=str, help='Directory for restored results')
parser.add_argument('--task', default="Single_Image_Defocus_Deblurring", type=str, help='Task to run', choices=['Motion_Deblurring',
'Single_Image_Defocus_Deblurring',
'Deraining',
'Real_Denoising',
'Gaussian_Gray_Denoising',
'Gaussian_Color_Denoising'])
parser.add_argument('--tile', type=int, default=None, help='Tile size (e.g 720). None means testing on the original resolution image')
parser.add_argument('--tile_overlap', type=int, default=32, help='Overlapping of different tiles')
args = parser.parse_args()
def load_img(filepath):
return cv2.cvtColor(cv2.imread(filepath), cv2.COLOR_BGR2RGB)
def save_img(filepath, img):
cv2.imwrite(filepath,cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
def load_gray_img(filepath):
return np.expand_dims(cv2.imread(filepath, cv2.IMREAD_GRAYSCALE), axis=2)
def save_gray_img(filepath, img):
cv2.imwrite(filepath, img)
def get_weights_and_parameters(task, parameters):
if task == 'Motion_Deblurring':
weights = os.path.join('Motion_Deblurring', 'pretrained_models', 'motion_deblurring.pth')
elif task == 'Single_Image_Defocus_Deblurring':
weights = r"F:\Restormer-main\Defocus_Deblurring\pretrained_models\single_image_defocus_deblurring.pth"
elif task == 'Deraining':
weights = os.path.join('Deraining', 'pretrained_models', 'deraining.pth')
elif task == 'Real_Denoising':
weights = os.path.join('Denoising', 'pretrained_models', 'real_denoising.pth')
parameters['LayerNorm_type'] = 'BiasFree'
elif task == 'Gaussian_Color_Denoising':
weights = os.path.join('Denoising', 'pretrained_models', 'gaussian_color_denoising_blind.pth')
parameters['LayerNorm_type'] = 'BiasFree'
elif task == 'Gaussian_Gray_Denoising':
weights = os.path.join('Denoising', 'pretrained_models', 'gaussian_gray_denoising_blind.pth')
parameters['inp_channels'] = 1
parameters['out_channels'] = 1
parameters['LayerNorm_type'] = 'BiasFree'
return weights, parameters
task = args.task
inp_dir = args.input_dir
out_dir = os.path.join(args.result_dir, task)
os.makedirs(out_dir, exist_ok=True)
extensions = ['jpg', 'JPG', 'png', 'PNG', 'jpeg', 'JPEG', 'bmp', 'BMP']
if any([inp_dir.endswith(ext) for ext in extensions]):
files = [inp_dir]
else:
files = []
for ext in extensions:
files.extend(glob(os.path.join(inp_dir, '*.'+ext)))
files = natsorted(files)
if len(files) == 0:
raise Exception(f'No files found at {inp_dir}')
# Get model weights and parameters
parameters = {'inp_channels':3, 'out_channels':3, 'dim':48, 'num_blocks':[4,6,6,8], 'num_refinement_blocks':4, 'heads':[1,2,4,8], 'ffn_expansion_factor':2.66, 'bias':False, 'LayerNorm_type':'WithBias', 'dual_pixel_task':False}
weights, parameters = get_weights_and_parameters(task, parameters)
load_arch = run_path(os.path.join('basicsr', 'models', 'archs', 'restormer_arch.py'))
model = load_arch['Restormer'](**parameters)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
checkpoint = torch.load(weights)
model.load_state_dict(checkpoint['params'])
model.eval()
img_multiple_of = 8
print(f"\n ==> Running {task} with weights {weights}\n ")
with torch.no_grad():
for file_ in tqdm(files):
if torch.cuda.is_available():
torch.cuda.ipc_collect()
torch.cuda.empty_cache()
if task == 'Gaussian_Gray_Denoising':
img = load_gray_img(file_)
else:
img = load_img(file_)
input_ = torch.from_numpy(img).float().div(255.).permute(2,0,1).unsqueeze(0).to(device)
# Pad the input if not_multiple_of 8
height,width = input_.shape[2], input_.shape[3]
H,W = ((height+img_multiple_of)//img_multiple_of)*img_multiple_of, ((width+img_multiple_of)//img_multiple_of)*img_multiple_of
padh = H-height if height%img_multiple_of!=0 else 0
padw = W-width if width%img_multiple_of!=0 else 0
input_ = F.pad(input_, (0,padw,0,padh), 'reflect')
if args.tile is None:
## Testing on the original resolution image
restored = model(input_)
else:
# test the image tile by tile
b, c, h, w = input_.shape
tile = min(args.tile, h, w)
assert tile % 8 == 0, "tile size should be multiple of 8"
tile_overlap = args.tile_overlap
stride = tile - tile_overlap
h_idx_list = list(range(0, h-tile, stride)) + [h-tile]
w_idx_list = list(range(0, w-tile, stride)) + [w-tile]
E = torch.zeros(b, c, h, w).type_as(input_)
W = torch.zeros_like(E)
for h_idx in h_idx_list:
for w_idx in w_idx_list:
in_patch = input_[..., h_idx:h_idx+tile, w_idx:w_idx+tile]
out_patch = model(in_patch)
out_patch_mask = torch.ones_like(out_patch)
E[..., h_idx:(h_idx+tile), w_idx:(w_idx+tile)].add_(out_patch)
W[..., h_idx:(h_idx+tile), w_idx:(w_idx+tile)].add_(out_patch_mask)
restored = E.div_(W)
restored = torch.clamp(restored, 0, 1)
# Unpad the output
restored = restored[:,:,:height,:width]
restored = restored.permute(0, 2, 3, 1).cpu().detach().numpy()
restored = img_as_ubyte(restored[0])
f = os.path.splitext(os.path.split(file_)[-1])[0]
# stx()
if task == 'Gaussian_Gray_Denoising':
save_gray_img((os.path.join(out_dir, f+'.png')), restored)
else:
save_img((os.path.join(out_dir, f+'.png')), restored)
print(f"\nRestored images are saved at {out_dir}")
使用效果:
原图:
修复图: