首页 > 其他分享 >笨方法实现resnet18

笨方法实现resnet18

时间:2024-10-14 15:50:08浏览次数:3  
标签:resnet18 kernel nn 实现 self torch stride 方法 size

import torch


class myResNet(torch.nn.Module):
    def __init__(self, in_channels=3, num_classes=10):
        super(myResNet, self).__init__()
        # 第1层
        self.conv0_1 = torch.nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3)
        self.bn0_1 = torch.nn.BatchNorm2d(64)
        self.relu0_1 = torch.nn.ReLU()
        self.dmp = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        # 第2 3 层
        self.conv1_1 = torch.nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.bn1_1 = torch.nn.BatchNorm2d(64)
        self.relu1_1 = torch.nn.ReLU()
        self.conv1_2 = torch.nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.bn1_2 = torch.nn.BatchNorm2d(64)
        self.relu1_2 = torch.nn.ReLU()

        # 第4 5层
        self.conv2_1 = torch.nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.bn2_1 = torch.nn.BatchNorm2d(64)
        self.relu2_1 = torch.nn.ReLU()
        self.conv2_2 = torch.nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.bn2_2 = torch.nn.BatchNorm2d(64)
        self.relu2_2 = torch.nn.ReLU()

        # 第6 7层
        self.conv3_0 = torch.nn.Conv2d(64, 128, kernel_size=1, stride=2)
        self.conv3_1 = torch.nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
        self.bn3_1 = torch.nn.BatchNorm2d(128)
        self.relu3_1 = torch.nn.ReLU()
        self.conv3_2 = torch.nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
        self.bn3_2 = torch.nn.BatchNorm2d(128)
        self.relu3_2 = torch.nn.ReLU()

        # 第8 9层
        self.conv4_1 = torch.nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
        self.bn4_1 = torch.nn.BatchNorm2d(128)
        self.relu4_1 = torch.nn.ReLU()
        self.conv4_2 = torch.nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
        self.bn4_2 = torch.nn.BatchNorm2d(128)
        self.relu4_2 = torch.nn.ReLU()

        # 第10 11层
        self.conv5_0 = torch.nn.Conv2d(128, 256, kernel_size=1, stride=2)
        self.conv5_1 = torch.nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
        self.bn5_1 = torch.nn.BatchNorm2d(256)
        self.relu5_1 = torch.nn.ReLU()
        self.conv5_2 = torch.nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
        self.bn5_2 = torch.nn.BatchNorm2d(256)
        self.relu5_2 = torch.nn.ReLU()

        # 第12 13层
        self.conv6_1 = torch.nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
        self.bn6_1 = torch.nn.BatchNorm2d(256)
        self.relu6_1 = torch.nn.ReLU()
        self.conv6_2 = torch.nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
        self.bn6_2 = torch.nn.BatchNorm2d(256)
        self.relu6_2 = torch.nn.ReLU()

        # 第14 15层
        self.conv7_0 = torch.nn.Conv2d(256, 512, kernel_size=1, stride=2)
        self.conv7_1 = torch.nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1)
        self.bn7_1 = torch.nn.BatchNorm2d(512)
        self.relu7_1 = torch.nn.ReLU()
        self.conv7_2 = torch.nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
        self.bn7_2 = torch.nn.BatchNorm2d(512)
        self.relu7_2 = torch.nn.ReLU()

        # 第16 17层
        self.conv8_1 = torch.nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
        self.bn8_1 = torch.nn.BatchNorm2d(512)
        self.relu8_1 = torch.nn.ReLU()
        self.conv8_2 = torch.nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
        self.bn8_2 = torch.nn.BatchNorm2d(512)
        self.relu8_2 = torch.nn.ReLU()

        # 第18层
        self.fc = torch.nn.Linear(512, num_classes)

    def forward(self, x):  # batch_size, 3, 224, 224
        x = self.conv0_1(x)   # bs, 64, 112, 112
        x = self.bn0_1(x)
        x = self.relu0_1(x)
        x1 = self.dmp(x)  # bs, 64, 56, 56

        x = self.conv1_1(x1)  # bs, 64, 56, 56
        x = self.bn1_1(x)
        x = self.relu1_1(x)
        x = self.conv1_2(x)
        x = self.bn1_2(x)
        x = x + x1
        x2 = self.relu1_2(x)

        x = self.conv2_1(x2)
        x = self.bn2_1(x)
        x = self.relu2_1(x)
        x = self.conv2_2(x)
        x = self.bn2_2(x)
        x = x + x2
        x = self.relu2_2(x)  # bs, 64, 56, 56

        x3 = self.conv3_0(x)  # bs, 128, 28, 28
        x = self.conv3_1(x)
        x = self.bn3_1(x)
        x = self.relu3_1(x)
        x = self.conv3_2(x)
        x = self.bn3_2(x)
        x = x + x3
        x4 = self.relu3_2(x)

        x = self.conv4_1(x4)
        x = self.bn4_1(x)
        x = self.relu4_1(x)
        x = self.conv4_2(x)
        x = self.bn4_2(x)
        x = x + x4
        x = self.relu4_2(x)  # bs, 128, 28, 28

        x5 = self.conv5_0(x)  # bs, 256, 14, 14
        x = self.conv5_1(x)
        x = self.bn5_1(x)
        x = self.relu5_1(x)
        x = self.conv5_2(x)
        x = self.bn5_2(x)
        x = x + x5
        x6 = self.relu5_2(x)

        x = self.conv6_1(x6)
        x = self.bn6_1(x)
        x = self.relu6_1(x)
        x = self.conv6_2(x)
        x = self.bn6_2(x)
        x = x + x6
        x = self.relu6_2(x)  # bs, 256, 14, 14

        x7 = self.conv7_0(x)  # bs, 512, 7, 7
        x = self.conv7_1(x)
        x = self.bn7_1(x)
        x = self.relu7_1(x)
        x = self.conv7_2(x)
        x = self.bn7_2(x)
        x = x + x7
        x8 = self.relu7_2(x)

        x = self.conv8_1(x8)
        x = self.bn8_1(x)
        x = self.relu8_1(x)
        x = self.conv8_2(x)
        x = self.bn8_2(x)
        x = x + x8
        x = self.relu8_2(x)  # bs, 512, 7, 7

        x = torch.nn.functional.avg_pool2d(x, (x.shape[-2], x.shape[-1]))
        x = torch.flatten(x, 1, -1)
        x = self.fc(x)
        return x


