首页 > 其他分享 >机器学习day03

机器学习day03

时间:2024-06-19 20:11:32浏览次数:29  
标签:机器 模型 day03 train 学习 estimator test todo image

机器学习day03

超参数选择方法--交叉验证、网格搜索、手写数字识别案例

1交叉验证

1.1 什么是交叉验证?

是一种数据集的分割方法,将训练集划分为 n份,拿一份做验证集 (测试集)、其他n-1份做训练集

1.2交叉验证法原理:将数据集划分为 cv=4

  1. 第一次:把第一份数据做验证集,其他数据做训练
  2. 第二次:把第二份数据做验证集,其他数据做训练
  3. ... 以此类推,总共训练4次,评估4次。
  4. 使用训练集+验证集多次评估模型,取平均值做交叉验证为模型得分
  5. 若k=5模型得分最好,再使用全部训练集(训练集+验证集) 对k=5模型再训练 一边,再使用测试集对k=5模型做评估

交叉验证法,是划分数据集的一种方法,目的就是为了得到更加准确可信的模型评分。

2.网格搜索

2.1为什么需要网格搜索?

  • 模型有很多超参数,其能力也存在很大的差异。需要手动产生很多超参数组合,来训练模型
  • 每组超参数都采用交叉验证评估,最后选出最优参数组合建立模型。

网格搜索是模型调参的有力工具。寻找最优超参数的工具! 只需要将若干参数传递给网格搜索对象,它自动帮我们完成不同超参数的组合、模型训练、模型评估, 最终返回一组最优的超参数。

网格搜索 + 交叉验证的强力组合 (模型选择和调优)

  • 交叉验证解决模型的数据输入问题(数据集划分)得到更可靠的模型
  • 网格搜索解决超参数的组合
  • 两个组合再一起形成一个模型参数调优的解决方案

3.交叉验证网格搜索 – API和应用举例

3.1交叉验证网格搜索API介绍

sklearn.model_selection.GridSearchCV(estimator,param_grid=None,cv= None)
对估计器的指定参数值进行详尽的搜索
estimator:估计对象
param_grid:估计器参数(dict){"n_neighbors":[1,3,5]}
cv:指定几折交叉验证
fit:输入训练数据
score:准确率
结果分析:
	bestscore_:在交叉验证中的最好结果
    bestestimator:最好的参数模型
    cvresults:每次交叉验证后的验证集准确结果和训练集准确率结果

4.案例

4.1利用KNN算法对鸢尾花分类 – 交叉验证网格搜索

#!/usr/bin/env python
# -*- coding: UTF-8 -*-
'''
----------------------------------
@Project :pythonclass 
@File    :cv_iris_05.py
@IDE     :PyCharm 
@Author  :chizhayuehaiyuyumao
@Date    :2024/6/19 11:15 
----------------------------------
    Deos TODO 鸢尾花案例,补充超参调优
    # todo 1.导包
    # todo 2.加载数据集
    # todo 3.数据基本处理
    # todo 3.1 划分数据集并
    # todo 3.2 划分数据集并
    # todo 4.实例化模型
    # todo 5.超参调优
    # todo 6.模型预测
----------------------------------
'''
# todo 1.导包

from sklearn.neighbors import KNeighborsClassifier  # 导入knn分类算法
from sklearn.datasets import load_iris  # 导入鸢尾花数据集
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split,GridSearchCV

# todo 2.加载数据集
iris_data = load_iris()
print(iris_data.DESCR)

# todo 3.数据基本处理
x = iris_data.data
y = iris_data.target
# todo 3.1划分数据集
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=22)

# todo 3.2 数据预处理(标准化)
scaler = StandardScaler()
x_train = scaler.fit_transform(x_train)
x_test = scaler.transform(x_test)  # 为什么不用fit 是因为可以直接使用已经拟合好的函数

# todo 4.实例化模型
estimator = KNeighborsClassifier()

