首页 > 其他分享 >亦菲喊你来学机器学习(9) --逻辑回归实现手写数字识别

亦菲喊你来学机器学习(9) --逻辑回归实现手写数字识别

时间:2024-08-23 09:53:41浏览次数:12  
标签:-- 模型 来学 train 0.99 0.98 test 手写 250

文章目录

逻辑回归

逻辑回归(Logistic Regression)虽然是一种广泛使用的分类算法,但它通常更适用于二分类问题。然而,通过一些策略(如一对多分类,也称为OvR或One-vs-Rest),逻辑回归也可以被扩展到多分类问题,如手写数字识别(通常是0到9的10个类别)。

本篇我们就来尝试一下如何通过逻辑回归来实现手写数字识别

  1. 训练模型
  2. 测试模型

实现手写数字识别

训练模型

  1. 收集数据

在这里插入图片描述

  1. 读取图片

使用opencv处理图片,将图片的像素数值读取进来,并返回的是一个三维(高,宽,颜色)numpy数组:

 pip install opencv-python==3.4.11.45
import cv2
img = cv2.imread("digits.png")
  1. 转为灰度图

将图片转化为灰度图,从而让三维数组变成二位的数组:

grey = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)
  1. 处理图片信息

对图片进行处理:将其先垂直切分(横向)成50份,再将每一份水平切分(竖向)成100份,这样我们的每份图片的像素值都为20*20(训练的图片比较规范)共500个,比如:

在这里插入图片描述

import numpy as np
img_info = [np.hsplit(row,100) for row in np.vsplit(grey,50)]
  1. 装进array数组

将切分的每一份图片像素数据都装进array数组中:

x = np.array(img_info)
  1. 分隔训练集与测试集

将数据竖着分隔一半,一半作为训练集,一般作为测试集:

train_x = x[:,:50]
test_x = x[:,50:100]
  1. 调整数据结构

由于我们最后要将数据放在逻辑回归模型中训练,我们得将数据结构调整为适合逻辑回归算法训练的结构,那么我们就来改变每份图片数组的维度:reshape:

new_train_x = train_x.reshape(-1,400).astype(np.float32)
new_test_x = test_x.reshape(-1,400).astype(np.float32)
  1. Z-score标准化

逻辑回归算法进行手写数字识别时,对数据进行标准化是为了提高优化算法的收敛速度、提升模型的预测性能,并避免潜在的数值问题。将数据都进行表示话,避免参数的影响:

from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
fin_train = scaler.fit_transform(new_train_x)
fin_test = scaler.fit_transform(new_test_x)
  1. 分配标签

我们训练着那么多的数据,却没有给他们具体的类别标签(图像的实际值),因为我们之前的图像处理都是在寻找图像特征,但是并没有给他们一个具体对应的类别,只有空荡荡的特征,无法分类,所以我们得给切分的每份图片打上它们对应的标签:

k = np.arange(10)
train_y = np.repeat(k,250)
test_y = np.repeat(k,250)
train_y = train_y.ravel()
  1. 交叉验证

在逻辑回归的算法中,逻辑模型的参数中,有一参数为正则化强度C,越小的数值表示越强的正则化。我们要进行调参数,看看哪个惩罚因子最为合适,使模型拟合效果更好:

from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_score

#交叉验证选择较优的惩罚因子
scores = []
c_param_range = [0.01,0.1,1,10,100] #参数:一般常用的惩罚因子

for i in c_param_range:
    lr = LogisticRegression(C = i,penalty='l2',solver='lbfgs',max_iter=1000,random_state=0)
    # C表示正则化强度,越小的数值表示越强的正则化。防止过拟合
        score = cross_val_score(lr,fin_train,train_y,cv=10,scoring='recall_macro')
    #交叉验证,将模型和数据集传入,对其进行划分,每份轮流作为测试集来测试模型。返回一个列表对象
    score_mean = sum(score)/len(score)
    scores.append(score_mean)
c_choose = c_parma[np.argmax(scores)] #argmax取出最大值的索引位置
  1. 训练模型
lr_model = LogisticRegression(C = c_choose,max_iter=1000,random_state=0)
lr_model.fit(fin_train,train_y)

测试模型

  1. 先用训练数据再次进入模型测试,查看他本身的模型训练效果怎么样:
from sklearn import metrics
train_predict = lr_model.predict(fin_train)
print(metrics.classification_report(train_y,train_predict))  #查看混淆矩阵
-------------------------------
              precision    recall  f1-score   support

           0       0.99      1.00      0.99       250
           1       0.98      1.00      0.99       250
           2       1.00      0.98      0.99       250
           3       0.98      0.98      0.98       250
           4       1.00      1.00      1.00       250
           5       0.98      0.98      0.98       250
           6       0.99      1.00      1.00       250
           7       0.98      0.99      0.98       250
           8       0.98      0.99      0.99       250
           9       0.99      0.97      0.98       250

    accuracy                           0.99      2500
   macro avg       0.99      0.99      0.99      2500
weighted avg       0.99      0.99      0.99      2500
  1. 再用分割的测试集来测试模型:
