import os import glob import h5py import random from PIL import Image from matplotlib import pyplot as plt import torch.utils.data as data from torch.utils.data import DataLoader import torchvision.transforms as tfs from torchvision.transforms import functional as FF from Utils.metrics import * from Utils.option import * class train_Dataset(data.Dataset): def __init__(self, path, mode='train', size=opt.crop_size, **kwargs): super(train_Dataset,self).__init__() self.size=size self.mode=mode self.format=format self.haze_imgs_dir = os.path.join(path, 'low') # haze_img dir path self.haze_imgs_list = os.listdir(self.haze_imgs_dir) # haze_img name list self.haze_imgs = [os.path.join(self.haze_imgs_dir, img) for img in self.haze_imgs_list] # haze_img path list self.clear_dir = os.path.join(path, 'high') # clean_img dir path self.mask = False if 'restored_mask' in kwargs: self.restored_mask = kwargs['restored_mask'].long() self.mask = True self.length = len(self.haze_imgs_list) def __getitem__(self, index): haze = Image.open(self.haze_imgs[index]) if isinstance(self.size, int): while haze.size[0]<self.size or haze.size[1]<self.size : index = random.randint(0,self.length) haze = Image.open(self.haze_imgs[index]) haze_name = self.haze_imgs[index].split('/')[-1] id = haze_name.split('_')[0] clear_name = id clear = Image.open(os.path.join(self.clear_dir, clear_name)) clear = tfs.CenterCrop(haze.size[::-1])(clear) if not isinstance(self.size,str): i,j,h,w=tfs.RandomCrop.get_params(haze,output_size=(self.size,self.size)) haze = FF.crop(haze,i,j,h,w) clear = FF.crop(clear,i,j,h,w) rand_hor = random.randint(0, 1) rand_rot = random.randint(0, 3) haze =self.augData_haze(haze.convert("RGB"), rand_hor, rand_rot) clear = self.augData_clear(clear.convert("RGB"), rand_hor, rand_rot) mask_flag = self.restored_mask[index] if self.mask else -1 return haze, clear, index, mask_flag def augData_haze(self, haze, rand_hor, rand_rot): haze=tfs.RandomHorizontalFlip(rand_hor)(haze) if rand_rot: haze=FF.rotate(haze,90*rand_rot) haze=tfs.ToTensor()(haze) # haze=tfs.Normalize(mean=[0.64, 0.6, 0.58],std=[0.14,0.15, 0.152])(haze) return haze def augData_clear(self, clear, rand_hor, rand_rot): clear=tfs.RandomHorizontalFlip(rand_hor)(clear) if rand_rot: clear=FF.rotate(clear,90*rand_rot) clear=tfs.ToTensor()(clear) return clear def __len__(self): return self.length class test_Dataset(data.Dataset): def __init__(self, path, mode='test'): super(test_Dataset,self).__init__() self.mode=mode self.format=format self.haze_imgs_dir = os.path.join(path, 'low') # haze_img dir path self.haze_imgs_list = os.listdir(self.haze_imgs_dir) # haze_img name list self.haze_imgs = [os.path.join(self.haze_imgs_dir, img) for img in self.haze_imgs_list] # haze_img path list self.clear_dir = os.path.join(path, 'high') # clean_img dir path self.length = len(self.haze_imgs_list) def __getitem__(self, index): haze = Image.open(self.haze_imgs[index]) haze_name = self.haze_imgs[index].split('/')[-1] id = haze_name.split('_')[0] clear_name = id clear = Image.open(os.path.join(self.clear_dir, clear_name)) clear = tfs.CenterCrop(haze.size[::-1])(clear) haze = self.augData_haze(haze.convert("RGB")) clear = self.augData_clear(clear.convert("RGB")) return haze, clear, haze_name def augData_haze(self, haze): haze=tfs.ToTensor()(haze) # haze=tfs.Normalize(mean=[0.64, 0.6, 0.58],std=[0.14,0.15, 0.152])(haze) return haze def augData_clear(self, clear): clear=tfs.ToTensor()(clear) return clear def __len__(self): return self.length
标签:11,self,haze,path,imgs,import,size From: https://www.cnblogs.com/yyhappy/p/17558596.html