首页 > 其他分享 >最简单知识点PyTorch中的nn.Linear(1, 1)

最简单知识点PyTorch中的nn.Linear(1, 1)

时间:2024-04-05 19:32:38浏览次数:14  
标签:知识点 Linear nn 输出 特征 torch PyTorch 输入

一、nn.Linear(1, 1)

nn.Linear(1, 1) 是 PyTorch 中的一个线性层(全连接层)的定义。

nn 是 PyTorch 的神经网络模块(torch.nn)的常用缩写。

nn.Linear(1, 1) 的含义如下:

  • 第一个参数 1:输入特征的数量。这表示该层接受一个长度为 1 的向量作为输入
  • 第二个参数 1:输出特征的数量。这表示该层产生一个长度为 1 的向量作为输出

因此,nn.Linear(1, 1) 定义了一个简单的线性变换,其数学形式为:y=x⋅w+b
其中:

  • x 是输入向量(长度为 1)。
  • w 是权重(也是一个长度为 1 的向量)。
  • b 是偏置项(一个标量)。
  • y 是输出向量(长度为 1)。

在实际应用中,这样的线性层可能不常用,因为对于从长度为 1 的输入到长度为 1 的输出的映射,这实际上就是一个简单的线性变换,但在某些特定场景或作为更复杂模型的一部分时,它仍然可能是有用的。

二、简单举例

假设我们有一个简单的任务,需要预测一个线性关系,比如根据给定的输入值 x 来预测输出值 y,其中 y 是 x 的线性变换。在这种情况下,nn.Linear(1, 1) 可以用来表示这个线性关系。

以下是一个使用 PyTorch 和 nn.Linear(1, 1) 的简单例子:

import torch
import torch.nn as nn
import torch.optim as optim
# 定义模型
model = nn.Linear(1, 1)
# 定义损失函数和优化器
criterion = nn.MSELoss() # 均方误差损失
optimizer = optim.SGD(model.parameters(), lr=0.01) # 随机梯度下降优化器
# 假设我们有一些简单的线性数据
x_train = torch.tensor([[1.0], [2.0], [3.0], [4.0]], dtype=torch.float32)
y_train = torch.tensor([[2.0], [4.0], [6.0], [8.0]], dtype=torch.float32) # 假设 y = 2 * x
# 训练模型
for epoch in range(100): # 假设我们训练 100 个 epoch
        # 前向传播
        outputs = model(x_train)
        loss = criterion(outputs, y_train)
        # 反向传播和优化
        optimizer.zero_grad() # 清除梯度
        loss.backward() # 反向传播计算梯度
        optimizer.step() # 应用梯度更新权重
        # 打印损失值(可选)
        if (epoch+1) % 10 == 0:
                print(f'Epoch [{epoch+1}/100], Loss: {loss.item():.4f}')
# 测试模型
with torch.no_grad(): # 不需要计算梯度
        x_test = torch.tensor([[5.0]], dtype=torch.float32)
        y_pred = model(x_test)
        print(f'Predicted output for x=5: {y_pred.item()}')

运行截图:

图1 上述代码运行输出

在这个例子中,我们创建了一个简单的线性模型 nn.Linear(1, 1) 来学习输入 x 和输出 y 之间的线性关系。我们使用均方误差损失函数 nn.MSELoss() 随机梯度下降优化器 optim.SGD() 来训练模型。通过多次迭代(epoch),模型逐渐学习权重和偏置项(w, b)以最小化预测值与实际值之间的误差。最后,我们使用训练好的模型对新的输入值 x=5 进行预测,并打印出预测结果。

三、举一反三——nn.Linear(2, 1) 

nn.Linear(2, 1) 是PyTorch深度学习框架中用于定义一个线性层的语句。在深度学习中,线性层(也被称为全连接层或密集层)是一种非常基础的神经网络层,用于执行线性变换。

含义

nn.Linear(2, 1) 表示一个线性层,它接收一个具有2个特征的输入,并输出一个具有1个特征的结果。具体来说:

  • 第一个参数 2 表示输入特征的数量,即该层期望的输入维度是2。
  • 第二个参数 1 表示输出特征的数量,即该层输出的维度是1。

作用

这个线性层的作用是对输入的2个特征进行线性组合,然后输出一个单一的数值。数学上,这个过程可以表示为:

y = x1 * w1 + x2 * w2 + b

其中:

  • x1 和 x2 是输入特征。
  • w1 和 w2 是权重,它们在训练过程中会被学习。
  • b 是偏置项,也是一个在训练过程中会被学习的参数。
  • y 是该层的输出。

可能的应用场景

