1. 在运行SFNet代码时,前后代码保持不变,运行两次结果发生变化, 把下面这段代码注掉就可以保持前后两次运行结果一致,不确定是否是nn.BatchNorm2d计算均值和方差导致
class dynamic_filter(nn.Module): def __init__(self, inchannels, mode, kernel_size=3, stride=1, group=8): super(dynamic_filter, self).__init__() self.stride = stride self.kernel_size = kernel_size self.group = group self.lamb_l = nn.Parameter(torch.zeros(inchannels), requires_grad=True) self.lamb_h = nn.Parameter(torch.zeros(inchannels), requires_grad=True) self.conv = nn.Conv2d(inchannels, group*kernel_size**2, kernel_size=1, stride=1, bias=False) self.bn = nn.BatchNorm2d(group*kernel_size**2) self.act = nn.Softmax(dim=-2) nn.init.kaiming_normal_(self.conv.weight, mode='fan_out', nonlinearity='relu') self.pad = nn.ReflectionPad2d(kernel_size//2) self.ap = nn.AdaptiveAvgPool2d((1, 1)) self.modulate = SFconv(inchannels, mode) def forward(self, x): identity_input = x low_filter = self.ap(x) low_filter = self.conv(low_filter) low_filter = self.bn(low_filter) n, c, h, w = x.shape x = F.unfold(self.pad(x), kernel_size=self.kernel_size).reshape(n, self.group, c//self.group, self.kernel_size**2, h*w) n,c1,p,q = low_filter.shape low_filter = low_filter.reshape(n, c1//self.kernel_size**2, self.kernel_size**2, p*q).unsqueeze(2) low_filter = self.act(low_filter) low_part = torch.sum(x * low_filter, dim=3).reshape(n, c, h, w) out_high = identity_input - low_part out = self.modulate(low_part, out_high) return out
标签:kernel,network,nn,self,filter,low,size,image,SFNet From: https://www.cnblogs.com/yyhappy/p/17782669.html