首页 > 编程语言 >零基础学习人工智能—Python—Pytorch学习(八)

零基础学习人工智能—Python—Pytorch学习(八)

时间:2024-08-23 10:27:26浏览次数:19  
标签:窗口 Python 0.5 学习 卷积 Pytorch 特征 池化 图像

前言

本文介绍卷积神经网络的上半部分。
其实,学习还是需要老师的,因为我自己写文章的时候,就会想当然,比如下面的滑动窗口,我就会想当然的认为所有人都能理解,而实际上,我们在学习的过程中之所以卡顿的点多,就是因为学习资源中想当然的地方太多了。

概念

卷积神经网络,简称CNN, 即Convolutional Neural Network的缩写。

滤波器/卷积核(Filter/Kernels)

卷积核是一个小矩阵(通常是3x3、5x5等),它在输入图像上滑动(即移动),并与图像的局部区域进行矩阵乘法(点积)操作。结果是一个单值,这个值代表了该局部区域的某种特征。
点积就是内积,就是np.dot函数,内积是个值,就是两个矩阵对应项相乘,在相加
例如。a=[2,3] 和 b=[4,5],它们的点积是a⋅b=(2×4)+(3×5)=8+15=23
点积的意义是a⋅b=∥a∥∥b∥cosθ,意思是说,点积等于a向量的模乘以b向量的模乘以ab的夹角θ的cos的值
向量的模就是向量的长度,v=[3,4],因为勾股定理,c²=a²+b²,所以∥v∥=c=根号下a²+b²=根号下9+16=根号下25=5
例。rgb图,是3通道,卷积核会在3个通道上都进行卷积操作,最后形成一个特征图。

卷积核的尺寸

如果尺寸是 5×5,那么滑动窗口的大小也是 5×5。
image

特征图(Feature Map)

当一个卷积核(或滤波器)滑动在输入图像上时,它会在每一个位置计算卷积核与输入图像区域的点积,结果是一个标量。通过滑动整个图像,得到一组标量值,这些值构成了一个新的二维矩阵,这个矩阵就是特征图。
在CNN中,使用越多的卷积核,意味着提取的特征图越多,因此卷积核越大就可以得到的越丰富的特征。
更多的卷积核意味着更多的计算和内存消耗。因此,在选择卷积核数量时,也要考虑硬件资源的限制。

最大池化层(Max Pooling Layer)

是卷积神经网络(CNN)中常用的下采样(或降采样)技术。它用于减小特征图的尺寸,从而减少计算量,并有助于控制模型的复杂度(防止过拟合)。
最大池化操作使用一个固定大小的窗口(通常是2x2或3x3),在特征图上滑动。
在窗口覆盖的区域内,最大池化层会选择该区域的最大值作为输出。
步幅决定了池化窗口在特征图上滑动的步长。步幅为2意味着窗口每次移动2个像素。
每次池化操作生成的特征图尺寸会比输入特征图小。池化操作会减少特征图的宽度和高度,但保持深度(通道数)不变。
例,4x4 的输入特征图如下

1 2 3 4
5 6 7 8
9 10 11 12
13 14 15 16

使用 2x2 的最大池化窗口和步幅为 2,池化过程如下:
池化窗口覆盖 1 2 5 6,最大值为 6
池化窗口覆盖 3 4 7 8,最大值为 8
池化窗口覆盖 9 10 13 14,最大值为 14
池化窗口覆盖 11 12 15 16,最大值为 16
得到的输出特征图为:

6 8
14 16

结合代码理解

结合下面的代码理解上面的概念。

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
# device config
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# hyper parameters
batch_size = 4
learning_rate = 0.001
num_epochs = 0

# dataset has PILImage images of range [0, 1].# We transform them to Tensors of normalized range [-1, 1]
# transforms.ToTensor():将PIL图像或numpy数组转换为PyTorch张量,并将值范围从[0,1]变为[0,255]。
# transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)):对图像进行归一化处理,将图像的像素值调整到[-1,1]范围。
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])


