首页 > 其他分享 >pytorch深度学习分类代码简单示例

pytorch深度学习分类代码简单示例

时间:2024-08-07 10:39:23浏览次数:9  
标签:示例 pred 代码 torch pytorch new print output model

train.py代码如下

import torch
import torch.nn as nn
import torch.optim as optim

model_save_path = "my_model.pth"

# 定义简单的线性神经网络模型
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.output = nn.Linear(2, 4)   # 输入2个特征,输出4个类别

    def forward(self, x):
        x = self.output(x)
        return x

def main():
    # 数据点
    x = torch.tensor([[0, 0], [0, 10], [10, 0], [10, 10]], dtype=torch.float32)
    y = torch.tensor([0, 1, 2, 3], dtype=torch.long)

    # 初始化模型
    model = MyModel()

    # 定义损失函数和优化器
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.01)

    # 训练模型
    num_iterations = 10000  # 迭代次数
    for i in range(num_iterations):
        model.train()

        # 前向传播:计算预测输出
        y_pred = model(x)

        # 计算损失
        loss = criterion(y_pred, y)

        # 输出每1000次迭代的损失值
        if i % 1000 == 0:
            print(f"迭代 {i},损失:{loss.item():.4f}")

        # 反向传播与梯度更新
        optimizer.zero_grad()  # 清除梯度
        loss.backward()  # 计算梯度
        optimizer.step()  # 更新参数

    # 打印优化后的权重和偏置
    print("\n优化后的权重和偏置:")
    for name, param in model.named_parameters():
        if param.requires_grad:
            print(f"{name} = {param.data.numpy()}")

    # 保存模型
    torch.save(model.state_dict(), model_save_path)
    print(f"模型已保存到 {model_save_path}")

if __name__ == "__main__":
    main()
View Code

运行结果

test.py代码如下

import numpy as np
import torch
from torch import nn

from train import MyModel, model_save_path

# 加载模型
loaded_model = MyModel()
loaded_model.load_state_dict(torch.load(model_save_path))
loaded_model.eval()  # 切换到评估模式

# 定义预测数据
input_data = [0, 9]

# 使用加载的模型进行预测
x_new = torch.tensor([input_data], dtype=torch.float32)  # 新数据点
y_new_pred = loaded_model(x_new)  # 计算预测值

# 使用softmax计算每个类别的概率
softmax = nn.Softmax(dim=1)
y_new_pred_probs = softmax(y_new_pred)

# 找到预测的类别
predicted_class = torch.argmax(y_new_pred_probs, dim=1)

# 将概率分布四舍五入到三位小数
y_new_pred_probs_rounded = np.round(y_new_pred_probs.detach().numpy(), 3)

print(f"\n对x = {input_data}的预测类别:{predicted_class.item()}")
print(f"预测类别的概率分布:{y_new_pred_probs_rounded}")

# 打印权重和偏置
weights = loaded_model.output.weight  # 获取输出层权重
bias = loaded_model.output.bias  # 获取输出层偏置

print(f"\n模型权重:\n{weights}")
print(f"\n模型偏置:\n{bias}")

# 计算input_data * 模型权重 + 模型偏置
with torch.no_grad():
    linear_output = x_new @ weights.t() + bias

print(f"\ninput_data * weights + bias ={linear_output.numpy()}")

# 手动计算Softmax概率分布
linear_output_np = linear_output.numpy()
exp_output = np.exp(linear_output_np)
softmax_output_manual = exp_output / np.sum(exp_output)

print(f"\n手动计算的Softmax概率分布:{softmax_output_manual}")
print(f"手动计算的预测类别:{np.argmax(softmax_output_manual)}")
View Code

运行结果

 

 

标签:示例,pred,代码,torch,pytorch,new,print,output,model
From: https://www.cnblogs.com/lizhiqiang0204/p/18346540

