首页 > 其他分享 >Pytorch模型文件`*.pt`与`*.pth` 的保存与加载

Pytorch模型文件`*.pt`与`*.pth` 的保存与加载

时间:2024-07-19 10:20:42浏览次数:10  
标签:pt pth self torch state Pytorch dict model 模型

1.*.pt文件

.pt文件保存的是模型的全部,在加载时可以直接赋值给新变量model = torch.load("filename.pt")

具体操作:

  • (1). 模型的保存
torch.save(model,"Path/filename.pt")
  • (2). 模型的加载
model = torch.load("filename.pt")

注意:torch.load()的参数使用字符串参数。

2. .pth文件

.pth保存的是模型参数,通过字符字典进行保存,在加载该类文件时应该先实例化一个具体的模型,然后对新建立的空模型,进行参数赋予。

具体操作:

  • (1). 模型的保存
torch.save(model.state_dict(), PATH)
  • (2). 模型的加载
model = nn.Module() # 这里要先实例化模型
model.load_state_dict(torch.load("filename.pth"))

操作实例

  1. 首先定义一个模型作为例子
# Define model
class TheModelClass(nn.Module):
    # 类的初始化
    def __init__(self):
        # 继承父类 nn.Module 的属性和方法
        super(TheModelClass, self).__init__()
        # Inputs_channel, Outputs_channel, kernel_size
        self.conv1 = nn.Conv2d(3, 6, 5)
        # 最大池化层,池化核的大小
        self.pool = nn.MaxPool2d(2, 2)
        # 卷积层,池化层,卷积层
        self.conv2 = nn.Conv2d(6, 16, 5)
        # 最后接一个线性全连接层
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        # 卷积作用后,使用relu进行非线性化,最后使用池化操作进行特征个数,参数量的降低
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Initialize model
model = TheModelClass()

# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

2. 现在开始进行模型的保存与加载

PATH = "/home/深度学习/model"
#  第一种模型保存和加载方式
torch.save(model.state_dict(), PATH+"/TheModuleClass.pth")
model = TheModelClass()
model.load_state_dict(torch.load("/home/深度学习/model/TheModuleClass.pth"))

for param_tensor in model.state_dict():
    print(f"{param_tensor}<<<{model.state_dict()[param_tensor].size()}")
print(model)

# 输出结果
'''
conv1.weight<<<torch.Size([6, 3, 5, 5])
conv1.bias<<<torch.Size([6])
conv2.weight<<<torch.Size([16, 6, 5, 5])
conv2.bias<<<torch.Size([16])
fc1.weight<<<torch.Size([120, 400])
fc1.bias<<<torch.Size([120])
fc2.weight<<<torch.Size([84, 120])
fc2.bias<<<torch.Size([84])
fc3.weight<<<torch.Size([10, 84])
fc3.bias<<<torch.Size([10])
TheModelClass(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)
'''
#  第二种模型保存和加载方式

torch.save(model, PATH + "/the_module_class.pt")
model = torch.load(PATH + "/the_module_class.pt")

for param_tensor in model.state_dict():
    print(f"{param_tensor} <<< {model.state_dict()[param_tensor].size()}")
print(model)

#  输出结果
'''
conv1.weight<<<torch.Size([6, 3, 5, 5])
conv1.bias<<<torch.Size([6])
conv2.weight<<<torch.Size([16, 6, 5, 5])
conv2.bias<<<torch.Size([16])
fc1.weight<<<torch.Size([120, 400])
fc1.bias<<<torch.Size([120])
fc2.weight<<<torch.Size([84, 120])
fc2.bias<<<torch.Size([84])
fc3.weight<<<torch.Size([10, 84])
fc3.bias<<<torch.Size([10])
TheModelClass(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)
'''

总结

  • 这里推荐使用第二种方法,因为保存和加载文件简单,而且生成的二进制文件区分程度高。
  • torch.save() 保存模型的参数,为以后模型推理核模型恢复提供了更加方便更加灵活的方法。
  • 一定要在模型评估时, 关闭批量规范化和丢弃法, 仅仅在模型训练时有用,模型推理时一定要关闭(所谓模型推理,指是使用模型进行的实际应用)
  • 加载.pth 要先实例化,再进行参数的承接。

标签:pt,pth,self,torch,state,Pytorch,dict,model,模型
From: https://www.cnblogs.com/conpi/p/18310901

