五、序号5,使用identityConv进行残差连接,最后对增强后的幅值、增强后的相位、空域进行Concat
class YYBlock(nn.Module): def __init__(self, in_channel=3, out_channel=20, relu_slope=0.2): super(YYBlock, self).__init__() self.spatialConv = nn.Sequential(*[ nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1, bias=True), nn.LeakyReLU(relu_slope, inplace=False), nn.Conv2d(out_channel, out_channel, kernel_size=3, padding=1, bias=True), nn.LeakyReLU(relu_slope, inplace=False) ]) self.identity = nn.Conv2d(in_channel, out_channel, 1, 1, 0) self.fftConv2 = nn.Sequential(*[ nn.Conv2d(out_channel, out_channel, 1, 1, 0), nn.LeakyReLU(relu_slope, inplace=False), nn.Conv2d(out_channel, out_channel, 1, 1, 0) ]) self.fusion = nn.Conv2d(out_channel * 3, out_channel, 1, 1, 0) # self.conv_01 = nn.Conv2d(in_size, out_size, 3, 1, 1) def forward(self, x1): spatial_out = self.spatialConv(x1) identity = self.identity(x1) out = spatial_out + identity x_fft = torch.fft.rfft2(out, norm='backward') x_amp = torch.abs(x_fft) x_phase = torch.angle(x_fft) enhanced_phase = self.fftConv2(x_phase) enhanced_amp = self.fftConv2(x_amp) x_fft_out1 = torch.fft.irfft2(x_amp * torch.exp(1j * enhanced_phase), norm='backward') x_fft_out2 = torch.fft.irfft2(enhanced_amp * torch.exp(1j * x_phase), norm='backward') out = self.fusion(torch.cat([out, x_fft_out1, x_fft_out2], dim=1)) return outYYBlock
标签:fft,torch,SFNet,nn,self,FFTBlock,模块,channel,out From: https://www.cnblogs.com/yyhappy/p/17799542.html