首页 > 其他分享 >Pytorch nn.Linear的基本用法与原理详解

Pytorch nn.Linear的基本用法与原理详解

时间:2023-10-03 16:34:40浏览次数:54  
标签:10 样本 Linear nn torch Pytorch model

Pytorch nn.Linear的基本用法与原理详解

原文:Pytorch nn.Linear的基本用法与原理详解_iioSnail的博客-CSDN博客

nn.Linear的基本定义

nn.Linear定义一个神经网络的线性层,方法签名如下:

torch.nn.Linear(in_features, # 输入的神经元个数
           out_features, # 输出神经元个数
           bias=True # 是否包含偏置
           )

Linear其实就是对输入\(X_{n\times i}\)执行了一个线性变换,既:

\[Y_{n\times o}=X_{n\times i}W_{i\times o}+b \]

其中\(W\)是模型要学习的参数,\(W\) 的维度为\(W_{i\times o}\),\(b\) 是o维的向量偏置,\(n\) 为输入向量的行数(例如,你想一次输入10个样本, 即batch_size为10,则\(n=10\)),\(i\)为输入神经元的个数(例如你的样本特征数为5,则\(i=5\)),\(o\)为输出神经元的个数。

使用演示:

from torch import nn
import torch

model = nn.Linear(2, 1) # 输入特征数为2,输出特征数为1

input = torch.Tensor([1, 2]) # 给一个样本,该样本有2个特征(这两个特征的值分别为1和2)
output = model(input)
output

tensor([-1.4166], grad_fn=<AddBackward0>)

我们的输入为[1,2],输出了[-1.4166]。可以查看模型参数验证一下上述的式子:

# 查看模型参数
for param in model.parameters():
    print(param)


Parameter containing:
tensor([[ 0.1098, -0.5404]], requires_grad=True)
Parameter containing:
tensor([-0.4456], requires_grad=True)

可以看到,模型有3个参数,分别为两个权重和一个偏执。计算可得:

\[y=[1,2]*[0.1098,-0.5404]^T-0.4456=-1.4166 \]

实战

假设我们的一次输入三个样本A,B,C(即batch_size为3),每个样本的特征数量为5:

A: [0.1,0.2,0.3,0.3,0.3]
B: [0.4,0.5,0.6,0.6,0.6]
C: [0.7,0.8,0.9,0.9,0.9]

则我们的输入向量 \(X_{3\times5}\) 为:

X = torch.Tensor([
    [0.1,0.2,0.3,0.3,0.3],
    [0.4,0.5,0.6,0.6,0.6],
    [0.7,0.8,0.9,0.9,0.9],
])
X
tensor([[0.1000, 0.2000, 0.3000, 0.3000, 0.3000],
        [0.4000, 0.5000, 0.6000, 0.6000, 0.6000],
        [0.7000, 0.8000, 0.9000, 0.9000, 0.9000]])

定义线性层, 我们的输入特征为5,所以 in_feature=5,我们想让下一层的神经元个数为10,所以 out feature=10, 则模型参数为: \(W_{5\times10}\)

model = nn.Linear(in_features=5, out_features=10, bias=True)

经过线性层,其实就是做了一件事,即:

\[Y_{3\times10}=X_{3\times5}W_{5\times10}+b \]

具体表示则为:

\[\begin{bmatrix}Y_{00}&Y_{01}&\cdots&Y_{08}&Y_{09}\\Y_{10}&Y_{11}&\cdots&Y_{18}&Y_{19}\\Y_{20}&Y_{21}&\cdots&Y_{28}&Y_{29}\end{bmatrix}=\begin{bmatrix}X_{00}&X_{01}&X_{02}&X_{03}&X_{04}\\X_{10}&X_{11}&X_{12}&X_{13}&X_{14}\\X_{20}&X_{21}&X_{22}&X_{23}&X_{24}\end{bmatrix}\begin{bmatrix}W_{00}&W_{01}&\cdots&W_{08}&W_{09}\\W_{10}&W_{11}&\cdots&W_{18}&W_{19}\\W_{20}&W_{21}&\cdots&W_{28}&W_{29}\\W_{30}&W_{31}&\cdots&W_{38}&W_{39}\\W_{40}&W_{41}&\cdots&W_{48}&W_{49}\end{bmatrix}+b \]

其中 \(X_i\).就表示第\(i\)个样本, \(W_{\cdot j}\) 表示所有输入神经元到第\(j\)个输出神经元的权重。
image

注意: 这里图有点问题, 应该是\(W_{00},W_{01},W_{02},...,W_{07},W_{08},W_{09}\)

因为有三个样本,所以相当于依次进行了三次\(Y_{1\times10}=X_{1\times5}W_{5\times10}\),然后再将三个\(Y_{1\times10}\) 叠在一起经过线性层后,我们最终的到了\(3\times10\)维的矩阵,即 输入3个样本,每个样本维度为5,输出为3个样本,将每个样本扩展成了10维

model(X).size()
# torch.Size([3, 10])

Pytorch版本线性回归模型

import torch
from torch import nn
from torch import optim
import numpy as np
from matplotlib import pyplot as plt

# 1. 定义数据
x = torch.rand([50,1])
y = x*3 + 0.8

