首页 > 其他分享 >pytorch 模型加载和保存

pytorch 模型加载和保存

时间:2024-08-06 10:06:11浏览次数:16  
标签:load 模型 torch 保存 pytorch model 加载

模型加载

torch.load(f, map_location=None, pickle_module=<module 'pickle' from '/opt/conda/lib/python3.6/pickle.py'>, **pickle_load_args)

 

map_location适用于修改模型能在gpu上运行还是cpu上运行。

一般情况下,加载模型,主要用于预测新来的一组样本。预测的主要流程包括:输入数据——预处理——加载模型——预测得返回值(类别或者是属于某一类别的概率)

def predict(test_data, model_path, config):
    '''
    input:
           test_data:测试数据
           model_path:模型的保存路径 model_path = './save/20201104_204451.ckpt'
    output:
           score:模型输出属于某一类别的概率
    '''
    data = process_data_for_predict(test_data)#预处理数据,使得数据格式符合模型输入形式
    model = torch.load(model_path)#加载模型
    score = model(data)#模型预测
    return score #返回得分

 Pytorch模型 .pt, .pth, .pkl的区别

后缀名为.pt, .pth, .pkl的pytorch模型文件,在格式上其实没有区别,只是后缀不同而已

模型的保存和加载有两种方式:

(1) 仅仅保存和加载模型参数

# 保存
torch.save(the_model.state_dict(), PATH='mymodel.pth')        #只保存模型权重参数,不保存模型结构

# 调用
the_model = TheModelClass(*args, **kwargs)                    #这里需要重新模型结构,TheModelClass
the_model.load_state_dict(torch.load('mymodel.pth'))        #这里根据模型结构,调用存储的模型参数

(2) 保存和加载整个模型

# 保存
torch.save(the_model, PATH)                                    #保存整个model的状态

# 调用
the_model = torch.load(PATH)                                #这里已经不需要重构模型结构了,直接load就可以

第一种方式需要自己定义网络,并且其中的参数名称与结构要与保存的模型中的一致(可以是部分网络,比如只使用VGG的前几层),相对灵活,便于对网络进行修改。第二种方式则无需自定义网络,保存时已把网络结构保存,比较死板,不能调整网络结构。

标签:load,模型,torch,保存,pytorch,model,加载
From: https://www.cnblogs.com/arwen-xu/p/18344611

相关文章

  • SemanticKernel/C#:实现接口,接入本地嵌入模型
    前言本文通过Codeblaze.SemanticKernel这个项目,学习如何实现ITextEmbeddingGenerationService接口,接入本地嵌入模型。项目地址:https://github.com/BLaZeKiLL/Codeblaze.SemanticKernel实践SemanticKernel初看以为只支持OpenAI的各种模型,但其实也提供了强大的抽象能力,可以通过......
  • 基于R语言复杂数据回归与混合效应模型【多水平/分层/嵌套】技术与代码
    回归分析是科学研究特别是生态学领域科学研究和数据分析十分重要的统计工具,可以回答众多科学问题,如环境因素对物种、种群、群落及生态系统或气候变化的影响;物种属性和系统发育对物种分布(多度)的影响等。纵观涉及数量统计方法生态学论文中几乎都能看到回归分析的身影。随着现代统......
  • 驱动开发系列09 - Linux设备模型之设备,驱动和总线
    一:概述     Linux设备模型(LDM)是Linux内核中引入的一个概念。用于管理内核对象(那些需要引用计数的对象、例如文件、设备、总线甚至驱动程序),以及描述它们之间的层次结构,以及这些内核对象之间绑定关系。Linux设备模型引入了对象生命周期管理、引用计数、以及面向对象......
  • python图表没有正确显示中文,这通常是因为matplotlib的默认设置不支持中文字符,或者相应
    如果图表没有正确显示中文,这通常是因为matplotlib的默认设置不支持中文字符,或者相应的字体没有正确加载。你可以通过指定支持中文的字体来解决这个问题。下面是如何设置matplotlib以确保能够在图表中显示中文的步骤:方法1:全局设置字体你可以修改matplotlib的全局配置,使......
  • 为什么企业要微调大模型
    一、模型是什么?在人工智能领域,模型是指通过对数据进行分析和学习,建立的一种数学结构或算法,用于预测或分类新数据。简单来说,模型是从数据中提取知识,并应用这些知识对未来进行预测的工具。一个基本的线性模型可以表示为:Y=WX其中,Y是预测值,W是权重矩阵,X是输入数据。通过训......
  • USB通讯架构及数据模型
    注意:(1)一个usb设备由一个或者多个接口组成;(2)每一个接口为usb设备的一个功能,比如上面的usb设备由两个接口,一个可用于鼠标,一个可用于键盘;(3)每个接口占用usb设备的多个端口资源;(4)windows通过一组管道(pipes)与usb设备的某个接口的端点进行数据交互实现某种功能;(5)usb设备最多具有16个......
  • 解锁GraphRag.Net的无限可能:手把手教你集成国产模型和本地模型
        在上次的文章中,我们已经详细介绍了GraphRag的基本功能和使用方式。如果你还不熟悉,建议先阅读前面的文章    通过前两篇文章,相信你已经了解到GraphRag.Net目前只支持OpenAI规范的接口,但许多小伙伴在社区中提议,希望能增加对本地模型(例如:ollama等)的支持。所以这......
  • FLUX.1最强AI绘画开源新模型,本地部署教程!
    原文链接:FLUX.1最强AI绘画开源新模型,本地部署教程!(chinaz.com)Flux最近收到了很多模型爱好者的好评,出图质量超越SD3和MJ,许多人说Flux才是大家心目中的SD3,所以我也是非常好奇FLux的实力在这里把本地部署的过程分享给大家官网参考图:Flux官网首页:https://blackforestlabs.ai......
  • 生成 512x512 照片的模型
    我怎样才能让这个模型生成512x512像素或更大的图像?现在它生成64x64px图像。我尝试更改模型中的一些值,但没有成功。这些卷积层(尤其是Conv2D和Conv2DTranspose)如何工作?我不明白如何在这些层中调整图像的大小。importtensorflowastffromtensorflowimportkerasfrom......
  • 为什么 Langchain HuggingFaceEmbeddings 模型尺寸与 HuggingFace 上所述的不一样
    我使用的是langchainHuggingFaceEmbeddings模型:dunzhang/stella_en_1.5B_v5。当我查看https://huggingface.co/spaces/mteb/leaderboard时,我可以看到型号是8192。但当我这样做时len(embed_model.embed_query("heyyou"))它给了我1024。请问为什么会有这种差......