首页 > 其他分享 >Pytorch学习--神经网络--线性层及其他层

Pytorch学习--神经网络--线性层及其他层

时间:2024-10-30 21:48:23浏览次数:6  
标签:他层 img nn -- self torch Pytorch import Linear

一、正则化层

torch.nn.BatchNorm2d

torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, device=None, dtype=None)

正则化的意义:

  • 加速训练收敛:在每一层网络的输入上执行批量归一化可以保持数据的分布稳定,从而减小梯度的波动。这种稳定性让模型更快收敛,从而提高训练速度。

  • 减轻梯度消失和梯度爆炸问题:通过调整每一层的输入分布,Batch Normalization可以减轻深层网络中梯度消失和梯度爆炸的现象,使得更深的网络也能够得到有效的训练。

  • 减少对权重初始化的敏感性:Batch Normalization可以减小网络对权重初始化的依赖,使得模型可以在更宽的初始化范围内有效训练。这减少了在不同模型初始化方案间进行调试的时间和精力。

  • 提高模型的泛化能力:Batch Normalization在训练时引入了少量噪声(由于 mini-batch 的不同),这在一定程度上起到了正则化作用,有助于提高模型的泛化能力,降低过拟合的风险。

  • 降低学习率调整的难度:使用Batch Normalization可以让模型在较高的学习率下进行训练,从而进一步加速训练过程。

二、Dropout层

torch.nn.Dropout

torch.nn.Dropout(p=0.5, inplace=False)

在这里插入图片描述
防止过拟合

三、线性层

torch.nn.Linear

torch.nn.Linear(in_features, out_features, bias=True, device=None, dtype=None)

在这里插入图片描述
代码实现:
CIFAR 中的图片 转换为 一维的数据(1,m),再转换成 (1,n) 的维度

import torch
import torchvision
from torch import nn
from torch.nn import Linear
from torch.utils.data import DataLoader

dataset = torchvision.datasets.CIFAR10(root="datasets",train=False,transform=torchvision.transforms.ToTensor(),download=True)

dataloader = DataLoader(dataset,batch_size=64)

class Mary(nn.Module):
    def __init__(self):
        super(Mary,self).__init__()
        self.linear1 = Linear(196608,10)
    def forward(self,x):
        x = self.linear1(x)
        return x
Yorelee = Mary()

for data in dataloader:
    img,targets = data
    img = torch.flatten(img)
    print(img.shape)
    output = Yorelee(img)
    print(output.shape)

输出:

torch.Size([196608])
torch.Size([10])

标签:他层,img,nn,--,self,torch,Pytorch,import,Linear
From: https://blog.csdn.net/weixin_68930974/article/details/143364815

相关文章

  • PTA | 六度空间
    “六度空间”理论又称作“六度分隔(SixDegreesofSeparation)”理论。这个理论可以通俗地阐述为:“你和任何一个陌生人之间所间隔的人不会超过六个,也就是说,最多通过五个人你就能够认识任何一个陌生人。”如图1所示。“六度空间”理论虽然得到广泛的认同,并且正在得到越来越......
  • Codeforces Round 981 (Div. 3) ABCDE
    CodeforcesRound981(Div.3)ABCDEA.SakurakoandKosuke藕是看样例直接猜了结论......
  • 分享一下最近清洗CFPS心得,有错误求指正
    目标:得到一个四期面板数据,每期包括家庭库和个人库一、提取变量以2014年为例,2016、2018、2020省略处理过程1.处理个人库keepfid14pidprovcd14urban14cfps2014_agecfps_genderqea0qp201cfps2014eduy_imqz207ku802 替换缺失值forvar_all:replaceX=.ifinl......
  • QT:QThread 使用案例
    问题描述:软件界面打开之前要初始化相机和机械臂,并且在执行扫描点云,配准等操作时,只能单线程运行,导致运行效率低。解决:使用QThread首先写一个类如task,成员函数执行的是需要在子线程运行。task.h:task类需要继承QObject类,startadd()函数内容在子进程运行。#ifndefTASK_H#de......
  • 免费送源码:Java+ssm+MySQL+Ajax ssm第二课堂管理系统 计算机毕业设计原创定制
    摘要随着互联网的高速发展,教育进入了信息化时代,促使了多种混合式教学模式的出现。第二课堂管理系统是这一时期新型混合式教学模式的代表,它的出现改变了传统教学模式,将知识传递置于课前,将学习知识的主动性交给学生,促使学生的素质全面发展。第二课堂管理系统以“以学生为......
  • Python 标准库——argparse模块
    文章目录前言一、主要作用二、基本步骤1.导入模块2.创建解析器对象3.添加参数4.解析参数5.使用解析后的参数6.编写主函数并调用三、函数示例前言argparse是Python标准库中的一个模块,用于编写用户友好的命令行接口。它允许你轻松地定义程序应该接受的命令行参数,并......
  • (C语言)数组
    目录一维数组1>.  定义2>. 数组的下标3>. 数组的初始化4>. 计算数组的大小    1)strlen    2)sizeof二维数组1>. 定义2>. 初始化    1)只有一个{}    2)多个{}变长数组数组定义:为了存放多个相同类型的元素,创建了数组;......
  • 高校联动,创新无限!“2024 深圳国际金融科技大赛”校园行活动圆满结束
    在金融科技蓬勃发展的当下,人才培养成为推动行业前行的关键。为推进深圳市金融科技人才高地建设,向高校学子提供一个展示自身知识、能力和创意的平台,2024FinTechathon深圳国际金融科技大赛——西丽湖金融科技大学生挑战赛重磅开启,并精心筹备了一系列精彩活动。自报名启动后,大......
  • 基于模型内部的检索增强型生成答案归属方法:MIRAGE
    人工智能咨询培训老师叶梓转载标明出处在自然语言处理(NLP)中,确保模型生成答案的可验证性是一个重要挑战。特别是在检索增强型生成(RAG)用于问答(QA)领域时,如何验证模型答案是否忠实于检索到的来源是一个关键问题。近期一种名为自引用提示的方法被提出,以使大型语言模型(LLMs)在生成答......
  • Servlet -个人理解笔记
    Servlet的作用        Servlet主要是为了衔接web应用的前端和后端的,作为它们俩中间数据交换的桥梁,现在很多web项目都是前后端分离的,前端写前端的后端写后端的,但是他俩所用的编程语言是有区别的,怎么实现它们之间的数据交换呢?Servlet就是为了解决这个,它是用java编写的,目......