首页 > 编程语言 >Python 机器学习 模型保存和加载

Python 机器学习 模型保存和加载

时间:2024-02-16 22:15:05浏览次数:40  
标签:Python 模型 joblib 序列化 pickle 加载

 

Python 机器学习中,模型保存和加载是两个非常重要的操作。模型保存可以将训练好的模型保存到文件,以便以后使用。模型加载可以将保存的文件加载到内存,以便进行预测或评估。最常用保存和加模型的库包括pickle和joblib,另外在使用特定的机器学习库,如scikit-learn、TensorFlow或PyTorch时,它们也提供了自己的保存和加载机制。

参考文档:Python 机器学习 模型保存和加载-CJavaPy

1、pickle

pickle模块是Python的一部分,提供了一个简单的方式来序列化和反序列化一个Python对象结构。训练好的模型通常需要被保存,以便于未来进行预测时能够直接加载使用,而不需要重新训练。pickle模块是Python中一个常用的进行对象序列化和反序列化的模块,它可以将Python对象转换为字节流,从而能够将对象保存到文件中,或者从文件中恢复对象。

1)模型的保存

pickle.dump()方法用于将Python对象序列化并保存到文件中。常用参数如下,

参数

类型

描述

obj

对象

要被序列化的Python对象。

file

文件对象

一个打开的文件对象,

必须以二进制写模式打开('wb')。

protocol

整数/None

指定pickle数据格式的版本号。

如果省略,则使用默认的协议。可选的协议版本号从0到5,

其中更高的版本

提供了更高的效率和新的功能。

fix_imports

布尔值

仅在Python 2和Python 3之间的互操作性中使用。

默认为True,

为了使pickle文件在不同的Python版本间

能够互相兼容。

buffer_callback

回调函数/None

一个可选的回调函数,

用于pickle协议版本5中,

为了提供对大型数据的优化处理机制。

仅在Python 3.8及以上版本中可用。

使用代码:

import pickle

# 创建一个复杂的数据结构
my_data = {
    'name': 'Python',
    'version': 3.8,
    'features': ['Speed', 'Flexibility', 'Community'],
    'rank': 1
}

# 指定pickle文件的名称
filename = 'my_data.pickle'

# 使用最新的pickle协议版本进行序列化(Python 3.8及以上支持协议5)
protocol_version = pickle.HIGHEST_PROTOCOL

# 序列化对象到文件
with open(filename, 'wb') as file:
    # 使用dump()方法并指定协议版本
    # 在这个例子中,fix_imports默认即可,因为我们不考虑跨Python主版本的兼容性
    pickle.dump(my_data, file, protocol=protocol_version)


# 打开文件
with open(filename, 'r' ,errors='ignore') as f:
    # 读取文件内容
    content = f.read()

# 打印文件内容
print(content)

2)模型的加载

要使用pickle.load()方法,首先需要有一个已经以二进制模式打开的文件,该文件包含了之前使用pickle.dump()方法序列化的Python对象。

import pickle

# 创建一个复杂的数据结构
my_data = {
    'name': 'Python',
    'version': 3.8,
    'features': ['Speed', 'Flexibility', 'Community'],
    'rank': 1
}

# 指定pickle文件的名称
filename = 'my_data.pickle'

# 使用最新的pickle协议版本进行序列化(Python 3.8及以上支持协议5)
protocol_version = pickle.HIGHEST_PROTOCOL

# 序列化对象到文件
with open(filename, 'wb') as file:
    # 使用dump()方法并指定协议版本
    # 在这个例子中,fix_imports默认即可,因为我们不考虑跨Python主版本的兼容性
    pickle.dump(my_data, file, protocol=protocol_version)


# 打开文件
with open(filename, 'r' ,errors='ignore') as f:
    # 读取文件内容
    content = f.read()

# 打印文件内容
print(content)
# 打开包含序列化数据的文件
with open(filename, 'rb') as file:
    my_object = pickle.load(file)
    print(my_object)

3)使用pickle保存和加载模型

import pickle
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.datasets import load_iris
from sklearn.metrics import accuracy_score

# 加载示例数据集,如鸢尾花数据集
iris = load_iris()
X = iris.data
y = iris.target

# 使用train_test_split函数划分数据集
# 测试集占比30%,保持类别比例,设置随机种子为42以确保结果一致性
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42, stratify=y)

# 创建KNN分类器实例
knn = KNeighborsClassifier(
    n_neighbors=3,  # 设置邻居的数量为3
    weights='distance',  # 设置权重为距离的倒数
    algorithm='kd_tree',  # 使用KD树算法
    leaf_size=40,  # 设置KD树/球树的叶子大小
    p=1,  # 设置Minkowski距离的幂参数为1(曼哈顿距离)
    metric='euclidean'  # 使用欧氏距离作为距离度量
)

