首页 > 其他分享 >残差网络ResNet的深入介绍和实战

残差网络ResNet的深入介绍和实战

时间:2024-10-29 21:17:08浏览次数:3  
标签:实战 训练 nn outchannel self ResNet 残差

ResNet是由Kaiming He等人在2015年提出的深度学习模型,它通过引入残差学习解决了随着网络深度增加而性能下降的问题。ResNet在多个视觉识别任务上取得了显著的成功,在ImageNet的分类比赛上将网络深度直接提高到了152层,前一年夺冠的VGG只有19层。斩获当年ImageNet竞赛中分类任务第一名,目标检测第一名。获得COCO数据集中目标检测第一名,图像分割第一名,可以说ResNet的出现对深度神经网络来说具有重大的历史意义。

image.png

一、架构和网络结构

它使用了一种连接方式叫做“shortcut connection”,顾名思义,shortcut就是“抄近道”的意思,下面是这个resnet的网络结构:

image.png

1.残差块(Residual Block)

ResNet的核心是残差块,它允许网络中的信号绕过一个或多个层,直接传递到后面的层。这种设计使得网络能够学习到输入和输出之间的残差,而不是直接学习输出。

残差块的代码实现

残差块由两个主要部分组成:主体部分和跳过连接(skip connection)。主体部分通常包含卷积层、批量归一化层和ReLU激活函数。跳过连接则是一个或多个层的输出直接相加到主体部分的输出上。

参数设置:

inchannel 是输入通道数

outchannel 是输出通道数

stride 是卷积层的步长

class Block(nn.Module):
    def __init__(self, inchannel, outchannel, stride):
        super().__init__()

第一个卷积层后的批量归一化层,有助于网络训练的稳定性,第一个卷积层,用于提取特征,步长为stride,有助于减少特征图的空间维度,第二个卷积层,用于进一步提取特征,步长为1:

self.block = nn.Sequential(
            nn.Conv2d(inchannel, outchannel, (3, 1), (stride, 1), (1, 0)),
            nn.BatchNorm2d(outchannel),
            nn.ReLU(),
            nn.Conv2d(outchannel, outchannel, 1, 1, 0),
            nn.BatchNorm2d(outchannel)
        )

构建跳过连接,如果输入输出通道数不一致或步长不为1,则添加跳过连接以匹配维度:

        self.short = nn.Sequential()
        if (inchannel != outchannel or stride != 1):
            self.short = nn.Sequential(
                nn.Conv2d(inchannel, outchannel, (3, 1), (stride, 1), (1, 0)),
                nn.BatchNorm2d(outchannel)
            )

向传播函数,计算残差块的输出,将残差块的输出与输入相加,形成残差:

def forward(self, x):
        out = self.block(x) + self.short(x)
        return nn.ReLU()(out)

2.网络层构建

ResNet由多个残差块组成,每一层由多个残差块串联而成。每一层的输入通过多个残差块,并通过跳过连接(Skip Connection)与层的输出相加。

网络结构的代码实现

__init__方法中定义了ResNet的整体结构。它包括四个残差层,每个层由make_layers函数构建,以及自适应平均池化层和全连接层。

初始化ResNet网络,定义训练样本的形状和类别数,train_shape 是训练样本的形状,用于确定自适应平均池化层的大小,category 是类别数,用于确定全连接层的输出维度:

class ResNet(nn.Module):
    # ...
    def make_layers(self, inchannel, outchannel, stride, blocks):
        layer = [Block(inchannel, outchannel, stride)]
        for i in range(1, blocks):
            layer.append(Block(outchannel, outchannel, 1))
        return nn.Sequential(*layer)

3.网络结构

ResNet的网络结构由四个主要部分组成:

  1. 输入层:接受输入数据。
  2. 残差层:由多个残差块组成,每个残差块包含两个卷积层和批量归一化层。
  3. 自适应平均池化层(Adaptive Average Pooling):将不同大小的特征图转换为统一大小,以便于全连接层处理。
  4. 全连接层:将特征图展平并输出分类结果。
网络结构的代码实现

__init__方法中定义了ResNet的整体结构。它包括四个残差层,每个层由make_layers函数构建,以及自适应平均池化层和全连接层。

首先网络的每个残差块层,通过make_layers函数创建多个残差块,每个make_layers函数调用定义了该层的输入通道数、输出通道数、步长和残差块的数量:

class ResNet(nn.Module):
    def __init__(self, train_shape, category):
        super().__init__()
        self.layer1 = self.make_layers(1, 64, 2, 1)
        self.layer2 = self.make_layers(64, 128, 2, 1)
        self.layer3 = self.make_layers(128, 256, 2, 1)
        self.layer4 = self.make_layers(256, 512, 2, 1)

自适应平均池化层,将特征图的每个特征通道压缩为一个单一的值,池化层的输出尺寸由train_shape[-1]决定,通常是序列长度或特征图的高度/宽度:

self.ada_pool = nn.AdaptiveAvgPool2d((1, train_shape[-1]))

全连接层,将自适应平均池化层的输出展平并映射到类别数,输入特征的数量是512乘以特征图的维度,输出是类别数:

self.fc = nn.Linear(512*train_shape[-1], category)

4.前向传播

在前向传播过程中,输入数据通过每一层的残差块,每一层的输出通过跳过连接与下一层的输入相加,最后通过全连接层输出分类结果。

前向传播的代码实现

forward方法定义了数据通过网络的流动过程。数据首先通过四个残差层,然后通过自适应平均池化层和全连接层,最终输出分类结果。

class ResNet(nn.Module):
    # ...
    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.ada_pool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

5.优点

  • 解决梯度消失问题:通过残差连接,ResNet可以构建更深的网络而不会遭受梯度消失或梯度爆炸的问题。
  • 简化训练过程:残差块的设计使得网络的每个部分可以独立训练,简化了训练过程。
  • 提高性能:ResNet在多个视觉识别任务上取得了当时的最佳性能。

ResNet不仅在图像分类任务上表现出色,还被广泛应用于其他计算机视觉任务,如目标检测、语义分割等。通过引入残差学习框架,成功解决了深层网络训练中的一些关键问题。

二、ResNet训练PAMAP2数据集

ResNet网络结构:

image.png

模型初始化

在模型初始化部分首先码实例化模型类。对于ResNet,将调用resnet.ResNet

net = model_dict[args.model](X_train.shape, category).to(device)

训练过程

训练过程包括前向传播、损失计算、反向传播和参数更新。这里使用了混合精度训练来加速训练过程并减少内存使用。

for i in range(EP):
    net.train()
    ...
    with autocast():
        out = net(data)
        loss = loss_fn(out, label)
    ...
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

评估和推理

在每个epoch结束后,模型在测试集上进行评估,计算准确率、精确率、召回率和F1分数。我们还计算了推理时间,这对于实时应用来说是一个重要的指标。

net.eval()
...
accuracy = accuracy_score(all_labels, all_preds)
...
inference_time = inference_end_time - inference_start_time

训练过程及结果:

image.png

结果分析

单图像推理时间

  • 单图像推理时间非常快,大约在0.0014到0.0021秒之间。这表明ResNet模型能够迅速对单个图像进行推理,这对于需要实时反馈的应用场景来说是一个积极的迹象。

训练损失

  • 在Epoch 98,训练损失非常低,为3.8721980672562495e-05,这表明模型在训练数据上的表现非常好,几乎能够完美地拟合训练数据。
  • 在Epoch 99,训练损失进一步降低到4.326595899328822e-06,这进一步证实了模型的拟合能力。

总推理时间

  • 在Epoch 98和Epoch 99,整个测试集的推理时间分别为1.7669秒和1.7730秒。考虑到测试集可能包含大量的图像,这个推理时间是相当快的,表明模型在实际部署时能够提供快速响应。

ResNet在PAMAP2数据集上的训练结果是令人鼓舞的。模型显示出了极低的训练损失和在测试集上高准确率和高泛化能力。快速的单图像推理时间和整个测试集的推理时间表明,这个模型适合于需要快速响应的应用场景。然而,需要注意的是,尽管F1分数很高,但准确率的轻微下降可能表明模型开始轻微过拟合或训练数据中的一些变化。

总结来说,ResNet由于其残差连接能够训练更深的网络而不会遇到梯度消失问题,这在处理复杂的HAR任务时可能是有益的。使用torch.cuda.amp中的GradScalerautocast可以加速训练过程并减少内存使用,这对于大型模型和数据集尤其重要。在实时HAR应用中,推理时间是一个关键指标。我们计算了整个测试集的推理时间,这有助于评估模型在实际部署时的性能。

标签:实战,训练,nn,outchannel,self,ResNet,残差
From: https://blog.csdn.net/weixin_51390582/article/details/143029672

