论文:Haar wavelet downsampling: A simple but effective downsampling module
GitHub地址 :https://github.com/apple1986/HWD
论文地址: https://www.sciencedirect.com/science/article/pii/S0031320323005174
这篇论文利用频域的小波变化来进行降采样,图像经过小波变化会得到1个低频特征A,和3个高频特征H、V、D。常规的下采样操作普遍面临的信息损失问题,应用Haar小波变换来降低特征图的空间分辨率,可以更完整的保留图像的信息。
```import torchimport torch.nn as nnfrom pytorch_wavelets
import DWTForward
class Down_wt(nn.Module):
def __init__(self, in_ch, out_ch):
super(Down_wt, self).__init__()
self.wt = DWTForward(J=1, mode='zero', wave='haar')
self.conv_bn_relu = nn.Sequential(
nn.Conv2d(in_ch * 4, out_ch, kernel_size=1, stride=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
)
def forward(self, x):
yL, yH = self.wt(x)
y_HL = yH[0][:, :, 0, ::]
y_LH = yH[0][:, :, 1, ::]
y_HH = yH[0][:, :, 2, ::]
x = torch.cat([yL, y_HL, y_LH, y_HH], dim=1)
x = self.conv_bn_relu(x)
return x
if __name__ == '__main__':
block = Down_wt(64, 64) # 输入通道数,输出通道数
input = torch.rand(3, 64, 64, 64)
# 输入tensor形状B C H W
output = block(input)
print(output.size())
```
代码非常简单,只有几行,我们非常简单的将其应用到我们的模型中,用于替换常规的下采样模块,像是卷积下采样、最大池化下采样、平均池化下采样等。
标签:采样,__,ch,nn,self,小波,wt,即插即用 From: https://blog.csdn.net/Angelina_Jolie/article/details/143380357