knn.fit(X_train, y_train)
# 预测和评估
y_pred = knn.predict(X_test)
print("Accuracy:", accuracy_score(y_test, y_pred))

# 使用pickle保存模型
with open('model.pkl', 'wb') as file:
    pickle.dump(knn, file)

# 使用pickle加载模型
with open('model.pkl', 'rb') as file:
    loaded_model = pickle.load(file)

# 使用加载的模型进行预测
predictions = loaded_model.predict(X_test)
print(predictions)

2、joblib

joblib是一个非常适合于大数据的序列化工具,特别是对于包含大量numpy数组的数据结构。由于机器学习模型往往包含大量的numpy数组,因此joblib在保存和加载机器学习模型方面比pickle更加高效。如未安装,则可以使用pip install joblib进行安装。

1)保存模型

使用joblib.dump方法可以将训练好的模型保存到文件中。常用参数如下,

参数类型描述

value Python对象 要序列化的Python对象。
可以是几乎任何类型的对象,
如机器学习模型、numpy数组等。
filename 字符串 用于存储序列化对象的文件名。
可以是完整的文件路径。
如果文件名以特定的压缩扩展名结尾
(如.gz.bz2.xz.lzma),
则自动应用相应的压缩。
compress 布尔值/整数/元组 控制文件压缩的选项。
如果为布尔值且为True,
使用默认的压缩方式(通常是compress=3)。
如果为整数,指定压缩级别(0-9)。
也可以是(compressor, level)
形式的元组来精确控制压缩方式和级别。
protocol 整数 指定用于序列化的pickle协议版本。
如果为None,则使用默认的pickle协议。
通过指定协议,
可以帮助保持与旧版本Python的兼容性。
cache_size (已弃用) 此参数在最新版本的joblib中已不再使用。

使用代码:

import numpy as np
from joblib import dump

# 创建一个Numpy数组
array = np.arange(100).reshape(10, 10)

# 保存数组到磁盘,不使用压缩
dump(array, 'array.joblib')

# 使用压缩保存数组到磁盘
# compress参数可以是一个整数,指定压缩级别。这里使用3作为压缩级别的示例。
dump(array, 'array_compressed.joblib', compress=3)

# 使用更细粒度的压缩控制,指定压缩方式和级别
# 例如,使用gzip压缩方式,压缩级别为9(最大压缩)
dump(array, 'array_finely_compressed.joblib', compress=('gzip', 9))
# 打开文件
with open('array_finely_compressed.joblib', 'r' ,errors='ignore') as f:
    # 读取文件内容
    content = f.read()

# 打印文件内容
print(content)

2)加载模型

使用joblib.load方法可以从文件加载之前保存的模型。

import numpy as np
from joblib import dump,load

# 创建一个Numpy数组
array = np.arange(100).reshape(10, 10)

# 保存数组到磁盘,不使用压缩
dump(array, 'array.joblib')

# 使用压缩保存数组到磁盘
# compress参数可以是一个整数,指定压缩级别。这里使用3作为压缩级别的示例。
dump(array, 'array_compressed.joblib', compress=3)

# 使用更细粒度的压缩控制,指定压缩方式和级别
# 例如,使用gzip压缩方式,压缩级别为9(最大压缩)
dump(array, 'array_finely_compressed.joblib', compress=('gzip', 9))
# 打开文件
with open('array_finely_compressed.joblib', 'r' ,errors='ignore') as f:
    # 读取文件内容
    content = f.read()

# 打印文件内容
print(content)
# 打开包含序列化数据的文件
with open('array_finely_compressed.joblib', 'rb') as file:
    my_object = load(file)
    print(my_object)

3)使用joblib保存和加载模型

from joblib import dump, load
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.datasets import load_iris
from sklearn.metrics import accuracy_score

# 加载示例数据集,如鸢尾花数据集
iris = load_iris()
X = iris.data
y = iris.target

# 使用train_test_split函数划分数据集
# 测试集占比30%,保持类别比例,设置随机种子为42以确保结果一致性
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42, stratify=y)

# 创建KNN分类器实例
knn = KNeighborsClassifier(
    n_neighbors=3,  # 设置邻居的数量为3
    weights='distance',  # 设置权重为距离的倒数
    algorithm='kd_tree',  # 使用KD树算法
    leaf_size=40,  # 设置KD树/球树的叶子大小
    p=1,  # 设置Minkowski距离的幂参数为1(曼哈顿距离)
    metric='euclidean'  # 使用欧氏距离作为距离度量
)

knn.fit(X_train, y_train)
# 预测和评估
y_pred = knn.predict(X_test)
print("Accuracy:", accuracy_score(y_test, y_pred))

