#dataset.py
import logging
import os
#处理c语言类型和二进制的数据结构
import struct
#regular expression module
import re
import time
import numpy as np
#这个模块pip的时候有问题,但是按照提示出现换了pip文件名就好使了,叫做scikit-image
#skimage的文档,主要是一些图像读取,显示,还有一些图像处理算法
rgb转换成hsv import skimage.io as skio
import torch as th
#不知道为什么作者要注释掉这个代码,后面也没用到它
#from scipy.stats import ortho_group
#hsv是一种格式,hue色调, saturation饱和度, value亮度
from skimage.color import rgb2hsv, hsv2rgb
from torch.utils.data import Dataset
from torchlib.image import read_pfm
from torchlib.utils import Timer
#创建日志
log = logging.getLogger(“demosaic_data”)
class SRGB2Linear(object):
def init(self):
调用父类的方法
super函数详解使用super好处
python的继承顺序
super(SRGB2Linear, self).init()
self.a = 0.055
self.thresh = 0.04045
self.scale = 12.92
self.gamma = 2.4
#gamma矫正
def call(self, im):
#im貌似是一个输入量
#boardcast把小的矩阵或张量拓展成和大的那个一样的
#如果im的像素小于等于thresh,返回 im / self.scale
#如果im的像素大于thresh,则(im+a)/(1+a)幂上gamma
return np.where(im <= self.thresh,
im / self.scale,
np.power((np.maximum(im, self.thresh) + self.a) / (1 + self.a), self.gamma))
#应该是把伽马矫正变回来
class Linear2SRGB(object):
def init(self):
super(Linear2SRGB, self).init()
self.a = 0.055
# 这个阈值和上面那个不一样,不知道它是怎么确定的
self.thresh = 0.0031308049535603713
self.scale = 12.92
self.gamma = 2.4
clamp用法,一个钳位功能
#算了一下,式子刚好是上面那个的反变换
def call(self, im):
return th.where(im <= self.thresh,
im * self.scale,
(1 + self.a)*th.pow(th.clamp(im, self.thresh), 1.0 / self.gamma) - self.a)
def bayer_mosaic(im):
“”“GRBG Bayer mosaic.”""
mos = np.copy(im)
#mask是和im一样大小的数组
mask = np.ones_like(im)
pycharm调试技巧 view as array是一个很好用的功能
#通道0田字格的右上保持原样,别的都是0
red
mask[0, ::2, 0::2] = 0
mask[0, 1::2, :] = 0
#通道1田字格的左上和右下保持原样,别的都是0
green
mask[1, ::2, 1::2] = 0
mask[1, 1::2, ::2] = 0
#通道0田字格的左下保持原样,别的都是0
blue
mask[2, 0::2, :] = 0
mask[2, 1::2, 1::2] = 0
#乘号是对应元素相乘
#返回两个值,用x,y就可以获取
return mosmask, mask
def xtrans_mosaic(im):
“”"XTrans Mosaick.
G b G G r G
r G r b G b
G b G G r G
G r G G b G
b G b r G r
G r G G b G
“”"
mask = np.zeros((3, 6, 6), dtype=np.float32)
g_pos = [(0,0), (0,2), (0,3), (0,5),
(1,1), (1,4),
(2,0), (2,2), (2,3), (2,5),
(3,0), (3,2), (3,3), (3,5),
(4,1), (4,4),
(5,0), (5,2), (5,3), (5,5)]
r_pos = [(0,4),
(1,0), (1,2),
(2,4),
(3,1),
(4,3), (4,5),
(5,1)]
b_pos = [(0,1),
(1,3), (1,5),
(2,1),
(3,4),
(4,0), (4,2),
(5,4)]
for y, x in g_pos:
mask[1, y, x] = 1
for y, x in r_pos:
mask[0, y, x] = 1
for y, x in b_pos:
mask[2, y, x] = 1
mos = np.copy(im)
_, h, w = mos.shape
mask = np.tile(mask, [1, np.ceil(h / 6).astype(np.int32), np.ceil(w / 6).astype(np.int32)])
mask = mask[:, :h, :w]
return mask*mos, mask
class DemosaicDataset(Dataset):
def init(self, filelist, add_noise=False, max_noise=0.1, transform=None,
augment=False, linearize=False):
self.transform = transform
self.add_noise = add_noise
self.max_noise = max_noise
self.augment = augment
#如果设置了线性化,就让self.linearizer等于被线性化伽马矫正过的numpy图像
if linearize:
self.linearizer = SRGB2Linear()
else:
self.linearizer = None
#分离文件路径//文件名和拓展名
# 这边应该要给filelist赋初值吧?
if not os.path.splitext(filelist)[-1] == ".txt":
raise ValueError("Dataset should be speficied as a .txt file")
self.root = os.path.dirname(filelist)
self.images = []
#读取filelist里面的图像
with open(filelist) as fid:
for l in fid.readlines():
im = l.strip()
self.images.append(os.path.join(self.root, im))
self.count = len(self.images)
def len(self):
return self.count
#没移植可真离谱
def make_mosaic(self, im):
return NotImplemented
def getitem(self, idx):
impath = self.images[idx]
# read image
im = skio.imread(impath).astype(np.float32) / 255.0
# if self.augment:
# # Jitter the quantized values
# im += np.random.normal(0, 0.005, size=im.shape)
# im = np.clip(im, 0, 1)
if self.augment:
if np.random.uniform() < 0.5:
im = np.fliplr(im)
if np.random.uniform() < 0.5:
im = np.flipud(im)
im = np.rot90(im, k=np.random.randint(0, 4))
# Pixel shift
if np.random.uniform() < 0.5:
shift_y = np.random.randint(0, 6) # cover both xtrans and bayer
im = np.roll(im, 1, 0)
if np.random.uniform() < 0.5:
shift_x = np.random.randint(0, 6)
im = np.roll(im, 1, 1)
# Random Hue/Sat
if np.random.uniform() < 0.5:
shift = np.random.uniform(-0.1, 0.1)
sat = np.random.uniform(0.8, 1.2)
im = rgb2hsv(im)
im[:, :, 0] = np.mod(im[:, :, 0] + shift, 1)
im[:, :, 1] *= sat
im = hsv2rgb(im)
im = np.clip(im, 0, 1)
if self.linearizer is not None:
im = self.linearizer(im)
# Randomize exposure
if self.augment:
if np.random.uniform() < 0.5:
im *= np.random.uniform(0.5, 1.2)
im = np.clip(im, 0, 1)
im = np.ascontiguousarray(im).astype(np.float32)
im = np.transpose(im, [2, 1, 0])
# crop boundaries to ignore shift
c = 8
im = im[:, c:-c, c:-c]
# apply mosaic
mosaic, mask = self.make_mosaic(im)
# TODO: separate GT/noisy
# # add noise
# std = 0
# if self.add_noise:
# std = np.random.uniform(0, self.max_noise)
# im += np.random.normal(0, std, size=im.shape)
# im = np.clip(im, 0, 1)
sample = {
"mosaic": mosaic,
"mask": mask,
# "noise_variance": np.array([std]),
"target": im,
}
# Augment
if self.transform is not None:
sample = self.transform(sample)
return sample
def repr(self):
s = “Dataset\n”
s += " . {} images\n".format(len(self.images))
return sclass ToBatch(object):
def call(self, sample):
for k in sample.keys():
if type(sample[k]) == np.ndarray:
sample[k] = np.expand_dims(sample[k], 0)
return sample
array型转换成Tensor
Python NumPy ndarray 入门指南 class ToTensor(object):
def call(self, sample):
for k in sample.keys():
if type(sample[k]) == np.ndarray:
sample[k] = th.from_numpy(sample[k])
return sample
class GreenOnly(object):
def call(self, sample):
sample[“target”][0] = 0
sample[“target”][2] = 0
return sampleclass BayerDataset(DemosaicDataset):
def make_mosaic(self, im):
return bayer_mosaic(im)class XtransDataset(DemosaicDataset):
def make_mosaic(self, im):
return xtrans_mosaic(im)