首页 > 编程语言 >《python机器学习从入门到高级》

《python机器学习从入门到高级》

时间:2024-07-05 14:57:49浏览次数:20  
标签:plt 机器 入门 python 分类 train import mnist sklearn

《python机器学习从入门到高级》分类算法:

引言

我们在之前的文章已经介绍了机器学习的一些基础概念,当拿到一个数据之后如何处理、如何评估一个模型、以及如何对模型调参等。接下来,我们正式开始学习如何实现机器学习的一些算法。 回归和分类是机器学习的两大最基本的问题,对于分类算法的详细理论部分。 本文主要从python代码的角度来实现分类算法。

# 导入相关库
import sklearn
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

  1. 数据准备
    =========

下面我们以mnist数据集为例进行演示,这是一组由美国人口普查局的高中生和雇员手写的70000个数字图像。每个图像都用数字表示。也是分类问题非常经典的一个数据集

# 导入mnist数据集
from sklearn.datasets import fetch_openml
mnist = fetch_openml('mnist_784', version=1, as_frame=False)
mnist.keys()

dict_keys(['data', 'target', 'frame', 'categories', 'feature_names', 'target_names', 'DESCR', 'details', 'url'])

其中data是我们输入的特征,target0-9的数字

X, y = mnist["data"], mnist["target"]
X.shape,y.shape

((70000, 784), (70000,))

可以看出一共有70000图像,其中X一共有784个特征,这是因为图像是28×28的,每个特征是0-255之间的。下面我们通过imshow()函数将其进行还原

%matplotlib inline
import matplotlib as mpl
digit = X[0]
digit_image = digit.reshape((28, 28))#还原成28×28
plt.imshow(digit_image, cmap=mpl.cm.binary)
plt.axis("off")
plt.savefig("some_digit_plot")
plt.show()

图片


从我们人类角度来看,我们很容易辨别它是5,我们要做的是,当给机器一张图片时,它能辨别出正确的数字吗?我们来看看y的值

我们要实现的就是,给我们一张图片,不难发现这是一个多分类任务,下面我们正式进入模型建立,首先将数据集划分为训练集和测试集,这里简单的将前60000个划分为训练集,后10000个为测试集,具体代码如下

y = y.astype(np.uint8)#将y转换成整数
X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]

2.简单二元分类实现

在实现多分类任务之前,我们先从一个简单的问题考虑,现在假设我只想知道给我一张图片,它是否是7(我最喜欢的数字)。这个时候就是一个简单的二分类问题,首先我们要将我们的目标变量进行转变,具体代码如下

y_train_7 = (y_train == 7)
y_test_7 = (y_test == 7)

现在,我们选择一个分类器并对其进行训练。我们先使用SGD(随机梯度下降)分类器

from sklearn.linear_model import SGDClassifier
sgd_clf = SGDClassifier(max_iter=1000, tol=1e-3, random_state=123)#设置random_state为了结果的重复性
sgd_clf.fit(X_train, y_train_7)

SGDClassifier(random_state=123)

训练好模型之后我们可以进行预测,以第一张图片为例,我们预测一下它是否是7(很显然我们知道不是)

sgd_clf.predict(X[0].reshape((1,-1)))

array([False])

可以看出判断正确了,在之前我们讨论了模型评估的方法,详细介绍看这篇文章:Python机器学习从入门到高级:模型评估和选择(含详细代码) 下面演示如何用代码实现各个评估指标

3.模型评估

我们根据分类评估指标来看看SGD分类器效果

3.1 准确率

from sklearn.model_selection import cross_val_score
cross_val_score(sgd_clf, X_train, y_train_7, cv=3, scoring="accuracy")

array([0.97565, 0.97655, 0.963  ])

3.2 混淆矩阵

y_train_pred = sgd_clf.predict(X_train)

from sklearn.metrics import confusion_matrix
confusion_matrix(y_train_7, y_train_pred)

array([[53304,   431],
       [  550,  5715]], dtype=int64)

3.3 召回率和精确度

from sklearn.metrics import precision_score, recall_score

print('precision:',precision_score(y_train_7, y_train_pred))
print('recall:',recall_score(y_train_7,y_train_pred))

precision: 0.929873088187439
recall: 0.9122106943335994

下面要用的matplotlib,想了解matplotlib可以看这篇文章:Python数据可视化大杀器之地阶技法:matplotlib(含详细代码)

3.4 ROC曲线

from sklearn.metrics import roc_curve

