首页 > 其他分享 >pytorch(三)

pytorch(三)

时间:2024-11-05 20:18:15浏览次数:3  
标签:torchvision pth vgg16 torch pytorch print import

01 现有网络模型的使用及修改

import torchvision
from torch import nn

#train_data = torchvision.datasets.ImageNet("../data_image_net",split='train',download =True,
   #                                        transform = torchvision.transforms.ToTensor())

vgg16_false = torchvision.models.vgg16(pretrained=False)
vgg16_true = torchvision.models.vgg16(pretrained=True)

print(vgg16_true)
train_data = torchvision.datasets.ImageNet("../data_image_net",split='train',download =True,
                                            transform = torchvision.transforms.ToTensor())

vgg16_true.add_module('add_linear',nn.Linear(1000,10))
print(vgg16_true)

print(vgg16_false)
vgg16_false.classifier[6] = nn.Linear(4096,10)
print(vgg16_false)

02 网络模型的保存

import torchvision
import torch
from torch import nn
vgg16 = torchvision.models.vgg16(pretrained=False)
#保存方式1   模型结构+模型参数
torch.save(vgg16,"vgg16_method.pth")

#保存方式2   模型参数(官方推荐)
torch.save(vgg16.state_dict(),"vgg16_method2.pth")#只保存网络模型中的参数

#存在的陷阱
class Tudui(nn.Module):
    def __init__(self):
        super(Tudui,self).__init__()
        self.conv1 = nn.Conv2d(3,64,kernel_size=3)
        
    def forward(self,x):
        x = self.conv1(x)
        return x
    
tudui = Tudui()
torch.save(tudui,"tudui_method1.pth")

03 网络模型的提取 

import torch
import torchvision

#方式1,加载模型
model = torch.load("vgg16_method1.pth")
#print(model)

#方式2,加载模型
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))
#model = torch.load("vgg16_method2.path")
#print(model)


#陷阱1
model = torch.load('tudui_method1.pth')
print(model)

 

标签:torchvision,pth,vgg16,torch,pytorch,print,import
From: https://blog.csdn.net/weixin_53294261/article/details/143375235

相关文章

  • 【图神经网络】 AM-GCN代码实战(1)【pytorch】代码可运行
    AM-GCN网络系列代码实践部分1.环境设置2.代码运行指令2.1命令行执行代码2.1IDE执行(1)2.2IDE执行(2)3.参数选择4.总结代码实践部分本专栏致力于深入探讨图神经网络模型相关的学术论文,并通过具体的编程实验来深化理解。读者可以根据个人兴趣选择相关内容进行学......
  • 基于卷积神经网络的大豆病虫害识别与防治系统,resnet50,mobilenet模型【pytorch框架+pyt
     更多目标检测和图像分类识别项目可看我主页其他文章功能演示:大豆病虫害识别与防治系统,卷积神经网络,resnet50,mobilenet【pytorch框架,python源码】_哔哩哔哩_bilibili(一)简介基于卷积神经网络的大豆病虫害识别与防治系统是在pytorch框架下实现的,这是一个完整的项目,包括代码,......
  • 基于AMD显卡安装Pytorch(小白攻略)
    安装的时候看了很多博客,踩了一些雷,现在把成功下载的流程汇总下来。假设这个时候已经安装好了ubuntu,我安装的是ubuntu22.04.安装rocmLinux®DriversforAMDRadeon™andRadeonPRO™Graphics可以点击上面这个链接,点击ubuntux8664-bit.我选的是带rocm的版本复制这......
  • 基于yolov8的生猪检测和统计系统,支持图像、视频和摄像实时检测【pytorch框架、python
     更多目标检测和图像分类识别项目可看我主页其他文章功能演示:基于yolov8的生猪检测和统计系统,支持图像、视频和摄像实时检测【pytorch框架、python源码】_哔哩哔哩_bilibili(一)简介基于yolov8的生猪检测和统计系统是在PyTorch框架之下得以实现的。这是一个完备的项目,涵盖......
  • PyTorch(torch.cuda.empty_cache())
    目录1.函数功能2.使用场景3.注意事项4.示例代码5.相关函数torch.cuda.empty_cache()是PyTorch中用于清理GPU上缓存的内存的函数。这个函数不会影响GPU上存储的实际张量数据,只是释放了由缓存机制占用的内存。在深度学习模型的训练过程中,经常需要释放不再使用的GPU......
  • 基于YOLOv8模型的高精度红外行人车辆目标检测(PyTorch+Pyside6+YOLOv8模型)
    摘要:基于YOLOv8模型的高精度红外行人车辆目标检测系统可用于日常生活中检测与定位红外行人车辆目标,利用深度学习算法可实现图片、视频、摄像头等方式的目标检测,另外本系统还支持图片、视频等格式的结果可视化与结果导出。本系统采用YOLOv8目标检测算法训练数据集,使用Pysdie6库......
  • PyTorchStepByStep - Chapter 9: Sequence-to-Sequence
     points,directions=generate_sequences(n=256,seed=13)Andthenlet’svisualizethefirstfivesquares:classEncoder(nn.Module):def__init__(self,n_features,hidden_dim):super().__init__()self.n_features=n_features......
  • 基于PyTorch的大语言模型微调指南:Torchtune完整教程与代码示例
    近年来,大型语言模型(LargeLanguageModels,LLMs)在自然语言处理(NaturalLanguageProcessing,NLP)领域取得了显著进展。这些模型通过在大规模文本数据上进行预训练,能够习得语言的基本特征和语义,从而在各种NLP任务上取得了突破性的表现。为了将预训练的LLM应用于特定领域或......
  • pytorch自动微分
    求导是几乎所有深度学习优化算法的关键步骤,因为在优化损失函数时会用反向传播,即使参数朝着梯度下降的方向调整,求梯度即求偏导。虽然求导的计算很简单,但对于复杂的模型,手动进行更新很容易出错。Pytorch通过自动微分来加快求导。他会先构建一个计算图(computationalgraph),来跟踪计......
  • 关于图神经网络框架Pytorch_geometric实战应用,并给出详细代码实现过程
    大家好,我是微学AI,今天给大家介绍一下关于图神经网络框架Pytorch_geometric实战应用,并给出详细代码实现过程,本文展示了如何利用该框架进行图神经网络的搭建与训练。文章涵盖了从数据预处理、模型构建、参数调优到模型评估等各个环节,旨在帮助读者深入理解并掌握Pytorch_geome......