import logging logging.basicConfig(level=logging.INFO, format='%(asctime)s-%(filename)s[line:%(lineno)d]-%(levelname)s:%(message)s', datefmt='%Y-%m-%d %H:%M:%S') import torch import torch.nn as nn import torch.nn.functional as F class iUnet(nn.Module): def __init__(self, num_classes=10): super().__init__() # 第一次卷积 - encode N 3 512 512 -> N 64 256 256 self.conv1_1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1) self.bn1_1 = nn.BatchNorm2d(64) self.relu1_1 = nn.ReLU(inplace=1) self.conv1_2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1) self.bn1_2 = nn.BatchNorm2d(64) self.relu1_2 = nn.ReLU(inplace=1) self.pool1 = nn.MaxPool2d(2) # N 64 512 512 -> N 64 256 256 # 第二次卷积 - encode N 64 256 256 -> N 128 128 128 self.conv2_1 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1) self.bn2_1 = nn.BatchNorm2d(128) self.relu2_1 = nn.ReLU(inplace=1) self.conv2_2 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1) self.bn2_2 = nn.BatchNorm2d(128) self.relu2_2 = nn.ReLU(inplace=1) self.pool2 = nn.MaxPool2d(2) # N 128 256 256 -> N 128 128 128 # 第三次卷积 - encode N 128 128 128 -> N 256 64 64 self.conv3_1 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1) self.bn3_1 = nn.BatchNorm2d(256) self.relu3_1 = nn.ReLU(inplace=1) self.conv3_2 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1) self.bn3_2 = nn.BatchNorm2d(256) self.relu3_2 = nn.ReLU(inplace=1) self.pool3 = nn.MaxPool2d(2) # N 256 128 128 -> N 256 64 64 # 第四次卷积 - encode N 256 64 64 -> N 512 32 32 self.conv4_1 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1) self.bn4_1 = nn.BatchNorm2d(512) self.relu4_1 = nn.ReLU(inplace=1) self.conv4_2 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1) self.bn4_2 = nn.BatchNorm2d(512) self.relu4_2 = nn.ReLU(inplace=1) self.pool4 = nn.MaxPool2d(2) # N 512 64 64 -> N 512 32 32 # 第五次卷积 - encode N 512 32 32 -> N 1024 32 32 self.conv5_1 = nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, padding=1) self.bn5_1 = nn.BatchNorm2d(1024) self.relu5_1 = nn.ReLU(inplace=1) self.conv5_2 = nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, padding=1) self.bn5_2 = nn.BatchNorm2d(1024) self.relu5_2 = nn.ReLU(inplace=1) # 第1次解码 - decode self.upsample1 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) self._conv1 = nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=3, padding=1) self._conv1_1 = nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=3, padding=1) self._bn1_1 = nn.BatchNorm2d(512) self._relu1_1 = nn.ReLU(inplace=1) self._conv1_2 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1) self._bn1_2 = nn.BatchNorm2d(512) self._relu1_2 = nn.ReLU(inplace=1) # 第2次解码 - decode self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) self._conv2 = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, padding=1) self._conv2_1 = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, padding=1) self._bn2_1 = nn.BatchNorm2d(256) self._relu2_1 = nn.ReLU(inplace=1) self._conv2_2 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1) self._bn2_2 = nn.BatchNorm2d(256) self._relu2_2 = nn.ReLU(inplace=1) # 第3次解码 - decode self.upsample3 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) self._conv3 = nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, padding=1) self._conv3_1 = nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, padding=1) self._bn3_1 = nn.BatchNorm2d(128) self._relu3_1 = nn.ReLU(inplace=1) self._conv3_2 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1) self._bn3_2 = nn.BatchNorm2d(128) self._relu3_2 = nn.ReLU(inplace=1) # 第4次解码 - decode self.upsample4 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) self._conv4 = nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, padding=1) self._conv4_1 = nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, padding=1) self._bn4_1 = nn.BatchNorm2d(64) self._relu4_1 = nn.ReLU(inplace=1) self._conv4_2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1) self._bn4_2 = nn.BatchNorm2d(64) self._relu4_2 = nn.ReLU(inplace=1) # 输出类别信息 self.out = nn.Conv2d(in_channels=64, out_channels=num_classes, kernel_size=1) def forward(self, x): # 编码 x = self.relu1_1(self.bn1_1(self.conv1_1(x))) x1 = self.relu1_2(self.bn1_2(self.conv1_2(x))) # logging.info(f'x1:{x1.shape}') x = self.pool1(x1) x = self.relu2_1(self.bn2_1(self.conv2_1(x))) x2 = self.relu2_2(self.bn2_2(self.conv2_2(x))) # logging.info(f'x2:{x2.shape}') x = self.pool2(x2) x = self.relu3_1(self.bn3_1(self.conv3_1(x))) x3 = self.relu3_2(self.bn3_2(self.conv3_2(x))) # logging.info(f'x3:{x3.shape}') x = self.pool3(x3) x = self.relu4_1(self.bn4_1(self.conv4_1(x))) x4 = self.relu4_2(self.bn4_2(self.conv4_2(x))) # logging.info(f'x4:{x4.shape}') x = self.pool4(x4) x = self.relu5_1(self.bn5_1(self.conv5_1(x))) x = self.relu5_2(self.bn5_2(self.conv5_2(x))) # logging.info(f'x5:{x.shape}') # 解码 x = self.upsample1(x) x = self._conv1(x) x = torch.cat([x, x4], dim=1) x = self._relu1_1(self._bn1_1(self._conv1_1(x))) x = self._relu1_2(self._bn1_2(self._conv1_2(x))) # logging.info(f'dx1:{x.shape}') x = self.upsample2(x) x = self._conv2(x) x = torch.cat([x, x3], dim=1) x = self._relu2_1(self._bn2_1(self._conv2_1(x))) x = self._relu2_2(self._bn2_2(self._conv2_2(x))) # logging.info(f'dx2:{x.shape}') x = self.upsample3(x) x = self._conv3(x) x = torch.cat([x, x2], dim=1) x = self._relu3_1(self._bn3_1(self._conv3_1(x))) x = self._relu3_2(self._bn3_2(self._conv3_2(x))) # logging.info(f'dx3:{x.shape}') x = self.upsample4(x) x = self._conv4(x) x = torch.cat([x, x1], dim=1) x = self._relu4_1(self._bn4_1(self._conv4_1(x))) x = self._relu4_2(self._bn4_2(self._conv4_2(x))) # logging.info(f'dx4:{x.shape}') x = self.out(x) return x if __name__ == '__main__': data = torch.randn(4, 3, 384, 384) net = iUnet() pred = net(data)标签:nn,实现,256,self,channels,unet,._,128,方法 From: https://www.cnblogs.com/ddzhen/p/17563842.html