# 判断某个文件是否是图像 # enswith判断是否以指定的.png,.jpg,.jpeg结尾的字符串 # 可以根据情况扩充图像类型,加入.bmp、.tif等 def is_image_file(filename): return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"]) # 读取图像转为YCbCr模式,得到Y通道 def load_img(filepath): img = Image.open(filepath).convert('YCbCr') y, _, _ = img.split() return y # 裁剪大小,宽高一致为300 # 如果想训练自己的数据集,请根据情况修改裁剪大小 CROP_SIZE = 300 # 封装数据集,适配后面的torch.utils.data.DataLoader中的dataset,定义成类似形式 # 类参数为图像文件夹路径和放大倍数 # __len__(self) 定义当被len()函数调用时的行为(返回容器中元素的个数) #__getitem__(self) 定义获取容器中指定元素的行为,相当于self[key],即允许类对象可以有索引操作。 #__iter__(self) 定义当迭代容器中的元素的行为 # 返回输入图像和标签,传入DataLoader的dataset参数 class DatasetFromFolder(Dataset): def __init__(self, image_dir, zoom_factor): super(DatasetFromFolder, self).__init__() self.image_filenames = [join(image_dir, x) for x in listdir(image_dir) if is_image_file(x)] # 图像路径列表 crop_size = CROP_SIZE - (CROP_SIZE % zoom_factor) # 处理放大倍数,防止用户瞎设置,本例只能设置为2,3,4,大小不变 # 数据集变换 # 还有一些其他的变换操作,如归一化等,遇到一个积累一个 self.input_transform = transforms.Compose([transforms.CenterCrop(crop_size), # 从图片中心裁剪成300*300 transforms.Resize( crop_size // zoom_factor), # Resize, 输入应该是缩放倍数后的图像,因为先缩小后放大 transforms.Resize( crop_size, interpolation=Image.BICUBIC), # 双三次插值 transforms.ToTensor()]) # 图像转成tensor # label标签,超分不是分类问题,定义成一样的就行 self.target_transform = transforms.Compose( [transforms.CenterCrop(crop_size), transforms.ToTensor()]) def __getitem__(self, index): input = load_img(self.image_filenames[index]) # 输入是图像的Y通道,即亮度通道 target = input.copy() input = self.input_transform(input) target = self.target_transform(target) return input, target def __len__(self): return len(self.image_filenames) # 图像个数
标签:__,self,SRCNN,transforms,image,图像,input,数据,预处理 From: https://www.cnblogs.com/anyview/p/18659155