EDSR源码笔记
1. common
1. default_conv
def default_conv(in_channels, out_channels, kernel_size, bias=True):
return nn.Conv2d(
in_channels, out_channels, kernel_size,
padding=(kernel_size//2), bias=bias)
仅仅是单纯的Conv,没有什么太特殊的,这边唯一需要注意的是
padding=(kernel_size//2)
这句话使得卷积前后的图像大小不变。
2. MeanShift
class MeanShift(nn.Conv2d):
def __init__(
self, rgb_range,
rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1):
super(MeanShift, self).__init__(3, 3, kernel_size=1)
std = torch.Tensor(rgb_std)
#Shift卷积块的设置
self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1)
self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std
#梯度取消
for p in self.parameters():
p.requires_grad = False
该模块还不能确定具体的工作流程,但是根据查到的资料显示,是MeanShift算法(看到网上大多是Shift聚类算法,但是似乎跟卷积相关的很少),具体连接如下:
解读模型压缩4:你一定从未见过如此通俗易懂的模型压缩Shift操作解读
这边简单的记录以下目前已知的东西
Shift操作:
Shift表意是将一个图像平移的过程,这边以图为例:
(其中滤波器有颜色的一格为1,白色的是0,可以看出,Shift操作的表现就是将这个通道上的图像向某个方向平移)
那么为什么要这么做呢?
- 这里先确定我们的目标——获取图片空间信息,这在以前我们通过普通CNN的卷积核移动实现
- 我们又知道,参考NiN中1*1的卷积,这种1*1的卷积核能够融合通道的信息。
基于以上两点,我们提出Shift操作来提取空间信息(先po张图):
由上面的图片可以看出,Shift在每个通道上将图像向空间上的不同方向进行平移(这些方向的问题会在文章再后面一点指出),虽然简单的平移操作似乎没有提取到空间信息,但是考虑到,通道域是空间域信息的层次化扩散。因此通过设置不同方向的shift卷积核,可以将输入张量不同通道进行平移,随后配合1x1卷积实现跨通道的信息融合,即可实现空间域和通道域的信息提取。
形象点理解就是卷积就是我们拿放大镜看东西,为了看完整个大图片,我们要么手自己动(普通CNN中卷积核平移),要么我们让图片动(Shift中每个通道移动)。
以上便是Shift操作的原理,那么接下来回答上面关于方向的问题,如何确定Shift的方向来更好的提取图片空间信息。
我先剧透一下,虽然比较扯蛋,但其实Shift操作就好在我们不需要考虑方向的问题。
以下是解释,先po张图:
对于一个图片的每一个Channel,有D2(D是图片的维数)个可以移动的方向.
我们假设存在Shift的最优解,但是显然暴力的搜索出这个Shift方案并不靠谱,于是我们想到近似:对于每一个方向,我们假设需要平移的channel数都是相同的(对应最左边的图,同一颜色代表这些Channel有相同的应该移动的方向),因此我们可以通过将原先的m个通道分组,分为m/D2个组,每个组有自己的一个移动方向,然后想象将方向相同的放到一块(对应中间图),随后再通过一个合理但未知的排序(对应最右边的图),就能获得在我们假设下的最优解。
于是问题就转化为如何获得最优排序顺序,针对这个问题,我们已经知道1*1卷积会将所有的Channel合并,因此我们可喜的发现,在Shift通道融合,也就是1*1卷积的那一步,我们完全不需要考虑通道之间的顺序,因为卷积就是将所有通道中空间上这个位置的信息全部融合,跟通道自身的顺序无关(有点加法交换律的味道我们需要的只是加起来的值,而不关心中间怎么加的)。
因此我们就了解了整个Shift的过程和原理,这里再提议看一下该部分开头连接中的深度可分离卷积,可能会对Channel和空间地位互换有更深的理解。
讲完了Shift,提议下这个类的内容,尽管结构比较简单,但是里面参数的设置还是比较难懂的,我有的疑问是:
- rgb_range,rgb_mean,rgb_std代表什么
- 为什么weight中的参数要这么设置
这边简要提议下目前已经了解的知识:
- torch.view()是视图,可以理解为reshape
- 因为是Shift中卷积核本身不需要改参数,只要完成Shift便能提取信息,因此需要循环将weight设为不需要梯度。
遗落的以后补充
3. ResBlock
class ResBlock(nn.Module):
def __init__(
self, conv, n_feats, kernel_size,
bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
super(ResBlock, self).__init__()
m = []
for i in range(2):
m.append(conv(n_feats, n_feats, kernel_size, bias=bias))
if bn:
m.append(nn.BatchNorm2d(n_feats))
if i == 0:
m.append(act)
self.body = nn.Sequential(*m)
self.res_scale = res_scale
def forward(self, x):
res = self.body(x).mul(self.res_scale)
res += x
return res
该类是个比较常见的残差模块,比较值得一提的就是torch.mul是对应位相乘,进行了网络输出的缩放,有这个操作也就要求了输入x经过conv后形状不能变(因此猜测是此处的conv传的是上面定义的Basic_Conv。)
4. Upsampler
class Upsampler(nn.Sequential):
def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True):
m = []
if (scale & (scale - 1)) == 0: # Is scale = 2^n?
for _ in range(int(math.log(scale, 2))):#重复操作
m.append(conv(n_feats, 4 * n_feats, 3, bias))
m.append(nn.PixelShuffle(2))
if bn:
m.append(nn.BatchNorm2d(n_feats))
if act == 'relu':
m.append(nn.ReLU(True))
#nn.ReLU(True)表示会修改输入的值
elif act == 'prelu': #具有可学习参数的ReLU
m.append(nn.PReLU(n_feats))
elif scale == 3:
m.append(conv(n_feats, 9 * n_feats, 3, bias))
m.append(nn.PixelShuffle(3))
if bn:
m.append(nn.BatchNorm2d(n_feats))
if act == 'relu':
m.append(nn.ReLU(True))
#nn.ReLU(True)表示会修改输入的值
elif act == 'prelu':
m.append(nn.PReLU(n_feats))
else:
raise NotImplementedError
super(Upsampler, self).__init__(*m)
看名字就知道,上采样模块,由于是接触的第一个用卷积核实现的上采样,还是记录以下,这个模块的重点是nn.PixelShuffle这个类,po一个链接:
可以看到,只要我们指定upscale_factor(或者叫做c),我们就可以压缩通道但获得更大的二维图像,那么对于一个普通的图片,如果我们想要得到放大C倍的图像,我们的通道就会被压缩C^2倍,因此在这个该类前,我们还通过一层卷积来获取足够的通道数,代码中的4,9,便对应了放大2倍和3倍。
以上便是EDSRcommon文件的源码解读
2. edsr
从开篇的import可以看出,只跟common文件有关,大大减少工作量。
1. url?
url = {
'r16f64x2': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x2-1bc95232.pt',
'r16f64x3': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x3-abf2a44e.pt',
'r16f64x4': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x4-6b446fab.pt',
'r32f256x2': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x2-0edfb8a3.pt',
'r32f256x3': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x3-ea3ef2c6.pt',
'r32f256x4': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x4-4f62e9ef.pt'
}
这个没啥搞头,要配合demo.sh使用,中间的连接应该是预训练好的模型,也仅仅只是定义了一个字典,不必在意。
2. make_model EDSR
这一块就是网络的主题架构了,但其实只要前面的看懂了,这里只不过是组装而已,先po上论文中的架构。
在po上代码,需要注意的已经在代码中注释了:
class EDSR(nn.Module):
def __init__(self, args, conv=common.default_conv):
super(EDSR, self).__init__()
n_resblocks = args.n_resblocks
n_feats = args.n_feats
kernel_size = 3
scale = args.scale[0]
act = nn.ReLU(True)
url_name = 'r{}f{}x{}'.format(n_resblocks, n_feats, scale)
if url_name in url:
self.url = url[url_name]
else:
self.url = None
#下面的两个shift由于shift的模块没有看太懂,后面会了补充
self.sub_mean = common.MeanShift(args.rgb_range)
self.add_mean = common.MeanShift(args.rgb_range, sign=1)
#以下开始都挺好懂的,就是按照论文的顺序
# define head module
m_head = [conv(args.n_colors, n_feats, kernel_size)]
# define body module
m_body = [
common.ResBlock(
conv, n_feats, kernel_size, act=act, res_scale=args.res_scale
) for _ in range(n_resblocks)
]
m_body.append(conv(n_feats, n_feats, kernel_size))
# define tail module
m_tail = [
common.Upsampler(conv, scale, n_feats, act=False),
conv(n_feats, args.n_colors, kernel_size)
]
self.head = nn.Sequential(*m_head)
self.body = nn.Sequential(*m_body)
self.tail = nn.Sequential(*m_tail)
def forward(self, x):
x = self.sub_mean(x)
x = self.head(x)
res = self.body(x)
res += x
x = self.tail(res)
x = self.add_mean(x)
return x
3. load_state-dict
def load_state_dict(self, state_dict, strict=True):
own_state = self.state_dict()
for name, param in state_dict.items():
if name in own_state:
if isinstance(param, nn.Parameter):
param = param.data
try:
own_state[name].copy_(param)
except Exception:
if name.find('tail') == -1:
raise RuntimeError('While copying the parameter named {}, '
'whose dimensions in the model are {} and '
'whose dimensions in the checkpoint are {}.'
.format(name, own_state[name].size(), param.size()))
elif strict:
if name.find('tail') == -1:
raise KeyError('unexpected key "{}" in state_dict'
.format(name))
这个模块是加载训练状况的,具体……可以参考在option中的参数help,这里的代码纠错比较严谨,有非常多的Error,但是理解上可以理解为一个用来给网络加载参数的方法。
3. trainer
Trainer(是的,这个文件只有这一个类)
1. init
没什么好说的,就是一些参数的设置,其中需要注意的是optimizer的设置,在utility中,这在之后讲utility的时候会提,总之就是获得一个优化器。
2. prepare
该方法目的是将args中的东西加载到device上,其中_prepare函数中的tensor=tensor.half()应该是将数组转为浮点数类型。同时要注意,这里的args中可能有多个tensor,因此要在整个方法中的return中遍历进行转移。
3. terminate
该方法在main中被调用,大致上就是说,如果我args中如果说了test_only,那么将模型设置为test模式,但如果是训练,就检查epoch大小,关于这个大小的判定条件,要具体看optimizer中的源码,了解get_last_epoch是什么意思
更新:get_last_epoch是已经走了多少轮,然后这边检查的其实是有没有训练完,这里的epoch就是已经训练的轮数,然后args中的就是需要训练的轮数。
4. train
参考平常那些训练代码,应该是比较好理解训练部分,下面记录一些新的东西。:
- 除了与平时不同加上了一个计时器。
- 这里的epoch由于要考虑到从上一个checkpoint开始,因此需要从优化器读出已经走了多少个epoch(get_last_epoch()),在这个基础上+1。
- 引入ckp(check point),并且引入了相关的log(日志)
非常难过的是,我只能这么简单讲一下,每一个方法,最好的就是看懂他在干什么,由于传入参数的时候每次都是传不定参数,导致理解起来会非常困难。看懂每一段在干什么我觉得就可以了,下面记录一段完全看不懂不知道在干啥的代码:
if self.args.gclip > 0:
utils.clip_grad_value_(
self.model.parameters(),
self.args.gclip
)
上面这一段是梯度裁减,gclip参数是梯度裁减的门槛,po上option中的解释:
parser.add_argument('--gclip', type=float, default=0,
help='gradient clipping threshold (0 = no clipping)')
另外,尽管下面的代码很长,但是真正的内容只有写log这一件事,其中涉及到了好多方法,都在utilities中的checkpoint中有,这边就不讲了。
if (batch + 1) % self.args.print_every == 0:
self.ckp.write_log('[{}/{}]\t{}\t{:.1f}+{:.1f}s'.format(
(batch + 1) * self.args.batch_size,
len(self.loader_train.dataset),
self.loss.display_loss(batch),
timer_model.release(),
timer_data.release()))
尽管这一段代码很长,但其实旨在写log,了解一下log写进去了啥即可。
5. test
这里面的大多数内容都是log中内容的添加还有结果的存储,能力问题,主要log也不太理解用法,想先留个坑,以后再说:
def test(self):
torch.set_grad_enabled(False)
epoch = self.optimizer.get_last_epoch()
self.ckp.write_log('\nEvaluation:')
self.ckp.add_log(
torch.zeros(1, len(self.loader_test), len(self.scale))
)
self.model.eval()
timer_test = utility.timer()
if self.args.save_results: self.ckp.begin_background()
for idx_data, d in enumerate(self.loader_test):
for idx_scale, scale in enumerate(self.scale):
d.dataset.set_scale(idx_scale)
for lr, hr, filename in tqdm(d, ncols=80):
lr, hr = self.prepare(lr, hr)
sr = self.model(lr, idx_scale)
sr = utility.quantize(sr, self.args.rgb_range)
save_list = [sr]
self.ckp.log[-1, idx_data, idx_scale] += utility.calc_psnr(
sr, hr, scale, self.args.rgb_range, dataset=d
)
if self.args.save_gt:
save_list.extend([lr, hr])
if self.args.save_results:
self.ckp.save_results(d, filename[0], save_list, scale)
self.ckp.log[-1, idx_data, idx_scale] /= len(d)
best = self.ckp.log.max(0)
self.ckp.write_log(
'[{} x{}]\tPSNR: {:.3f} (Best: {:.3f} @epoch {})'.format(
d.dataset.name,
scale,
self.ckp.log[-1, idx_data, idx_scale],
best[0][idx_data, idx_scale],
best[1][idx_data, idx_scale] + 1
)
)
self.ckp.write_log('Forward: {:.2f}s\n'.format(timer_test.toc()))
self.ckp.write_log('Saving...')
4. utility
1. timer
class timer():
def __init__(self):
self.acc = 0
self.tic()
def tic(self):
self.t0 = time.time() #time.time()返回当前的时间戳
def toc(self, restart=False): #该函数用来获取经过了多长的时间
diff = time.time() - self.t0
if restart: self.t0 = time.time()
return diff
def hold(self): #该函数用来累加时间
self.acc += self.toc()
def release(self): #与reset差不多,拥有了能够返回累加值的功能
ret = self.acc
self.acc = 0
return ret
def reset(self): #计时器清零
self.acc = 0
顾名思义,是个计时器,里面封装的功能都挺浅显易懂的。可能帮助理解的都注释了。
2.checkpoint
比较长,分方法参数讲解:
1. init
def __init__(self, args):
self.args = args
self.ok = True
self.log = torch.Tensor()
now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')
if not args.load:#如果不是加载
if not args.save:#如果没有保存
args.save = now #设保存时间
self.dir = os.path.join('..', 'experiment', args.save) #设置保存的文件路径
else: #从上一次的结果加载
self.dir = os.path.join('..', 'experiment', args.load)
if os.path.exists(self.dir): #如果已经存在了
self.log = torch.load(self.get_path('psnr_log.pt'))
print('Continue from epoch {}...'.format(len(self.log))) #从上次结束的epoch开始
else:
args.load = ''
if args.reset: #重新训练
os.system('rm -rf ' + self.dir) #向这种语句等同于在控制行中运行 rm -rf self.dir,清除之前存放的文件夹
args.load = ''
#创建文件夹,第一句新建总的文件夹,第二句新建moudle文件夹,getpath可以理解为在总文件夹下找到model的路径。exist_ok表示如果存在相同的目录,不会触发异常。第三句是创建结果的文件夹.
os.makedirs(self.dir, exist_ok=True)
os.makedirs(self.get_path('model'), exist_ok=True)
for d in args.data_test:
os.makedirs(self.get_path('results-{}'.format(d)), exist_ok=True)
#a为追加,w为重新写,原有内容会删除
open_type = 'a' if os.path.exists(self.get_path('log.txt'))else 'w'
self.log_file = open(self.get_path('log.txt'), open_type) #打开log.txt
with open(self.get_path('config.txt'), open_type) as f: #这边可以理解为文件句柄处理的一个固定的代码,详情可以查看with的用法,总之就是打开这个文件,然后写入下面的东西。
f.write(now + '\n\n')
for arg in vars(args):
f.write('{}: {}\n'.format(arg, getattr(args, arg)))
f.write('\n')
self.n_processes = 8 #应该是处理器的个数
更新:这里在init的部分比较多,单独拎出一段说一下:
open_type = 'a' if os.path.exists(self.get_path('log.txt'))else 'w'
self.log_file = open(self.get_path('log.txt'), open_type)
with open(self.get_path('config.txt'), open_type) as f:
f.write(now + '\n\n')
for arg in vars(args):
f.write('{}: {}\n'.format(arg, getattr(args, arg)))
f.write('\n')
2. get_path
def get_path(self, *subdir):
return os.path.join(self.dir, *subdir)
顾名思义,取得一个文件的地址,他的参数是多参,os.path.join其实就是拼接括号中的内容,因此get_path参数直接写总文件夹后的分文件夹就可以了。
3. *save
def save(self, trainer, epoch, is_best=False):
trainer.model.save(self.get_path('model'), epoch, is_best=is_best)
trainer.loss.save(self.dir)
trainer.loss.plot_loss(self.dir, epoch)
self.plot_psnr(epoch)
trainer.optimizer.save(self.dir)
torch.save(self.log, self.get_path('psnr_log.pt'))
就是保存模型现在的状态,这里面的trainer.loss由于trainer接受的是Mytrainer参数,因此猜测这个save是框架中的功能。同时在log中存储psnr-log.pt的地址,这类保存比较复杂,同时对于log这个参数并不是很理解,以后再来补。
4. *add_log,write_log,done
def add_log(self, log):
self.log = torch.cat([self.log, log])
def write_log(self, log, refresh=False):
print(log)
self.log_file.write(log + '\n')
if refresh:
self.log_file.close()
self.log_file = open(self.get_path('log.txt'), 'a')
def done(self):
self.log_file.close()
上面的两个操作都是针对log的操作,同上由于理解问题,就不一一细讲了。
5. plot_psnr
def plot_psnr(self, epoch):
axis = np.linspace(1, epoch, epoch)
for idx_data, d in enumerate(self.args.data_test):
label = 'SR on {}'.format(d)
fig = plt.figure()
plt.title(label)
for idx_scale, scale in enumerate(self.args.scale):
plt.plot(
axis,
self.log[:, idx_data, idx_scale].numpy(),
label='Scale {}'.format(scale)
)
plt.legend()
plt.xlabel('Epochs')
plt.ylabel('PSNR')
plt.grid(True)
plt.savefig(self.get_path('test_{}.pdf'.format(d)))
plt.close(fig)
关于打印从来就没懂过,总之我知道这个是打印psnr(可以理解为衡量图像超分的好坏)。
6. save_results
def save_results(self, dataset, filename, save_list, scale):
if self.args.save_results:
filename = self.get_path(
'results-{}'.format(dataset.dataset.name),
'{}_x{}_'.format(filename, scale)
)
postfix = ('SR', 'LR', 'HR')
for v, p in zip(save_list, postfix):
normalized = v[0].mul(255 / self.args.rgb_range)
tensor_cpu = normalized.byte().permute(1, 2, 0).cpu()
self.queue.put(('{}{}.png'.format(filename, p), tensor_cpu))
保存结果,看不懂捏。
3. quantize(量化)
def quantize(img, rgb_range):
pixel_range = 255 / rgb_range
return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range)
4. calc_psnr
def calc_psnr(sr, hr, scale, rgb_range, dataset=None):
if hr.nelement() == 1: return 0
diff = (sr - hr) / rgb_range
if dataset and dataset.dataset.benchmark:
shave = scale
if diff.size(1) > 1:
gray_coeffs = [65.738, 129.057, 25.064]
convert = diff.new_tensor(gray_coeffs).view(1, 3, 1, 1) / 256
diff = diff.mul(convert).sum(dim=1)
else:
shave = scale + 6
valid = diff[..., shave:-shave, shave:-shave]
mse = valid.pow(2).mean()
return -10 * math.log10(mse)
应该就是计算sr跟hr的psnr。不知道与上面的plot_psnr有没有连接。
5. make_optimizer
这个就比较重要了,代码中有标
make optimizer and scheduler together
-
首先我们看optimizer模块:
# optimizer
trainable = filter(lambda x: x.requires_grad, target.parameters())
kwargs_optimizer = {'lr': args.lr, 'weight_decay': args.weight_decay}
if args.optimizer == 'SGD':
optimizer_class = optim.SGD
kwargs_optimizer['momentum'] = args.momentum
elif args.optimizer == 'ADAM':
optimizer_class = optim.Adam
kwargs_optimizer['betas'] = args.betas
kwargs_optimizer['eps'] = args.epsilon
elif args.optimizer == 'RMSprop':
optimizer_class = optim.RMSprop
kwargs_optimizer['eps'] = args.epsilon
这里可以理解为通过指定优化器,将不同的参数加入到kwargs_optimizer中
-
再看看Scheduler模块:
milestones = list(map(lambda x: int(x), args.decay.split('-')))
kwargs_scheduler = {'milestones': milestones, 'gamma': args.gamma}
scheduler_class = lrs.MultiStepLR
Scheduler模块是用来动态调整学习率的,下面详解所需要的前件:
-
milestones:是一个数组,里面装着需要调整学习率的epoch位置,例如[50,70]
-
gamma为倍数,假设lr=0.1,如果lr开始为0.01,则epoch为50时变为0.001,70为0.0001
-
last_epoch=-1,当last_epoch=-1,设定为初始lr。老哥说lastepoch表示跑了多少个epoch。是用来从checkpoint中恢复的
-
last_epoch表示已经走了多少个epoch,下一个milestone减去last_epoch就是需要的epoch数。
-
使用scheduler.get_lr(),会在milestone的时候乘以gamma的平方
-
新版的pytorch没有get_lr()函数了,应该用get_last_lr()代替get_lr(),而且 get_last_lr() 也没有 "乘gamma平方" 这个问题了。