import torch import torch.nn as nn class BasicConv(nn.Module): def __init__(self, in_channel, out_channel, kernel_size, stride, bias=True, norm=False, relu=True, transpose=False): super(BasicConv, self).__init__() if bias and norm: bias = False padding = kernel_size // 2 layers = list() if transpose: padding = kernel_size // 2 - 1 layers.append( nn.ConvTranspose2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias)) else: layers.append(nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias)) if norm: layers.append(nn.BatchNorm2d(out_channel)) if relu: layers.append(nn.GELU()) self.main = nn.Sequential(*layers) def forward(self, x): return self.main(x) class Network(nn.Module): def __init__(self, in_channel=3, out_channel=20, relu_slope=0.2): super(Network, self).__init__() self.preConv = nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1, bias=True), self.spatialConv = nn.Sequential(*[ nn.Conv2d(out_channel, out_channel, kernel_size=3, padding=1, bias=True), nn.LeakyReLU(relu_slope, inplace=False), nn.Conv2d(out_channel, out_channel, kernel_size=3, padding=1, bias=True), nn.LeakyReLU(relu_slope, inplace=False) ]) self.fftConv2 = nn.Sequential(*[ nn.Conv2d(out_channel, out_channel, 1, 1, 0), nn.LeakyReLU(relu_slope, inplace=False), nn.Conv2d(out_channel, out_channel, 1, 1, 0) ]) self.fusion = nn.Conv2d(out_channel * 2, out_channel, 1, 1, 0) self.proConv = nn.Conv2d(out_channel, in_channel, 3, 1, bias=True), def forward(self, x1): print(x1.shape) x = self.preConv(x1) spatial_out = self.spatialConv(x) out = spatial_out + x x_fft = torch.fft.rfft2(out, norm='backward') x_amp = torch.abs(x_fft) x_phase = torch.angle(x_fft) enhanced_phase = self.fftConv2(x_phase) enhanced_amp = self.fftConv2(x_amp) x_fft_out1 = torch.fft.irfft2(x_amp * torch.exp(1j * enhanced_phase), norm='backward') x_fft_out2 = torch.fft.irfft2(enhanced_amp * torch.exp(1j * x_phase), norm='backward') out = self.fusion(torch.cat([out, x_fft_out1, x_fft_out2], dim=1)) out = self.proConv(out) return out
标签:11,nn,self,torch,bias,channel,out From: https://www.cnblogs.com/yyhappy/p/17797234.html