import torch class myResNet(torch.nn.Module): def __init__(self, in_channels=3, num_classes=10): super(myResNet, self).__init__() # 第1层 self.conv0_1 = torch.nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3) self.bn0_1 = torch.nn.BatchNorm2d(64) self.relu0_1 = torch.nn.ReLU() self.dmp = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # 第2 3 层 self.conv1_1 = torch.nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) self.bn1_1 = torch.nn.BatchNorm2d(64) self.relu1_1 = torch.nn.ReLU() self.conv1_2 = torch.nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) self.bn1_2 = torch.nn.BatchNorm2d(64) self.relu1_2 = torch.nn.ReLU() # 第4 5层 self.conv2_1 = torch.nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) self.bn2_1 = torch.nn.BatchNorm2d(64) self.relu2_1 = torch.nn.ReLU() self.conv2_2 = torch.nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) self.bn2_2 = torch.nn.BatchNorm2d(64) self.relu2_2 = torch.nn.ReLU() # 第6 7层 self.conv3_0 = torch.nn.Conv2d(64, 128, kernel_size=1, stride=2) self.conv3_1 = torch.nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1) self.bn3_1 = torch.nn.BatchNorm2d(128) self.relu3_1 = torch.nn.ReLU() self.conv3_2 = torch.nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) self.bn3_2 = torch.nn.BatchNorm2d(128) self.relu3_2 = torch.nn.ReLU() # 第8 9层 self.conv4_1 = torch.nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) self.bn4_1 = torch.nn.BatchNorm2d(128) self.relu4_1 = torch.nn.ReLU() self.conv4_2 = torch.nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) self.bn4_2 = torch.nn.BatchNorm2d(128) self.relu4_2 = torch.nn.ReLU() # 第10 11层 self.conv5_0 = torch.nn.Conv2d(128, 256, kernel_size=1, stride=2) self.conv5_1 = torch.nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1) self.bn5_1 = torch.nn.BatchNorm2d(256) self.relu5_1 = torch.nn.ReLU() self.conv5_2 = torch.nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) self.bn5_2 = torch.nn.BatchNorm2d(256) self.relu5_2 = torch.nn.ReLU() # 第12 13层 self.conv6_1 = torch.nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) self.bn6_1 = torch.nn.BatchNorm2d(256) self.relu6_1 = torch.nn.ReLU() self.conv6_2 = torch.nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) self.bn6_2 = torch.nn.BatchNorm2d(256) self.relu6_2 = torch.nn.ReLU() # 第14 15层 self.conv7_0 = torch.nn.Conv2d(256, 512, kernel_size=1, stride=2) self.conv7_1 = torch.nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1) self.bn7_1 = torch.nn.BatchNorm2d(512) self.relu7_1 = torch.nn.ReLU() self.conv7_2 = torch.nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) self.bn7_2 = torch.nn.BatchNorm2d(512) self.relu7_2 = torch.nn.ReLU() # 第16 17层 self.conv8_1 = torch.nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) self.bn8_1 = torch.nn.BatchNorm2d(512) self.relu8_1 = torch.nn.ReLU() self.conv8_2 = torch.nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) self.bn8_2 = torch.nn.BatchNorm2d(512) self.relu8_2 = torch.nn.ReLU() # 第18层 self.fc = torch.nn.Linear(512, num_classes) def forward(self, x): # batch_size, 3, 224, 224 x = self.conv0_1(x) # bs, 64, 112, 112 x = self.bn0_1(x) x = self.relu0_1(x) x1 = self.dmp(x) # bs, 64, 56, 56 x = self.conv1_1(x1) # bs, 64, 56, 56 x = self.bn1_1(x) x = self.relu1_1(x) x = self.conv1_2(x) x = self.bn1_2(x) x = x + x1 x2 = self.relu1_2(x) x = self.conv2_1(x2) x = self.bn2_1(x) x = self.relu2_1(x) x = self.conv2_2(x) x = self.bn2_2(x) x = x + x2 x = self.relu2_2(x) # bs, 64, 56, 56 x3 = self.conv3_0(x) # bs, 128, 28, 28 x = self.conv3_1(x) x = self.bn3_1(x) x = self.relu3_1(x) x = self.conv3_2(x) x = self.bn3_2(x) x = x + x3 x4 = self.relu3_2(x) x = self.conv4_1(x4) x = self.bn4_1(x) x = self.relu4_1(x) x = self.conv4_2(x) x = self.bn4_2(x) x = x + x4 x = self.relu4_2(x) # bs, 128, 28, 28 x5 = self.conv5_0(x) # bs, 256, 14, 14 x = self.conv5_1(x) x = self.bn5_1(x) x = self.relu5_1(x) x = self.conv5_2(x) x = self.bn5_2(x) x = x + x5 x6 = self.relu5_2(x) x = self.conv6_1(x6) x = self.bn6_1(x) x = self.relu6_1(x) x = self.conv6_2(x) x = self.bn6_2(x) x = x + x6 x = self.relu6_2(x) # bs, 256, 14, 14 x7 = self.conv7_0(x) # bs, 512, 7, 7 x = self.conv7_1(x) x = self.bn7_1(x) x = self.relu7_1(x) x = self.conv7_2(x) x = self.bn7_2(x) x = x + x7 x8 = self.relu7_2(x) x = self.conv8_1(x8) x = self.bn8_1(x) x = self.relu8_1(x) x = self.conv8_2(x) x = self.bn8_2(x) x = x + x8 x = self.relu8_2(x) # bs, 512, 7, 7 x = torch.nn.functional.avg_pool2d(x, (x.shape[-2], x.shape[-1])) x = torch.flatten(x, 1, -1) x = self.fc(x) return x if __name__ == "__main__": tx = torch.randn((4, 3, 224, 224)) algo = myResNet() pred = algo(tx) print(pred.shape)
参考地址:https://mp.weixin.qq.com/s/eWeVWcEMLC9FIiFqKy5wqA
标签:resnet18,kernel,nn,实现,self,torch,stride,方法,size From: https://www.cnblogs.com/ddzhen/p/18464372