#2 .定义模型
class Lr(nn.Module):
    def __init__(self):
        super(Lr,self).__init__()
        # 因为简单的一维线性回归x的特征只有1,我们要预测的y也只有一个特征
        self.linear = nn.Linear(1,1)
    # 定义前向传播过程
    def forward(self, x):
        out = self.linear(x)
        return out

# 2. 实例化模型,loss,和优化器
model = Lr()
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=1e-3)
#3. 训练模型
for i in range(30000):
    out = model(x) #3.1 获取预测值
    loss = criterion(y,out) #3.2 计算损失
    optimizer.zero_grad()  #3.3 梯度归零
    loss.backward() #3.4 计算梯度
    optimizer.step()  # 3.5 更新梯度
    if (i+1) % 20 == 0:
        print('Epoch[{}/{}], loss: {:.6f}'.format(i,30000,loss.data))

#4. 模型评估
model.eval() #设置模型为评估模式,即预测模式
predict = model(x)
predict = predict.data.numpy()
plt.scatter(x.data.numpy(),y.data.numpy(),c="r")
plt.plot(x.data.numpy(),predict)
plt.show()

标签:10,样本,Linear,nn,torch,Pytorch,model
From: https://www.cnblogs.com/jzYe/p/17741248.html

相关文章

  • MySQL学习(2)什么是InnoDB数据页
    前言什么是InnoDB页MySQL服务器中负责读写数据的是存储引擎,InnoDB是一种常用的,将表数据存储在磁盘中的存储引擎。在实际操作中,MySQL将磁盘中的数据加载到内存中,若是需要处理写入或修改,则把内存中的数据刷新到磁盘。什么是行格式数据是以记录为单位在表中存储的,每一......
  • 解决ERROR oslo_messaging.rpc.server UnixHTTPConnectionPool(host=‘localhost‘, p
    zun.common.exception.ZunException:Unexpectederror:UnixHTTPConnectionPool(host=‘localhost’,port=None):Readtimedout.(readtimeout=60)2023-09-1317:26:08.6498468ERRORoslo_messaging.rpc.server[req-6ff62c08-fd25-4df6-8a81-d144956cbbd97db25ffff6314......
  • 【研究生学习】深度学习中几种常用的卷积形式的原理以及其Pytorch调用
    本篇博客主要记录一下在深度学习中几种常用的卷积形式的基本原理、输入输出维度,以及如何在Pytorch中调用这些卷积形式卷积卷积实际上是对图像的不同区域进行特征提取,一般认为输入图像的维度为H×W×C,如下图所示:图像具有颜色通道,一般是RGB,需要理解的是不同通道数的图像和不同的......
  • Java 21 新特性:Unnamed Patterns and Variables
    Java21中除了推出JEP445:UnnamedClassesandInstanceMainMethods之外,还有另外一个预览功能:未命名模式和变量(UnnamedPatternsandVariables)。该新特性的目的是提高代码的可读性和可维护性。下面通过一个例子来理解这个功能,try-catch块相信大家都不陌生,都是这样写的:try{......
  • Scanner()
    Scanner对象next():1、一定要读取有效支付后才可以结束输入。2、对输入有效字符之前遇到的空白,next()方法会自动将其去掉。3、只有输入有效字符后才将其后面输入的空白作为分隔符或者结束符。4、next()不能得到带有空格的字符。示例如下:importjava.util.Scanner;public......
  • Pytorch环境深度学习环境
    Pytorch环境深度学习环境1、安装minicoda下载地址:Miniconda—minicondadocumentation设置环境变量:安装路径\Miniconda3安装路径\Miniconda3\Scripts安装路径\Miniconda3\Library\bin测试:打开cmd,输入conda测试指令是否有效。2、配置base环境国内镜像(1)conda镜像......
  • AtCoder Beginner Contest 178 E
    AtCoderBeginnerContest178EE-DistMax曼哈顿距离最大点对\(ans=max(|x_i-x_j|+|y_i-y_j|)\)考虑去绝对值,4种情况。sort一下取max即可。#include<bits/stdc++.h>usingnamespacestd;typedeflonglongll;constintN=2e5+10;intx[N],y[N];intp[4][N];......
  • Java 21 新特性:Unnamed Classes and Instance Main Methods
    Java21引入了两个语言核心功能:未命名的Java类你说新的启动协议:该协议允许更简单地运行Java类,并且无需太多样板下面一起来看个例子。通常,我们初学Java的时候,都会写类似下面这样的HelloWorld程序:publicclassHelloWorld{publicstaticvoidmain(String[]args){......
  • compattelrunner.exe 是 Windows 操作系统中的一个可执行文件。它是 Microsoft 官方提
    compattelrunner.exe是Windows操作系统中的一个可执行文件。它是Microsoft官方提供的用于收集计算机性能数据和故障排除的工具,这些数据旨在帮助Microsoft监测和改进Windows操作系统。Compattelrunner.exe的主要功能如下:收集数据:它定期扫描计算机上的所有文件和程序,并......
  • AtCoder Beginner Contest 322
    A-FirstABC2解题思路签到Code#include<bits/stdc++.h>usingnamespacestd;typedeflonglongLL;voidsolve(){ intn; cin>>n; strings; cin>>s; intp=s.find("ABC"); if(p==-1)cout<<p<<'\n&......