train_dataset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform)

test_dataset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform)

train_loader = torch.utils. data.DataLoader(
    dataset=train_dataset, batch_size=batch_size, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    dataset=test_dataset, batch_size=batch_size, shuffle=False)


print('每份100个,被分成多了份', len(train_loader))

def imshow(img):
    img = img / 2 + 0.5  # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()


# get some random training images
dataiter = iter(train_loader)
images, labels = dataiter.__next__()
# show images
imshow(torchvision.utils.make_grid(images))

# nn.Conv2d 是 PyTorch 用于定义二维卷积层的类
# 三个参数分别为 in_channels、out_channels 和 kernel_size
# in_channels (输入通道数):
# 值=3 这是输入图像的通道数。对于彩色图像,通常有3个通道(对应于RGB),因此这里的值为3。如果输入的是灰度图像,通常只有1个通道。
# 对于彩色图像(RGB),输入图像有3个通道。卷积核对每个通道独立进行操作,然后将这些结果相加,得到输出特征图。
# 输出特征图的数量由卷积核的数量决定。如果你有多个卷积核,它们会捕捉输入图像中的不同特征,每个卷积核生成一个特征图
# out_channels (输出通道数):
# 值=6 这是卷积层输出的通道数,也称为卷积核的数量。这个参数决定了卷积操作后生成多少个不同的特征图。在本例中,卷积层会生成6个特征图,也就是说会有6个不同的卷积核应用于输入图像。
# kernel_size (卷积核大小):
# 值=5 这是卷积核的尺寸,表示卷积核的宽度和高度。这里使用的是 5x5 的卷积核。这意味着每个卷积核会查看输入图像的 5x5 像素区域,并通过滑动窗口方式在整个图像上进行卷积操作。
# conv1:第一个卷积层,将输入的3通道图像(RGB)通过一个5x5的卷积核,生成6个输出通道。这里,卷积层使用 6 个卷积核,每个卷积核会生成一个特征图。因此,该卷积层的输出是 6 个特征图,特征图的深度(通道数)为 6。
conv1 = nn.Conv2d(3, 6, 5)
# pool:最大池化层,将特征图的尺寸缩小一半。
# 第一个参数 (2): 池化窗口的大小。这表示池化操作将应用于一个 2x2 的窗口上。池化窗口决定了在特征图上进行池化操作的区域大小。 
# 第二个参数 (2): 池化的步幅(stride)。步幅决定了池化窗口在特征图上滑动的步长。步幅为 2 意味着池化窗口每次移动 2 个像素。
pool = nn.MaxPool2d(2, 2)
# conv2:第二个卷积层,将6个通道的输入特征图通过一个5x5的卷积核,生成16个输出通道。
conv2 = nn.Conv2d(6, 16, 5)
print(images.shape)


x = conv1(images)
# print(x.shape)
x = pool(x)
# print(x.shape)
x = conv2(x)
# print(x.shape)
x = pool(x)  # print(x.shape)

传送门:
零基础学习人工智能—Python—Pytorch学习—全集


注:此文章为原创,任何形式的转载都请联系作者获得授权并注明出处!



若您觉得这篇文章还不错,请点击下方的【推荐】,非常感谢!

https://www.cnblogs.com/kiba/p/18375380

标签:窗口,Python,0.5,学习,卷积,Pytorch,特征,池化,图像
From: https://www.cnblogs.com/kiba/p/18375380