fpr, tpr, thresholds = roc_curve(y_train_7, y_scores)
plt.plot(fpr, tpr, linewidth=2)
plt.plot([0, 1], [0, 1], 'k--') 
plt.axis([0, 1, 0, 1])                                   
plt.xlabel('False Positive Rate (Fall-Out)', fontsize=16) 
plt.ylabel('True Positive Rate (Recall)', fontsize=16)    
plt.grid(True)                  


在这里插入图片描述

本章的介绍到此介绍,下一章介绍分类算法(下):如何完成多分类任务

本文借鉴网络,如有侵权,请联系删除。

标签:plt,机器,入门,python,分类,train,import,mnist,sklearn
From: https://blog.csdn.net/2401_86168842/article/details/140178961

相关文章

  • python数据结构(树和二叉树)
    树非线性结构一对多根结点(无前驱)多个叶子结点(无后继)其他数据元素(一个前驱,多个后驱)树与二叉树转换树与二叉树均可用二叉链表作为存储结构,则以二叉链表为媒介可导出树之间的一个对应关系-----即给定一颗树,可以找到唯一一颗二叉树与之对应。把树转化为二叉树步骤一:加线......
  • Box,一个字典操作python库
     Box介绍Box是一个让字典操作变得异常简单与直观,支持通过属性访问字典内容的库。 特点概述属性访问Box允许用户像访问对象属性一样访问字典的值,提升了代码的可读性和易用性。无缝嵌套自动将嵌套的字典转换为Box对象,使得处理复杂字典结构变得轻而易举。灵活性......
  • Python速通(条件语句)
    (牛牛的选择)牛牛在牛客网经过了两次笔试分别获得了Tencent和Alibaba的面试资格,不巧的是这两次面试的时间冲突了。两家公司牛牛都想去,他决定通过笔试的成绩判断去参加哪家公司的面试。现在输入两行浮点数,分别表示牛牛在Tencent和Alibaba的笔试成绩,请比较两个成绩,输出笔试成绩较高的......
  • 小白也能看懂的Python基础教程(9)
    目录Python文件操作1、文件操作概述什么是文件?文件操作包含哪些内容呢?文件操作的作用2、文件的基本操作open()打开函数mode访问模式详解读操作相关方法read()方法:readlines()方法:readline()方法:file读取文件之readfile读取文件之readlines和reanline相对和绝对......
  • 全网最全网络安全入门指南(2024版)零基础可学_网络安全学习指南
    下一个十年的饭碗就是它了!据悉,2019年9月27日,工信部发布**《关于促进网络安全产业发展的指导意见(征求意见稿)》,明确提出2025年培育形成一批营收20亿元以上的网络安全企业,网络安全产业规模超过2000亿元的发展目标;据市场调研机构Gartner预测,我国网络安全预计将以......
  • 三菱FX PLC入门之定时器和计数器
    PLC中,定时器和计数器是两个非常主要的编程元件,是PLC程序编制不可或缺的环节。我在之前的文章中简单地扯了一下这两个元件,而现在就是揭秘时刻了,让我们一起来看看它们的庐山真面目吧!一、定时器说到定时器,其实我们生活中就有很多它的应用,例如洗衣机的定时选择,烤箱的定时旋......
  • threejs入门2:Creating a scene
    参考:https://threejs.org/docs/index.html#manual/en/introduction/Creating-a-sceneThegoalofthissectionistogiveabriefintroductiontothree.js.Wewillstartbysettingupascene,withaspinningcube.Aworkingexampleisprovidedatthebottomof......
  • ipython的使用技巧整理
    IPython是一个强大的交互式Python环境,提供了许多高级功能和快捷键,以下是非常详细的IPython使用技巧整理,覆盖了每个知识点(但本文是基于有一定基础的同学看的):IPython的使用基础:一、安装与基本操作安装Anaconda建议直接下载安装Anaconda,其中包含丰富的库,以及我们需要使用......
  • python爬取的数据存放在哪
    大家好,本文将围绕python数据爬取有哪些库和框架展开说明,python爬取数据保存到数据库是一个很多人都想弄明白的事情,想搞清楚python爬取数据存入数据库需要先了解以下几个事情。经常游弋在互联网爬虫行业的程序员来说,如何快速的实现程序自动化,高效化都是自身技术的一种沉淀的......
  • python作业题百度网盘,python大作业总结
    大家好,小编来为大家解答以下问题,python作业题百度网盘,python大作业总结,现在让我们一起来看看吧!大家好,本文将围绕python大作业代码及文档展开说明,python大作业代码100行是一个很多人都想弄明白的事情,想搞清楚python期末大作业题目需要先了解以下几个事情。大家好,给大家分......