test_predict = lr_model.predict(fin_test)
print(metrics.classification_report(test_y,test_predict))
---------------------------
              precision    recall  f1-score   support

           0       0.95      0.96      0.95       250
           1       0.94      0.96      0.95       250
           2       0.88      0.86      0.87       250
           3       0.90      0.86      0.88       250
           4       0.92      0.84      0.88       250
           5       0.84      0.90      0.87       250
           6       0.92      0.95      0.93       250
           7       0.89      0.93      0.91       250
           8       0.89      0.84      0.86       250
           9       0.83      0.86      0.85       250

    accuracy                           0.90      2500
   macro avg       0.90      0.90      0.89      2500
weighted avg       0.90      0.90      0.89      2500

到这为止!!我们就训练好一个关于手写数字识别的逻辑回归模型啦!!

总结

本篇介绍了如何用逻辑回归算法实现手写数字识别:

  1. 逻辑回归更适合二分类算法,但是也可以通过一些策略,扩展到多分类问题。
  2. 注意要将读取的数据进行标准化操作,灰度图图片数据相差过大。
  3. 学会调整参数,优化模型,比如本篇在交叉验证中找寻最优的惩罚因子。

标签:--,模型,来学,train,0.99,0.98,test,手写,250
From: https://blog.csdn.net/m0_74896766/article/details/141458076

相关文章

  • ICCEMDAN+皮尔逊+小波分解降噪+重构
    ICEEMDAN+皮尔逊+小波分解降噪+重构代码获取戳此处ICEEMDAN(改进的CEEMDAN)原理:ICEEMDAN是由Colominas等人提出的信号处理方法,它是在自适应噪声完全集合经验模态分解(CEEMDAN)的基础上发展而来。与CEEMDAN不同,ICEEMDAN在分解过程中不是直接添加高斯白噪声,而是选取白噪声被E......
  • 听劝❗用AI做职场思维导图仅仅需要几秒钟啊
    本文由ChatMoney团队出品嘿,各位职场朋友们是不是常常对着密密麻麻的笔记感到焦虑呢?想整理却无从下手?别怕,ChatmoneyAI知识库来拯救你的整理困难症啦!咱们都知道,思维导图是职场中必备的神器它能帮我们理清思路,记忆知识但传统做法嘛,不是画得乱七八糟就是费时费力,真心不方便......
  • 配置PXE预启动执行环境:使用PXE装机服务器网络引导装机
    文章目录PXE概述PXE批量部署的优点基本的部署过程搭建的前提条件搭建配置PXE装机服务器1.准备CentOS7安装源(YUM仓库)2.安装并启用TFTP服务3.安装并启用DHCP服务4.准备Linux内核和初始化镜像文件5.准备PXE引导程序6.安装FTP服务并准备CentOS7安装......
  • 《深海迷航:零度之下》user32.dll丢失导致游戏无法运行实用解决方法
    当你遇到《深海迷航:零度之下》(Subnautica:BelowZero)因缺少user32.dll文件而无法正常启动的问题时,可以尝试以下几种解决方法:了解问题user32.dll是一个Windows系统文件,包含了大量用于处理窗口和对话框的函数。如果游戏启动器或游戏本身需要这个文件而找不到它,就会出现错误......
  • 导入事件至苹果日历
    1.生成日期fromdatetimeimportdatetime,timedelta#设置开始日期和结束日期start_date=datetime(2024,8,23)end_date=datetime(2024,10,30)#列表用于存储每个周期的最后两天的日期result_dates=[]current_date=start_date-timedelta(days=1)while......
  • 学习分享:如何学习 API 中的数据格式
    以下是学习API中数据格式的要点:一、了解常见数据格式JSON(JavaScriptObjectNotation):结构特点:它是一种轻量级的数据交换格式,易于人阅读和编写,也易于机器解析和生成。JSON数据格式由键值对组成,类似于Python中的字典或者JavaScript中的对象。例如:{"name":"John",......
  • 拍立淘API在商品搜索中的应用实践案例
    拍立淘API在商品搜索中的应用具有多方面的优势和价值,以下为您详细介绍:精准匹配商品:原理:利用先进的图像识别技术,对用户上传的商品图片进行分析,提取图像中的特征信息,如颜色、形状、纹理等。然后将这些特征与商品数据库中的商品图像特征进行比对和匹配。示例:比如用户看到一......
  • MyBatis 源码解读:专栏导读与学习路线
    前言MyBatis是Java开发中广泛使用的持久层框架,其简洁的配置和强大的功能使得它在开发人员中备受欢迎。然而,MyBatis的背后隐藏着许多设计巧妙的架构和复杂的实现逻辑。通过源码解读,我们可以更深入地理解MyBatis的设计思想和工作原理,从而更好地应用它。本专栏将以源码......
  • 小红书全能实战营:精准定位,爆款打造,实现轻松涨粉变现之旅
    课程目录:1.[开营仪式]小红书训练营盛大开幕_.mp42.[直播精讲]第一篇章:精准定位与个性化包装,打造独特IP.mp43.[直播赋能]第二篇章:选题与标题的艺术,吸引眼球的秘诀.mp44.[互动答疑·上]专场,解答你的小红书成长疑惑_.mp45.[互动答疑·下]继续坐镇,深度剖析小红书运......
  • Spring 源码解读专栏:从零到一深度掌握 Spring 框架
    前言Spring是Java世界中无可争议的王者框架,它以其灵活、轻量、强大而著称,成为企业级开发的首选工具。然而,很多开发者在使用Spring时,往往只停留在会用的层面,对于其内部实现和设计原理知之甚少。本专栏旨在通过系统化的Spring源码解读,从实践到源码分析,再到设计模式的......