相关文章

  • 直面程序员的AI焦虑:学习大语言模型开发是关键
    随着AIGC热潮的兴起,越来越多的AI工具应用开始出现,包括OpenAIChatGPT、GithubCopilot等,这些工具正在改变着传统的工作生产方式。在2023年3月的一次发布会中,OpenAI甚至展示了直接通过识别原型草图,智能生成网站代码的案例。一时间,“程序员要失业了”类似言论甚......
  • 今日份笔记奉上,前两天扎针灸手肿了打字不太方便学习进程搁置了,今天简单学了写Dos指令
    基本Dos指令打开CMD1.开始+系统+命令提示符2.Windows+R输入cmd3.在任意文件下按shift+右键打开powershell4.资源管理器地址栏+cmd+空格+回车以管理员方式运行开始中命令提示符可以管理员身份运行常用的Dos指令#盘符切换#查看当前目录下的所有文件dir#切换目录cd#......
  • Python-批量统计MySQL中表的数据量
    背景在数据中台中,有时为了核对数据,需要每天批量统计MySQL数据库中表的数据量,但是DMS中没有周期调度功能。MySQL创建表--统计的表清单CREATETABLE`dws_table_list`(`table_name`varchar(255)DEFAULTNULL,`flag`varchar(255)DEFAULTNULL);--每天的数据量CRE......
  • python socket编辑示例
    服务端代码:fromsocketimportsocket,AF_INET,SOCK_STREAM#1.创建socket对象AF_INET:用于internet之间的进程通信,SOCK_STREAM:表示TCP协议server_socket=socket(AF_INET,SOCK_STREAM)#2.绑定ip和端口号ip='127.0.0.1'port=8888server_socket.bind((ip,p......
  • 数论学习笔记
    积性函数一般我们只需要考虑定义域在\(\mathbb{Z}\)就够了,什么实数,复数都不用管。如果函数\(f(x)\)满足对于任意的\(a,b\)且\(\gcd(a,b)=1\),都有\(f(ab)=f(a)f(b)\)。欧拉函数\(\varphi(i)\)\(\varphi(n)\)定义为大于等于\(1\)且小于\(n\)且与它互质的数的个数......
  • 基于Python flask的图书借阅管理系统的设计与实现
    基于PythonFlask的图书借阅管理系统旨在为图书馆或类似机构提供一个高效、便捷的管理平台,覆盖图书借阅的各个环节,帮助管理员和读者更好地管理和使用图书资源。该系统采用Python编程语言和Flask框架进行开发,结合了数据库管理、用户认证、数据可视化等技术,确保系统的功能完备和......
  • 亦菲喊你来学机器学习(9) --逻辑回归实现手写数字识别
    文章目录逻辑回归实现手写数字识别训练模型测试模型总结逻辑回归逻辑回归(LogisticRegression)虽然是一种广泛使用的分类算法,但它通常更适用于二分类问题。然而,通过一些策略(如一对多分类,也称为OvR或One-vs-Rest),逻辑回归也可以被扩展到多分类问题,如手写数字识别(通常是......
  • 学习分享:如何学习 API 中的数据格式
    以下是学习API中数据格式的要点:一、了解常见数据格式JSON(JavaScriptObjectNotation):结构特点:它是一种轻量级的数据交换格式,易于人阅读和编写,也易于机器解析和生成。JSON数据格式由键值对组成,类似于Python中的字典或者JavaScript中的对象。例如:{"name":"John",......
  • MyBatis 源码解读:专栏导读与学习路线
    前言MyBatis是Java开发中广泛使用的持久层框架,其简洁的配置和强大的功能使得它在开发人员中备受欢迎。然而,MyBatis的背后隐藏着许多设计巧妙的架构和复杂的实现逻辑。通过源码解读,我们可以更深入地理解MyBatis的设计思想和工作原理,从而更好地应用它。本专栏将以源码......
  • python-jose 实现fastapi登录验证
    JWT和Session的区别:JWT:JWT是一种无状态的认证机制。由于JWT令牌包含了用户的身份信息以及相关的元数据,服务端不需要存储任何用户状态信息,只需要验证JWT令牌的真实性和有效性即可。这使得JWT非常适合于构建无状态的分布式系统,因为JWT令牌可以在不同的服务之间轻松共享。Sessio......