首页 > 其他分享 >pytorch使用交叉熵训练模型学习笔记

pytorch使用交叉熵训练模型学习笔记

时间:2024-06-17 17:33:05浏览次数:26  
标签:__ nn 交叉 self torch 笔记 pytorch fc SimpleModel

python代码:

import torch
import torch.nn as nn
import torch.optim as optim

# 定义一个简单的神经网络模型
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(3, 2)  # 输入3维,输出2类

    def forward(self, x):
        return self.fc(x)

# 创建模型实例
model = SimpleModel()

# 创建损失函数
criterion = nn.CrossEntropyLoss()

# 创建优化器
optimizer = optim.SGD(model.parameters(), lr=0.01)

# 定义固定的训练数据
inputs = torch.tensor([
    [0.3, 0.2, 0.1],
    [0.4, 0.5, 0.6],
    [9, 8, 7],
    [10, 11, 102],
    [3, 2, 1],
    [4, 5, 6],
    [9, 8, 7],
    [1.0, 1.1, 102],
    [0.3, 0.2, 0.1],
    [0.4, 0.5, 0.6],
], dtype=torch.float32)  # 10个样本,每个样本3维

targets = torch.tensor([0, 1, 0, 1, 0, 1, 0, 1, 0, 1], dtype=torch.long)  # 真实标签

# 训练模型1000次
num_epochs = 1000

for epoch in range(num_epochs):
    # 前向传播
    outputs = model(inputs)

    # 计算损失
    loss = criterion(outputs, targets)

    # 反向传播和优化
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # 打印损失值
    if (epoch + 1) % 100 == 0:  # 每100次打印一次损失值
        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')

print('Training complete.')

# 定义一个待测试的样本
test_sample = torch.tensor([0.5, 5, 500], dtype=torch.float32)

# 预测结果
with torch.no_grad():  # 在测试时不需要计算梯度
    test_output = model(test_sample)
    probabilities = torch.softmax(test_output, dim=0)
    predicted_class = torch.argmax(test_output).item()
    print(f'Test Sample Prediction: Class {predicted_class}')
    print(f'Probabilities: {probabilities.tolist()}')
View Code

运行结果

 

此代码输入10个样本,每个样本3个特征,输出2个类别(升序特征为1,降序特征为0)

下面详细讲解一下 SimpleModel 这个模型的定义和构造

SimpleModel 模型定义

class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(3, 2)  # 输入3维,输出2类

    def forward(self, x):
        return self.fc(x)

1.类定义

class SimpleModel(nn.Module):

这行代码定义了一个名为 SimpleModel 的类,该类继承自 nn.Modulenn.Module 是 PyTorch 中所有神经网络模块的基类。通过继承 nn.Module,我们可以利用 PyTorch 提供的许多有用的功能,例如参数管理和自动求导。

2.初始化方法

    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(3, 2)  # 输入3维,输出2类
  • __init__ 方法是类的构造函数,当创建 SimpleModel 类的实例时会被调用。
  • super(SimpleModel, self).__init__() 调用父类(即 nn.Module)的构造函数,初始化父类中的一些属性
  • self.fc = nn.Linear(3, 2) 创建一个全连接层(也叫线性层),输入维度为3,输出维度为2。
  • nn.Linear 是 PyTorch 中定义全连接层的类。它接受两个参数:输入特征数和输出特征数。这里输入特征数是3,输出特征数是2。

3.前向传播方法

def forward(self, x):
    return self.fc(x)
  • forward 方法定义了模型的前向传播过程,即当输入数据 x 传递给模型时,数据如何通过各层计算最终输出。
  • self.fc(x) 将输入 x 传递给全连接层 self.fc,并返回输出

模型的具体构造

 

1. 全连接层(Linear Layer)

 

  • 定义:全连接层是指每个输入节点都与输出节点相连,进行线性变换。它的数学表示是:y = xW^T + b,其中 W 是权重矩阵,b 是偏置向量。
  • 输入和输出:在这个模型中,全连接层的输入维度是3,输出维度是2。这意味着输入数据应该有3个特征,输出将是2个类别的分数。

 

2. 模型的参数

 

  • 权重矩阵:大小为 [2, 3],表示从3维输入到2维输出的连接权重。
  • 偏置向量:大小为 [2],每个输出类别一个偏置。

 

数据流动

 

假设输入数据 x 是一个大小为 [batch_size, 3] 的张量,其中 batch_size 是每批数据的样本数量:

 

  1. 输入数据 x 通过全连接层 self.fc
  2. 全连接层计算输出 y = xW^T + b,输出是一个大小为 [batch_size, 2] 的张量,其中每行是一个样本在两个类别上的分数。

 

 

 

 