# todo 5.超参调优
estimator = GridSearchCV(estimator, param_grid={'n_neighbors': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]}, cv=5)
estimator.fit(x_train, y_train)

print(f'最佳参数:{estimator.best_params_}')
print(f'最佳分数:{estimator.best_score_}')
print(f'最佳模型:{estimator.best_estimator_}')

# todo 6.模型预测
estimator = estimator.best_estimator_
score = estimator.score(x_test, y_test)
print(f'准确率:{score}')

4.2利用KNN算法实现手写数字识别

已知数据

  • MNIST手写数字识别
  • 1999年发布,成为分类算法基准测试的基础
  • MNIST仍然是研究人员和学习者的可靠资源

需求:

  • 从数万个手写图像的数据集中正确识别数字

数据介绍:

  1. 数据文件 train.csv 和 test.csv 包含从 0 到 9 的手绘数字的灰度图像。
  2. 每个图像高 28 像素,宽28 像素,共784个像素。
  3. 每个像素取值范围[0,255],取值越大意味着该像素颜色越深
  4. 训练数据集(train.csv)共785列。 第一列为 "标签",为该图片对应的手写数字。其余784列为该图像的像素值
  5. 训练集中的特征名称均有pixel前缀,后面的数字([0,783])代表了像素的序号。

图片处理:

from PIL import Image, ImageOps
import numpy as np


def process_image(input_path, output_path):
    # 1. 加载图像
    image = Image.open(input_path)

    # 2. 转换为灰度图像
    gray_image = image.convert("L")

    # 3. 二值化处理
    threshold = 128
    binary_image = gray_image.point(lambda p: p > threshold and 255)

    # 4. 调整尺寸到28x28像素
    resized_image = binary_image.resize((28, 28), Image.LANCZOS)

    # 5. 转换为黑底白字
    inverted_image = ImageOps.invert(resized_image)

    # 6. 保存处理后的图像
    inverted_image.save(output_path)
    inverted_image.show()


# 示例调用
input_image_path = r"D:\pycharm\pythonclass\jiqixuexi\first\pythonProject2\data\test_9.png"
output_image_path = r"D:\pycharm\pythonclass\jiqixuexi\first\pythonProject2\data\picture_out\test_new.png"
process_image(input_image_path, output_image_path)

训练模型:

def knn_train():
    # todo 2.加载数据集并查看
    data = pd.read_csv(r'D:\pycharm\pythonclass\jiqixuexi\first\pythonProject2\data\手写数字识别.csv')
    # print(data.head(5))

    # todo 3.数据基本处理
    # todo 3.1归一化
    scaler = MinMaxScaler()
    x = scaler.fit_transform(data.iloc[:, 1:])
    y = data.iloc[:,0]

    # todo 3.2 划分数据集
    x_train,x_test,y_train,y_test = train_test_split(x,y,test_size=0.2,random_state=20)

    # todo 4.实例化模型
    estimator = KNeighborsClassifier()
    # todo 5.超参调优
    param = {'n_neighbors':range(3,10,2)}
    estimator = GridSearchCV(estimator,param_grid=param,cv=5)
    estimator.fit(x_train,y_train)

    # todo 6.模型预测
    estimator = estimator.best_estimator_
    score = estimator.score(x_test,y_test)
    print(f'准确率:{score}')

    # todo 7.模型保存
    joblib.dump(estimator,r'D:\pycharm\pythonclass\jiqixuexi\first\pythonProject2\data\knn_digit.pth')
if __name__ == '__main__':
    knn_train()

结果预测:

def knn_predict():
    # todo 1.加载图片
    img = plt.imread(r'D:\pycharm\pythonclass\jiqixuexi\first\pythonProject2\data\picture_out\test_new.png')

    # todo 2.图片转换成矩阵
    img = img.reshape(1,-1)

    # todo 3.加载模型
    estimator = joblib.load(r'D:\pycharm\pythonclass\jiqixuexi\first\pythonProject2\data\knn_digit.pth')

    # todo 4.预测图片
    y_pre =  estimator.predict(img)
    print(f'预测结果:{y_pre}')