相关文章

  • springboot+vue+mybatis销售评价系统+PPT+论文+讲解+售后
    随着科学技术的飞速发展,社会的方方面面、各行各业都在努力与现代的先进技术接轨,通过科技手段来提高自身的优势,销售评价系统当然也不能排除在外。销售评价系统是以实际运用为开发背景,运用软件工程开发方法,采用Java技术构建的一个管理系统。整个开发过程首先对软件系统进行需求分......
  • 论文《AdaLoRA: Adaptive Budget Allocation for Parameter-Efficient Fine-Tuning》
    在大模型微调的理论中,AdaLoRA方法是一个绕不开的部分。 这篇论文主要提出了一种新的自适应预算分配方法AdaLoRA,用于提高参数高效的微调性能。AdaLoRA方法有效地解决了现有参数高效微调方法在预算分配上的不足,提高了在资源有限情况下的模型性能,为NLP领域的实际应用提供了新的......
  • datagrip启动报错Exception Type:EXC_BAD_ACCESS (SIGABRT)
    本人电脑背景:mac10.15安装datagrip2024版本,根据官方描述,这个版本是不支持的,但是本着试试的态度安装,毕竟也想用新版本。结果遇到了问题。启动打不开,由于错误信息较多,大概整理出来描述如下:ExceptionType:EXC_BAD_ACCESS(SIGABRT)ExceptionCodes:KERN_INVALID_......
  • SQL 按照dept_no进行汇总
    系列文章目录文章目录系列文章目录前言前言前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家。点击跳转到网站,这篇文章男女通用,看懂了就去分享给你的码吧。描述按照dept_no进行汇总,属于同一个部门的emp_no按照逗号进行连接,结果给......
  • 【笔记】【THM】Introduction to Cryptography(密码学简介)
    【THM】IntroductiontoCryptography(密码学简介)-学习本文相关的TryHackMe实验房间链接:https://tryhackme.com/r/room/cryptographyintro本文相关内容:了解AES、Diffie-Hellman密钥交换、哈希、PKI和TLS等加密算法。(大部分为机翻,若有错误请指出)介绍这个房间的目的是向......
  • 在Python中doc转docx,xls转xlsx,ppt转pptx(Windows)
    在Python中doc转docx,xls转xlsx,ppt转pptx(Windows)说明:首次发表日期:2024-07-18参考pypi包:doc2docx缘起我们一般使用Python开发RAG应用,或者使用基于Python开发的开源RAG工具,比如Dify。然而由于Python中对.doc和.ppt格式的文件支持不够好,通常我们需要将文件格式转换为.docx和.p......
  • SciTech-Mathmatics-Statistics-NumPy and Statistics: Descriptive Statistics
    StatisticsFromNumPyOfficialDocs.https://numpy.org/doc/stable/reference/routines.statistics.htmlOrderstatisticsnumpy.percentilenumpy.percentile(a,q,axis=None,out=None,overwrite_input=False,method='linear',keepdims=False,*,weig......
  • 【Pytorch】小土堆自学日记(六)
    目录一、神经网络的基本骨架-nn.Moudle的使用1.torch.nn官方文档:2.containers文档:①Moudle的作用: ②Moudle的示例代码:③forward函数官方解释:二、卷积操作:1.常用的卷积:2.Conv2d①使用方法:②Stride参数理解:③代码:④padding参数讲解:一、神经网络的基本骨架-nn.M......
  • FastStone Capture v10.6 解锁版 (一款优秀的支持屏幕录制、滚动截图、高清长图、图片
    前言FastStoneCapture是一款极简主义的应用程序,它简单易用,可以捕捉屏幕上的任意区域,提供多种捕获模式,包括活动窗口、指定窗口/对象、矩形区域、手绘区域、整个屏幕和滚动窗口等。此外,FastStoneCapture还附带屏幕录像机、放大镜、取色器和标尺等辅助功能。其体积小巧,但功能强......
  • 过滤器(Filter)和拦截器(Interceptor)的执行顺序和区别
    https://www.cnblogs.com/kuotian/p/13176186.html过滤器FilterFilter有如下几个用处。Filter有如下几个种类。javax.servlet.Filter接口1.通过@WebFilter注解配置2.通过@Bean来配置3.SpringMVC在web.xml配置过滤器启动测试拦截器InterceptorHandlerIn......