import argparse import logging import os.path import sys import time from collections import OrderedDict import torchvision.utils as tvutils import numpy as np import torch from IPython import embed import lpips from torchvision import utils as vutils import options as option from models import create_model # sys.path.insert(0, "../../") import utils as util from data import create_dataloader, create_dataset from data.util import bgr2ycbcr from utils.metrics import * #### options parser = argparse.ArgumentParser() parser.add_argument("-opt", type=str, help="Path to options YMAL file.", default='options/test/refusion.yml') opt = option.parse(parser.parse_args().opt, is_train=False) opt = option.dict_to_nonedict(opt) #### mkdir and logger util.mkdirs( ( path for key, path in opt["path"].items() if not key == "experiments_root" and "pretrain_model" not in key and "resume" not in key ) ) # os.systemc util.setup_logger( "base", opt["path"]["log"], "test_" + opt["name"], level=logging.INFO, screen=True, tofile=True, ) logger = logging.getLogger("base") logger.info(option.dict2str(opt)) #### Create test dataset and dataloader test_loaders = [] for phase, dataset_opt in sorted(opt["datasets"].items()): test_set = create_dataset(dataset_opt) test_loader = create_dataloader(test_set, dataset_opt) logger.info("Number of test images in [{:s}]: {:d}".format(dataset_opt["name"], len(test_set))) test_loaders.append(test_loader) # load pretrained model by default model = create_model(opt) device = model.device sde = util.IRSDE(max_sigma=opt["sde"]["max_sigma"], T=opt["sde"]["T"], schedule=opt["sde"]["schedule"], eps=opt["sde"]["eps"], device=device) sde.set_model(model.model) lpips_fn = lpips.LPIPS(net='alex').to(device) scale = opt['degradation']['scale'] for test_loader in test_loaders: test_set_name = test_loader.dataset.opt["name"] # path opt[''] logger.info("\nTesting [{:s}]...".format(test_set_name)) test_start_time = time.time() dataset_dir = os.path.join(opt["path"]["results_root"], test_set_name) util.mkdir(dataset_dir) ssims_1, psnrs_1 = [], [] for i, test_data in enumerate(test_loader): need_GT = False if test_loader.dataset.opt["dataroot_GT"] is None else True img_path = test_data["GT_path"][0] if need_GT else test_data["LQ_path"][0] img_name = os.path.splitext(os.path.basename(img_path))[0] #### input dataset_LQ LQ, GT = test_data["LQ"], test_data["GT"] noisy_state = sde.noise_state(LQ) model.feed_data(noisy_state, LQ, GT) pred_img = model.test(sde, save_states=True) visuals = model.get_current_visuals() SR_img = visuals["Output"] GT = visuals["GT"] if opt['save_img']: save_dir = opt["savepath"] save_img_path = os.path.join(save_dir, img_name + ".png") print(save_img_path) vutils.save_image(SR_img.float(), save_img_path, normalize=True) per_ssim_1 = ssim(SR_img, GT).item() per_psnr_1 = psnr(SR_img, GT) ssims_1.append(per_ssim_1) psnrs_1.append(per_psnr_1) print(f'\n {img_name} iter processing:{i + 1} psnr:{per_psnr_1:.4f} ssim:{per_ssim_1:.4f}', end='',flush=True) avg_ssim_1 = np.mean(ssims_1) avg_psnr_1 = np.mean(psnrs_1) print(f'\navg_psnr:{avg_psnr_1:.4f} avg_ssim:{avg_ssim_1:.4f} ', end='', flush=True)
标签:opt,img,dataset,path,test,import,222 From: https://www.cnblogs.com/yyhappy/p/17931249.html