首页 > 其他分享 >机器学习-线性回归-多项式升维-07

机器学习-线性回归-多项式升维-07

时间:2023-12-13 22:33:49浏览次数:36  
标签:升维 07 predict 多项式 poly print train test

目录

1. 为什么要升维

升维的目的是为了去解决欠拟合的问题的,也就是为了提高模型的准确率为目的的,因为当维度不够时,说白了就是对于预测结果考虑的因素少的话,肯定不能准确的计算出模型。

在做升维的时候,最常见的手段就是将已知维度进行相乘来构建新的维度,如下图所示。下图左展示的是线性不可分的情况,下图右通过升维使得变得线性可分。

属于数据预处理的手段,在sklearn模块下它处于sklearn.preprocessing模块下。它的目的就是将已有维度进行相乘,包括自己和自己相乘,来组成二阶的甚至更高阶的维度。


升维后

多项式升维PolynomialFeatures

2 代码实现

import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error

np.random.seed(42)

m = 100
X = 6 * np.random.rand(m, 1) - 3
y = 0.5*X**2 + X + 2 + np.random.randn(m, 1)

X_train = X[:80]
X_test = X[80:]
y_train = y[:80]
y_test = y[80:]

plt.plot(X, y, "b.")
# plt.show()

d = {1: "g-", 2: "r+", 10: "y*"}


for i in d:
    poly_features = PolynomialFeatures(degree=i, include_bias=True)
    X_poly_train = poly_features.fit_transform(X_train)
    X_poly_test = poly_features.fit_transform(X_test)

    # print(X_train[0])
    # print(X_poly_train[0])
    #
    # print(X_train.shape)
    # print(X_poly_train.shape)

    line_reg = LinearRegression(fit_intercept=False)
    line_reg.fit(X_poly_train, y_train)

    y_train_predict = line_reg.predict(X_poly_train)  # 训练集上的预测
    y_test_predict = line_reg.predict(X_poly_test)  # 测试集的预测

    plt.plot(X_poly_train[:, 1], y_train_predict, d[i])

    print(f"degree: {i}, 训练集mse:", mean_squared_error(y_train, y_train_predict))
    print(f"degree: {i}, 测试集mse:", mean_squared_error(y_test, y_test_predict))

plt.show()


3, 总结

无论从下图,还是从上面的评估指标的对比,我们都可以发现使用多项式回归的时候,超参数degree很重要。当我们设置为1的时候欠拟合,当我们设置为10的时候过拟合,当我们设置为2的时候just right

过拟合:训练接表现ok 测试集表现不行

标签:升维,07,predict,多项式,poly,print,train,test
From: https://www.cnblogs.com/cavalier-chen/p/17900094.html

相关文章

  • 0x07.常用windows命令、搭建网站、状态码
    常用windows命令cmdwtcal 计算器control 控制面板winver 查看版本services.msc服务mstsc 远程桌面regedit 注册表ncpa.cpl 网络连接explorer 此电脑netplwiz 用户账户inetmgr IIS控制台判断windows-server版本IIS版本......
  • springboot+vue小白升级之路07-快速实现批量删除、小白升级之路08-实现批量导入导出ex
    我们接着之前的内容,全部代码我贴一下,大家参考使用。数据库droptableifexistsan_user;createtablean_user( idintnotnullauto_incrementprimarykeycomment'主键id', namevarchar(255)notnulluniquecomment'姓名', `password`varchar(255)notnullcomment......
  • [20231207]开发不应该这样写sql4.txt
    [20231207]开发不应该这样写sql4.txt--//最近在优化sql语句,发现另外一种风格,实际上以前也遇到过,感觉这就像一种病,会传染只要一个这样写后面的要么跟进要么--//不改。我觉得开发应该感谢exadata,不然我们的生产系统估计会垮掉。1.环境:XXXXXX>@ver1PORT_STRING          ......
  • 解决 OSError: [WinError -1066598274] Windows Error 0xc06d007e (xjl456852原创)
    异常OSError:[WinError-1066598274]WindowsError0xc06d007e或Processfinishedwithexitcode-1066598274(0xC06D007E)遇到问题:程序在调用PCA方法时,出现上述异常.这种PCA方法使用sklearn中的依赖包.我尝试了pip和mamba重新安装多个依赖包之后问题得到解决(只选择一......
  • TSINGSEE青犀可视化视频云平台JT/T1078接入能力在智慧物流中的应用
    一、引言随着科技的快速发展和全球贸易的蓬勃发展,智慧物流成为了现代物流业的重要发展方向。智慧物流通过引入先进的信息技术,实现了物流过程的自动化、智能化和信息化,从而提高了物流效率和准确性。在这个过程中,JT/T1078接入技术发挥着关键的作用。二、JT/T1078接入技术JT/T1078接入......
  • TSINGSEE青犀可视化视频云平台JT/T1078接入能力在智慧物流中的应用
    一、引言随着科技的快速发展和全球贸易的蓬勃发展,智慧物流成为了现代物流业的重要发展方向。智慧物流通过引入先进的信息技术,实现了物流过程的自动化、智能化和信息化,从而提高了物流效率和准确性。在这个过程中,JT/T1078接入技术发挥着关键的作用。二、JT/T1078接入技术JT/T1078......
  • [-007-]-Python3+Unittest+Selenium Web UI自动化测试之@property装饰器默认值设置
    看示例:#!/usr/bin/python3#coding:utf-8__author__='csjin'#定义@property装饰器classPPTListModels(object):def__init__(self):self._tab_name="PPT模板"@propertydefhandle(self):returnself.__handle......
  • Codeforces Round 807 (Div. 2)
    基本情况AB题秒了。C题搞了半天,搞了一个假的解法,最后还是爆空间了。D题没想下去。C.MarkandHisUnfinishedEssayProblem-C-Codeforces错误分析写出来自己的错解之后没有进一步思考,而是觉得没希望直接做D去了,实则D也没可能半小时写完。我的错解就是预处理好每个......
  • 07 java运行时数据区域
    包含堆、方法区、程序计数器、本地方法栈、虚拟机栈。这就是运行数据区的几个部分。其中堆和方法区是线程共有的,其它数据区域是线程私有的。堆中存储对象数据。方法区中储存类信息、常量及静态变量等信息。方法栈中的栈帧和线程的寿命是一致的,储存方法执行时的相关常量,比如局部变量......
  • CH32V307 ADC与触摸按键的使用
    CH32V307的ADC模块具有两个独立的ADC单元,12位分辨率,支持16个外部通道和2个内部信号源采样。CH32V307的触摸检测单元,借助ADC模块的电压转换功能,通过将电容量转换为电压量进行采样,实现触摸按键检测功能。检测通道复用ADC的16个外部通道,通过ADC模块的单次转换模式实现触摸按键检测。......