相关文章

  • Java项目-基于springboot框架的高校社团管理系统项目实战(附源码+文档)
    作者:计算机学长阿伟开发技术:SpringBoot、SSM、Vue、MySQL、ElementUI等,“文末源码”。开发运行环境开发语言:Java数据库:MySQL技术:SpringBoot、Vue、MybaitsPlus、ELementUI工具:IDEA/Ecilpse、Navicat、Maven源码下载地址:https://download.csdn.net/download/weixin_53......
  • Java项目-基于springboot框架的民宿管理系统项目实战(附源码+文档)
    作者:计算机学长阿伟开发技术:SpringBoot、SSM、Vue、MySQL、ElementUI等,“文末源码”。开发运行环境开发语言:Java数据库:MySQL技术:SpringBoot、Vue、MybaitsPlus、ELementUI工具:IDEA/Ecilpse、Navicat、Maven源码下载地址:https://download.csdn.net/download/weixin_53......
  • 明星人脸识别基于VGG、MTCNN、RESNET深度学习卷积神经网络应用|附数据代码
    全文链接:https://tecdat.cn/?p=38046原文出处:拓端数据部落公众号分析师:XinzuDu 人脸识别技术作为生物特征识别技术的重要组成部分,在近三十年里得到了广泛的关注和研究,已经成为计算机视觉、模式识别领域的研究热点。然而由于存在光线、背景、人脸遮挡等问题,如何准确识别出人......
  • GraphRAG原理及部署实战(GraphRAG系列第一篇)
        RAG在大模型时代,被寄予了厚望,但在近一年多各大小公司的实施过程中,其效果远没有抖音中宣传的那么振奋人心,其原因是多方面的。这篇文章就RAG中的一个弱项--局部性来展开讨论。一、RAG原理       图1描述了RAG的原理,用户输入了一个指令Instruct,RAG将其与Docu......
  • 基于ResNet50模型的船型识别与分类系统研究
    项目源码获取方式见文章末尾!600多个深度学习项目资料,快来加入社群一起学习吧。《------往期经典推荐------》项目名称1.【LSTM模型实现光伏发电功率的预测】2.【卫星图像道路检测DeepLabV3Plus模型】3.【GAN模型实现二次元头像生成】4.【CNN模型实现mnist手写数字......
  • 购物平台数据抓取实战指南:从API到深度分析
    在当今电商盛行的时代,淘宝、京东、拼多多等购物平台已成为消费者日常购物的主要场所。对于企业、市场分析师及开发者而言,这些平台上的数据无疑是一座宝贵的金矿。本实战指南将带您从API接口出发,一步步实现购物平台数据的抓取、处理到深度分析。一、API接口初探API(Application......
  • 【项目实战】Java中集合Collection 和 Collections入门介绍
    在Java编程语言中,Collection是一个接口,它是集合层次结构中的根接口。Collection接口定义了所有集合类型(如列表、集合和队列)所共有的基本操作方法。而Collections则是一个工具类,它提供了一系列静态方法来操作或返回集合。当你需要存储一组对象并在程序中对其进行操作时,......
  • 【项目实战】网络通信协议Socket和WebSocket入门介绍
    一、Socket1.1文件描述符详解文件描述符是在操作系统层面用来访问文件或I/O资源(如网络套接字)的一个抽象的、非负整数。每个进程在打开一个文件或创建一个套接字时,都会得到一个唯一的文件描述符。在Unix/Linux系统中,标准输入(stdin)、标准输出(stdout)和标准错误(stderr)默认......
  • 【项目实战】分布式日志搜索系统之数据同步方案(Logstash-input-jdbc、go-mysql-elast
    在构建分布式日志搜索系统时,数据同步是一个核心环节。以下是针对您提出的五种数据同步方案的详细分析:一、Logstash-input-jdbcLogstash是ElasticStack的一部分,用于从各种来源收集数据,并将其发送到Elasticsearch。Logstash-input-jdbc插件允许Logstash从关系型数据库(如My......
  • 【项目实战】分布式日志搜索系统之Elastic Stack日志抽取(filebeat、heartbeat、packet
    一、ElasticStack是什么?ElasticStack,以前称为ELKStack,是一套开源的日志分析解决方案。ElasticStack,由Elastic公司开发和维护。ElasticStack,包括了几个核心组件,这些组件协同工作以帮助用户收集、处理、存储、搜索和可视化数据。ElasticStack,因其灵活性和强大的功能......