标签:__,nn,交叉,self,torch,笔记,pytorch,fc,SimpleModel
From: https://www.cnblogs.com/lizhiqiang0204/p/18252876

相关文章

  • 苍穹外卖笔记-15-订单状态定时处理、来单提醒和客户催单
    文章目录苍穹外卖-day101.SpringTask1.1介绍1.2cron表达式1.3入门案例1.3.1SpringTask使用步骤1.3.2代码开发1.3.3功能测试2.订单状态定时处理2.1需求分析2.2代码开发2.3功能测试3.WebSocket3.1介绍3.2入门案例3.2.1案例分析3.2.2代码开发3.2.3功能......
  • 手把手教NLP小白如何用PyTorch构建和训练一个简单的情感分类神经网络
        在当今的深度学习领域,神经网络已经成为解决各种复杂问题的强大工具。本文将通过一个实际案例——对Yelp餐厅评论进行情感分类,来介绍如何使用PyTorch构建和训练一个简单的神经网络模型。我们将逐步讲解神经网络的基础概念,如激活函数、损失函数和优化器,并最终实现一......
  • 复习笔记二(动态规划法)
    工作指派问题(20分)设有n件工作,n个人,每个人只能做一件工作,每件工作只能安排给一个人,已知每个人做每件工作的耗费,请设计分支限界算法求解最少耗费的工作指派。要求:(1)对问题进行分析;(9分)(2)给出分支限界算法的伪代码描述;(8分)......
  • git学习笔记——202406171525
    想将本地仓库代码提交到远程仓库,应注意:如果在新建远程仓库时里面还新建了文件,在本地提交代码时会显示两个分支是冲突的,git认为是两个不相关的仓库代码,会拒绝上传。解决方法是gitpullremotemaster拉取远程代码到本地,然后再gitpushremote-umaster相关链接:https://www.cn......
  • 交叉编译python第三方库
    这里我们以编译androidpython程序为例工具crossenv名词对于交叉编译的各个部分,没有标准的词汇表,不同的资源经常会使用相互矛盾的术语。为了避免混淆,我们只使用GNU术语,这是Python本身使用的。host就是你编译出来的包要运行的平台,比如这里是Androidbuild进......
  • Java速成笔记 2024.6.17版
    变量:可以变化的容器不同变量可以存储不同类型的值变量声明方法:变量类型变量名=初始值;E.G.inta=1;变量类型:整型:intlong浮点数:floatdouble布尔:boolean字符串:String字符:char变量命名注意事项:不能重名不能以数字开头常量:关键字:final语法:finalfl......
  • 从事网络安全领域吃香吗?零基础入门精通就业,附学习笔记
    吃香是真的会吃香?但是很辛苦。在安服这行工作是做的痛并快乐着。工作是没有轻松的,都是付出和回报成正比的,而且还要不停学习提升,丝毫不敢懈怠。不可能不加班,能正常作息很难,网络安全系列的岗位是IT行业里最辛苦的,接触了太多圈内同行朋友,基本上都是007,996真是福报奢侈,有时候......
  • 程序员修炼之道:从小工到专家阅读笔记02
    做程序要及时亡羊补牢修复,这意味着在编程过程中,我们需要时刻关注代码的质量,一旦发现潜在的问题或错误,立即进行修复。遵循编码规范和风格指南,编写易于维护和阅读的代码。这样可以降低出错的可能性,并在出现问题时更容易进行修复。在发现问题时,及时与团队成员沟通,分享自己的发现和解......
  • 程序员修炼之道:从小工到专家阅读笔记01
    程序员要勇于承担错误,这意味着在编程过程中,我们需要敢于面对和解决出现的问题。以下是一些关于勇于承担错误的建议:诚实面对错误:当发现程序中的错误时,不要试图掩盖或忽视它们。诚实地面对问题,承认自己的错误,并寻求解决方案。分析错误原因:在解决问题之前,首先要了解错误发生的原因。......
  • 程序员修炼之道:从小工到专家阅读笔记04
    耦合这个词基本在我的职业生涯中每天都能听到,一个好的程序一定是低耦合的,这本书提出了函数的德墨忒尔法则帮我们更好的界定耦合的边界,怎样编写低耦合的代码,更难能可贵的是这本书不仅仅描述了一般的代码耦合,还花了很大笔墨解释了时间耦合,很多时候一个业务的实现没有必要一定是线性......