一. motivation
二. contribution
三.Network
1. 对于低光照的图片首先采用公式2获得SNR Map
(1)
Ig:是低光图片
:是经过cv.blur进行均值滤波后的图像
(2) 对Ig和Ig' 取得灰度图进行绝对值相减得到噪声N
(3)SNR(mask):均值滤波后的图像与噪声相除得到S
2. 先进行浅层特征提取
3. 对于fea进行深层特征提取,fea进行两个分支,一个分支(短分支)进行卷积块(更容易捕获局部信息)进行残差连接,另外一个分支(长分支)进行SNR引导的transformer结构(更容易捕获全局信息)
(1)短分支结构
(2)长分支结构
transformer包括Attention和FeedForward
对于Attention,对输入的特征进行归一化分别赋给q,k,v
q*k的转置是查看patch之间的相似度
对于mask的操作: 将mask分成s’个pacth,之后按照dim=2取得了mask的均值,采用公式,如果mask 的均值<0.5 则 取值为0,便于后面attention操作
对于mask取值为0的地方,attn中填充为很大的负数,在计算softmax这部分像素值不会发生变化
FeedForward: 首先归一化,之后全连接+relu,再归一化,一个跳连接。最后输出经过transformer结构的fea_unfold
再利用mask:进行特征融合
四. 损失
1.CharbonnierLoss2:
大佬链接:https://blog.csdn.net/weixin_43135178/article/details/120865709
class CharbonnierLoss2(nn.Module): """Charbonnier Loss (L1)""" def __init__(self, eps=1e-6): super(CharbonnierLoss2, self).__init__() self.eps = eps def forward(self, x, y): diff = x - y loss = torch.mean(torch.sqrt(diff * diff + self.eps)) return loss
2. perceptual loss
对于nn.L1Loss中reduction的说明:https://blog.csdn.net/qq_39450134/article/details/121745209
import torchvision class VGG19(torch.nn.Module): def __init__(self, requires_grad=False): super().__init__() vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features self.slice1 = torch.nn.Sequential() self.slice2 = torch.nn.Sequential() self.slice3 = torch.nn.Sequential() self.slice4 = torch.nn.Sequential() self.slice5 = torch.nn.Sequential() for x in range(2): self.slice1.add_module(str(x), vgg_pretrained_features[x]) for x in range(2, 7): self.slice2.add_module(str(x), vgg_pretrained_features[x]) for x in range(7, 12): self.slice3.add_module(str(x), vgg_pretrained_features[x]) for x in range(12, 21): self.slice4.add_module(str(x), vgg_pretrained_features[x]) for x in range(21, 30): self.slice5.add_module(str(x), vgg_pretrained_features[x]) if not requires_grad: for param in self.parameters(): param.requires_grad = False def forward(self, X): h_relu1 = self.slice1(X) h_relu2 = self.slice2(h_relu1) h_relu3 = self.slice3(h_relu2) h_relu4 = self.slice4(h_relu3) h_relu5 = self.slice5(h_relu4) out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] return out class VGGLoss(nn.Module): def __init__(self): super(VGGLoss, self).__init__() self.vgg = VGG19().cuda() # self.criterion = nn.L1Loss() self.criterion = nn.L1Loss(reduction='sum') self.criterion2 = nn.L1Loss() self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0] def forward(self, x, y): x_vgg, y_vgg = self.vgg(x), self.vgg(y) loss = 0 for i in range(len(x_vgg)): # print(x_vgg[i].shape, y_vgg[i].shape) loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach()) return loss
标签:__,torch,nn,light,Image,vgg,SNR,self From: https://www.cnblogs.com/yyhappy/p/17371090.html