# 使用pickle保存模型
with open('model.joblib', 'wb') as file:
    dump(knn, file)

# 使用pickle加载模型
with open('model.joblib', 'rb') as file:
    loaded_model = load(file)

# 使用加载的模型进行预测
predictions = loaded_model.predict(X_test)
print(predictions)

3、特定库的保存和加载机制

TensorFlow / Keras 使用model.save(filepath)保存模型,使用keras.models.load_model(filepath)加载模型。PyTorch使用torch.save(model.state_dict(), filepath)保存模型的状态字典,使用model.load_state_dict(torch.load(filepath))加载模型状态字典。每种方法都有其特定的使用场景和优缺点,选择合适的方法可以帮助更有效地管理和复用机器学习模型。

 参考文档:Python 机器学习 模型保存和加载-CJavaPy

标签:Python,模型,joblib,序列化,pickle,加载
From: https://www.cnblogs.com/tinyblog/p/18017543

相关文章

  • 博客园跳转编辑页面没有重新加载页面 pushState
    博客园前端是用angular写的全局搜索pushState,打断点,可以看到 pushState main.6267e7d35558bee5.is:1gomain.6267e7d35558bee5.js:1setBrowserUrl main.6267e7d35558bee5.js:1 setBrowserUrl(p,I){constQ=this.urlSerializer.serialize(p)......
  • 关于thrift python接口和java通信出现问题解决
    真的无语,搞了一个下午。使用thrift出现错误,先说一下遇到第一个错误,如下图:那时候代码是这叼样```if__name__=='__main__':handler=MessageServiceHandler()processor=MessageService.Processor(handler)transport=TSocket.TServerSocket(None,"9090"......
  • 书生开源大模型训练营-第3讲-笔记
    1、大模型的局限性a、知识只能截止到训练时间;b、垂直领域的专业能力有限;c、训练成本高,定制化成本高; 2、解决大模型局限性的两种思路RAGVSFTRAG:外挂一个知识库,通过检索得到文档,再将检索到文档和问题一起输入给大模型来生成答案。优点:成本极低、知识可更新;缺点:受限于基座大......
  • 大模型语言与AI
    目录大模型语言与AI什么是大模型语言?什么是AI?AI和大模型语言的区别什么是GPT?GPT的迭代以及每一代的区别GPT-1GPT-2GPT-3GPT-4Sora其他的AI应用场景及对应AI产品如何把握GPT及类似大模型技术带来的机会如何利用TensorFlow微调模型相关链接大模型语言与AI什么是大模型语言?大模型......
  • Sora技术报告 视频生成模型作为世界模拟器 笔记
    Sora技术报告视频生成模型作为世界模拟器笔记技术报告原题目叫做Videogenerationmodelsasworldsimulators,翻译一下就是视频生成模型作为世界模拟器,地址在这里。我写的时候是翻译和笔记并行,翻译感谢gpt4出色的翻译能力。这篇博客介绍了OpenAI在视频数据上大规模训练生......
  • 第 8章 Python 爬虫框架 Scrapy(下)
    第8章Python爬虫框架Scrapy(下)8.1Scrapy对接Selenium有一种反爬虫策略就是通过JS动态加载数据,应对这种策略的两种方法如下:分析Ajax请求,找出请求接口的相关规则,直接去请求接口获取数据。使用Selenium模拟浏览器渲染后抓取页面内容。8.1.1如何对接单独使用Sc......
  • RabbitMQ 消息模型
         参考链接:  https://blog.csdn.net/qq_40991313/article/details/126801025?spm=1001.2014.3001.5501 3.5.3.总结描述下Direct交换机与Fanout交换机的差异?   Fanout交换机将消息路由给每一个与之绑定的队列   Direct交换机根据RoutingKey判断路由......
  • Python 装饰器
    Python装饰器装饰器原理定义本质是函数,用来装饰其他函数,为其他函数添加附加功能原则不能修改被装饰函数的源代码不能修改被装饰的函数的调用方式实现装饰器知识储备函数就是变量高阶函数把一个函数当作实参传给另外一个函数,在不修改被装饰函数源代码情况下为......
  • dlt开源数据加载工具
    dlt是一个开源数据加载工具,基于python开发特点一个库 dlt就是一个python包,其他地方需要我们自己开发非黑盒系统 我们可以基于代码灵活的进行自定义开发基于乘法的玩法,而不是加法自动代码生成 包含了类似dbt的一些处理cli基于python的玩法 dlt对于数据的处理是基于......
  • python - flask wsgi
    直接使用flask自带的wsgi,关闭debug模式会出现以下警告fromflaskimportFlaskapp=Flask(__name__,static_folder="./static")app.run(host="0.0.0.0",port=8080,debug=False)#WARNING:Thisisadevelopmentserver.Donotuseitinaproductiondeployme......