nn.Linear(2, 1) 可以应用于多种场景,特别是当需要将两个特征合并为一个单一特征时。以下是一些具体的例子:

  1. 回归问题:在简单的回归问题中,如果你有两个特征并希望预测一个连续的数值输出,你可以使用 nn.Linear(2, 1)。例如,预测房价时,你可能会根据房屋的面积和卧室数量来预测价格。

  2. 特征压缩:在某些情况下,你可能希望将多个特征压缩成一个特征,以便于后续处理或可视化。例如,在降维或特征工程中,nn.Linear(2, 1) 可以用于将两个特征转换为一个新的综合特征。

  3. 神经网络的一部分:在构建更复杂的神经网络时,nn.Linear(2, 1) 可以作为神经网络的一部分。例如,在多层感知机(MLP)中,这样的层可以与其他层(如激活层、dropout层等)结合使用,以构建能够处理复杂任务的模型。

需要注意的是,虽然 nn.Linear(2, 1) 本身只能执行线性变换,但在实际使用时,通常会与其他非线性层(如ReLU或sigmoid激活函数)结合使用,以构建能够学习非线性关系的模型。

标签:知识点,Linear,nn,输出,特征,torch,PyTorch,输入
From: https://blog.csdn.net/Oxford1151/article/details/137406415

相关文章

  • 机器学习知识点全面总结
    机器学习按照模型类型分为监督学习模型、无监督学习模型两大类。1、有监督学习有监督学习通常是利用带有专家标注的标签的训练数据,学习一个从输入变量X到输入变量Y的函数映射。Y=f(X),训练数据通常是(n×x,y)的形式,其中n代表训练样本的大小,x和y分别是变量X和Y的样本值。......
  • Yann Lecun-纽约大学-深度学习(PyTorch)
    课程介绍    本课程涉及深度学习和表示学习的最新技术,重点是有监督和无监督的深度学习,嵌入方法,度量学习,卷积和递归网络,并应用于计算机视觉,自然语言理解和语音识别。前提条件包括:DS-GA1001数据科学入门或研究生水平的机器学习课程。     免费获取:YannLecun-纽约......
  • 【HTML5+CSS3】HTML知识点+自主练习
    一、W3C标准结构:HTML表现:CSS行为:JavaScript二、HTML常用标签排版标签(标题标签、段落标签、换行标签、分割标签、 文本格式化标签)媒体标签(图片标签、音视频标签)超链接标签(超链接标签)布局标签(div标签、span标签、HTML5新增语义化标签)三、HTML学生示例代码​<!--......
  • 一文了解JVM所有知识点
    文章目录类的加载过程Java虚拟机中有哪些类加载器?什么是双亲委派模型?为什么使用双亲委派模式?有哪些场景破坏了双亲委派模型SPI机制自定义类加载器破坏双亲委派机制线程上下文类加载器破坏双亲委派机制运行时数据区java中常用的常量池class模板类存放在哪里?元空间为什......
  • 趣学前端 | 类,我想好好继承它的知识点
    背景最近睡前习惯翻会书,重温了《JavaScript权威指南》。这本书,文字小,内容多。两年了,我才翻到第十章。因为书太厚,平时都充当电脑支架。JavaScript类话说当年类、原型、继承,差点给我绕晕。在JavaScript中,类使用基于原型的继承。如果两个对象从同一个原型继承属性(通常是以函......
  • [ABC211F] Rectilinear Polygons 题解
    [ABC211F]RectilinearPolygons题解思路什么的上一篇题解已经写的非常明白了,这里只是提供一个补充&另一个实现的方法。思路解析先说结论:扫描线。顾名思义,扫描线的本质就是用一条线沿着\(x\)或\(y\)轴扫过去,每碰到一条边就记录一下加边后是面积是增加还是减少,然后用树状......
  • Java最短路径算法知识点(含面试大厂题和源码)
    最短路径算法是计算机科学和图论中的核心问题之一,它旨在找到从一个顶点到另一个顶点或在所有顶点之间的最短路径。这个问题在多种实际应用中都非常重要,如网络路由、交通规划、社交网络分析等。以下是一些与最短路径算法相关的知识点:Dijkstra算法:由荷兰计算机科学家艾兹......
  • Java归并排序知识点(含面试大厂题和源码)
    归并排序是一种有效的排序算法,采用分治法(DivideandConquer)策略。它将数组分成两半,对每一半递归地进行排序,然后将两个有序的半部分合并成一个整体的有序数组。归并排序在最坏情况、平均情况和最好情况下都保持(O(n\logn))的时间复杂度,是一种稳定的排序算法。由于其分而治......
  • Java快速排序知识点(含面试大厂题含源码)
    快速排序是一种高效的排序算法,由C.A.R.Hoare在1960年提出。它的基本思想是分而治之(DivideandConquer)。快速排序的关键在于选取一个“基准值”(pivot),然后将数组分为两个子数组:一个包含所有小于基准值的元素,另一个包含所有大于基准值的元素。这个过程称为“分区”(partitio......
  • 【保姆级教程附代码】Pytorch (.pth) 到 TensorRT (.plan) 模型转化全流程
    整体流程为:.pth->.onnx->.plan(或.trt,二者等价)需要的工具和包:Docker,Pytorch,ONNX,onnxruntime,TensorRT(trtexec和polygraphy).pth到.onnx这里以SwinIR(https://github.com/JingyunLiang/SwinIR)预训练模型为例init_torch_model()函数主要是对模型初始化,这里是......