if __name__ == "__main__":
    tx = torch.randn((4, 3, 224, 224))
    algo = myResNet()
    pred = algo(tx)
    print(pred.shape)

参考地址:https://mp.weixin.qq.com/s/eWeVWcEMLC9FIiFqKy5wqA

标签:resnet18,kernel,nn,实现,self,torch,stride,方法,size
From: https://www.cnblogs.com/ddzhen/p/18464372

相关文章

  • android开发修复第三方库生成的so库名称不是以so结尾的解决方法
    需要ubuntu安装patchelf软件:sudoapt-getinstallpatchelf1.先使用readelf-d查看so内容结构先使用readelf-dlibpsl.so.5.3.5查看libpsl.so.5.3.5库类型是NEEDED和SONAME的对应的名称是不是以.so结尾的,比如下面的图,libc.so的名称是以.so结尾的我们就不用管,libpsl.so.5不......
  • 实现基于UDS诊断协议的CAN本地OTA升级
    一、目标在上篇文章实现基于UDSLIN诊断协议的本地OTA升级-CSDN博客博客中已经基于LINUDS诊断协议实现了通过PC端上位机对MCU进行本地的OTA升级。本篇将在上篇文章的基础上实现基于UDS诊断协议的CAN本地OTA升级。本篇文章对实现的目的、需要用到的第三方工具请查看之前的博客相......
  • 基于redis实现验证码、Token的存储
    多台tomcat服务器之间session信息不能共享(早期tomcat为解决这个问题可以在tomcat服务器之间拷贝session信息但拷贝时有时间延迟故淘汰)1.使用redis替代session1.使用String数据类型存储验证码 每一个手机号作为key2.使用Hash数据结构存储用户信息  随机token作为k......
  • 如何在Java中实现对象和Map之间的转换
    在Java中,对象和Map之间的转换是一个常见的需求,特别是在处理JSON数据、配置参数或需要将对象序列化为易于存储和传输的格式时。以下是详细讲解如何在Java中实现对象和Map之间转换的方法。1.引入必要的库Java标准库本身不提供对象和Map之间自动转换的功能,但我们可以使用第......
  • LIN诊断实现MCU本地OTA升级
    一、目标通过PC端上位机实现MCU本地的OTA升级,本篇文章对实现的目的、需要用到的第三方工具、LIN诊断帧、升级协议、MCU端升级过程以及PC端升级过程做详细说明。二、目的最近在做MCU项目时需要将样机寄给客户进行验证,在客户的验证过程中要求参数可调试,如果需要修改软件升级MCU就......
  • 实现基于UDS LIN诊断协议的本地OTA升级
    一、目标在上篇文章LIN诊断实现MCU本地OTA升级_linota-CSDN博客中已经基于LIN诊断协议实现了通过PC端上位机对MCU进行本地的OTA升级,但是没有完全按照UDS协议实现。本篇将在上篇文章的基础上进行改进,实现基于UDSLIN诊断协议的本地OTA升级。本篇文章对实现的目的、需要用到的第三......
  • Neo4j——安装jdk和neo4j过程中的注意事项、流程、安装包版本链接、个人建议和解决方
    后附安装jdk和neo4j过程中的注意事项、流程、安装包版本链接、个人建议和解决方法在安装jdk中,即使之前安装过jdk也要重装,因为之前安装的jdk版本太低或者与neo4j不兼容,这里我安装的jdk为14.0.2版本,neo4j安装的版本为4.1.1版本安装jdk版本的网址链接为:JavaArchiveDownloads......
  • 整数反转(C实现)
    题目:力扣第七题是“整数反转”(ReverseInteger)。题目要求我们给定一个32位有符号整数,反转其数字。如果反转后的整数超过了32位有符号整数的范围[-2^31,2^31-1],则返回0。解题思路:处理正负号:我们首先需要记录输入整数的符号,如果是负数,则最终结果也应该是负数。逐......
  • 基于Java+Jsp+Ssm+Mysql实现的在线乡村风景美食景点旅游平台功能设计与实现一
    一、前言介绍:1.1项目摘要乡村风景美食旅游平台的课题背景主要基于我国旅游产业的现状与发展需求。当前,我国旅游产业虽然发展迅速,但仍然存在基础薄弱、管理手段滞后、信息化程度低等问题。旅游行政管理部门的管理方式相对落后,缺乏有效的信息化管理手段,信息沟通渠道不畅,这......
  • 基于Java+Jsp+Ssm+Mysql实现的在线乡村风景美食景点旅游平台功能设计与实现二
    一、前言介绍:1.1项目摘要乡村风景美食旅游平台的课题背景主要基于我国旅游产业的现状与发展需求。当前,我国旅游产业虽然发展迅速,但仍然存在基础薄弱、管理手段滞后、信息化程度低等问题。旅游行政管理部门的管理方式相对落后,缺乏有效的信息化管理手段,信息沟通渠道不畅,这......