if __name__ == '__main__':
    knn_predict()

标签:机器,模型,day03,train,学习,estimator,test,todo,image
From: https://www.cnblogs.com/luoxuezhixing/p/18256342

相关文章

  • 论如何使用机器学习,预测客户流失率,轻松实现客户精准维护
    01、案例说明首先我们学习最经典的机器学习模型,就是监督学习(SupervisedLearning)中的分类模型。这边使用的是一个电信公司的案例,通过客户的基本资料和一些简单的互动信息,建立一个模型,以预测哪些客户有较高的可能性流失,从而进行补救。因为研究显示得到一个新客户的成本是维......
  • Javascript入门博客【入门复习(学习)使用】
    JavaScript是一门高级,解释形语言,大量用于关于web网站的开发,可以和网页联动做出更多有趣的动画效果。其运行方式大都是嵌入在网页中运行。其实在定义方面如果过你是初学者来学习和这方面相关的知识,知道上面这些就已经足够了。我们可以在浏览器中直接进行对代码的控制,进入浏览器......
  • mybits学习1
    所花时间(包括上课): 2h代码量(行): 150左右搏客量(篇): 1了解到的知识点:mybits备注(其他): private static SqlSessionFactorysqlSessionFactory;   static {       try {           Stringresource= "mybati......
  • mybits学习2
    所花时间(包括上课): 2h代码量(行): 150左右搏客量(篇): 1了解到的知识点:mybits备注(其他): @TestpublicvoidaddStudent(){SqlSessionsqlSession=mybatisUtil.getSqlSession();studentMapperstuMapper=sqlSession.getMapper(studen......
  • python学习3
    所花时间(包括上课): 2h代码量(行): 150左右搏客量(篇): 1了解到的知识点:python备注(其他): 破解百度翻译importrequestsimportjsonif__name__=='__main__':#UA伪装:让爬虫对应的请求载体身份标识伪装成某一款浏览器header......
  • C++学习(22)
    #学习自用#计时计时可以计算出执行代码时花费了多长时间,对于同样的目的,我们可以通过不同的代码实现,而执行时间长短是评价一串代码性能如何的指标。#include<iostream>#include<string>#include<chrono>#include<thread>usingnamespacestd;intmain(){ autostar......
  • 深度学习原理
    1简介        AIGC(ArtificialIntelligenceGeneratedContent,即人工智能生成内容)是一种利用人工智能技术自动创建文本、图像、音频和视频等内容的技术。AIGC的核心是通过机器学习和深度学习算法,让计算机模型学会理解和生成人类语言,从而能够自动产生有价值的内容。......
  • DevOps学习回顾02-实践的通用路径-需求分析的拆解-CI的理解-质量体系的实践路径
    参考来源:极客时间专栏:DevOps实战笔记,作者:石雪峰课程链接:https://time.geekbang.org/column/intro/235DevOps学习回顾02-实践的通用路径-需求分析的拆解-CI的理解-质量体系的实践路径DevOps实践的通用路径第一步:寻找合适的试点项目一个合适的项目应该具备以下几个特......
  • 机器学习(一)
    机器学习1.机器学习概述1.1人工智能概述1.1.1机器学习与人工智能、深度学习的关系1.1.2人工智能的起点1.1.3机器学习、深度学习能做什么?1.2什么是机器学习?1.2.1定义1.2.2数据集的构成1.3机器学习算法1.4机器学习开发流程2.特征工程2.1数据集2.1.1可用数据......
  • nodejs学习08——会话控制 session cookie token
    会话控制一、介绍所谓会话控制就是对会话进行控制HTTP是一种无状态的协议,它没有办法区分多次的请求是否来自于同一个客户端,无法区分用户而产品中又大量存在的这样的需求,所以我们需要通过会话控制来解决该问题常见的会话控制技术有三种:cookiesessiontoken二、cooki......