相关文章

  • 代码随想录算法训练营第61天 | 图论part08:拓扑排序+迪杰斯特拉朴素法
    117.软件构建https://kamacoder.com/problempage.php?pid=1191拓扑排序精讲https://www.programmercarl.com/kamacoder/0117.软件构建.html#拓扑排序的背景47.参加科学大会https://kamacoder.com/problempage.php?pid=1047dijkstra(朴素版)精讲https://www.programmercarl.c......
  • 企业为什么需要对源代码进行加密,12款源代码加密软件推荐
    在信息技术快速发展的今天,源代码是企业最为核心的知识产权之一。对源代码进行加密是保护企业竞争优势和知识产权的关键措施。1.保护知识产权:源代码是软件和技术的核心组成部分,未经授权的访问和泄露可能导致知识产权的损失。2.防止逆向工程:加密可以有效防止黑客通过逆向工......
  • 图片增加文本水印(右下角)--Java代码实现
    一.效果展示水印前                                               水印后        二.代码实现 /***在给定的图片上添加文本水印。**@paramsourceImgPath源图片路径*......
  • 神经网络之卷积篇:详解边缘检测示例(Edge detection example)
    详解边缘检测示例卷积运算是卷积神经网络最基本的组成部分,使用边缘检测作为入门样例。在这个博客中,会看到卷积是如何进行运算的。在之前的博客中,说过神经网络的前几层是如何检测边缘的,然后,后面的层有可能检测到物体的部分区域,更靠后的一些层可能检测到完整的物体,这个例子中就是......
  • 深入解析:23种软件设计模式详解及其分类(创建型、结构型、行为型)附代码示例DEMO
    目录引言一、创建型模式1.简单工厂模式(SimpleFactoryPattern)2.抽象工厂模式(AbstractFactoryPattern)3.单例模式(SingletonPattern)4.建造者模式(BuilderPattern)5.原型模式(PrototypePattern)二、结构型模式1.适配器模式(AdapterPattern)2.桥接模式(BridgePatt......
  • pytorch和deep learning技巧和bug解决方法短篇收集
    有一些几句话就可以说明白的观点或者解决的的问题,小虎单独收集到这里。torch.hub.loadhowdoesitwork下载预训练模型再载入,用程序下载链接可能失效。model=torch.hub.load('ultralytics/yolov5','yolov5s')model=torch.hub.load('ultralytics/yolov3','yolov3......
  • Apache 中的新零日漏洞允许远程代码执行
    ApacheOFBiz开源企业资源规划(ERP)系统中披露了一个新的零日预认证远程代码执行漏洞,该漏洞可能允许威胁行为者在受影响的实例上实现远程代码执行。该漏洞编号为CVE-2024-38856,CVSS评分为9.8(满分10.0)。该漏洞会影响18.12.15之前的ApacheOFBiz版本。发现并报告该漏洞......
  • 7 Python之代码类型提示(Type Hint)
     欢迎来到@一夜看尽长安花博客,您的点赞和收藏是我持续发文的动力对于文章中出现的任何错误请大家批评指出,一定及时修改。有任何想要讨论的问题可联系我:[email protected]。发布文章的风格因专栏而异,均自成体系,不足之处请大家指正。   专栏:java全栈C&C++PythonAIP......
  • c#12 实验特性Interceptor如何使用的一个简单但完整的示例
    一直有很多转载dotnet对Interceptor说明文档的,但鲜有说明Interceptor如何使用的,这里写一篇简单示例来展示一下c#12实验特性Interceptor是什么?官方解释如下(其实简单说就是语言特性中内置的静态编织方式的aop功能,不同于其他il修改代码的方式,使用上得结合sourcegenerater来生......
  • 代码随想录 day 47 回文子串 | 最长回文子序列
    回文子串回文子串解题思路dp数组的状态是判断以i结尾,j开始的字符串是否为回文,用bool类型存储,之后当i和j的字符串相等时,通过计算它们之间的距离和判断它们之间是否为回文串来进行递归。知识点回文,动态规划心得如果不看题解根本想不到怎么做最长回文子序列最长回文子序列......