首页 > 其他分享 >【scikit-learn基础】--『监督学习』之 LASSO回归

【scikit-learn基础】--『监督学习』之 LASSO回归

时间:2023-12-28 09:03:41浏览次数:19  
标签:error r2 -- pred 模型 scikit learn test LASSO

LASSOLeast Absolute Shrinkage and Selection Operator)回归模型一般都是用英文缩写表示,
硬要翻译的话,可翻译为 最小绝对收缩和选择算子

它是一种线性回归模型的扩展,其主要目标是解决高维数据中的特征选择和正则化问题。

1. 概述

LASSO中,通过使用L1正则化项,它能够在回归系数中引入稀疏性,
也就是允许某些系数在优化过程中缩减为零,从而实现特征的选择。

与岭回归不同的是,LASSO的损失函数一般定义为:\(L(w) = (y-wX)^2+\lambda\parallel w\parallel_1\)
其中 \(\lambda\parallel w\parallel_1\),也就是 L1正则化项(岭回归中用的是 L2正则化项)。

模型训练的过程就是寻找让损失函数\(L(w)\)最小的参数\(w\)。
也就等价于:\(\begin{align} & arg\ min(y-wX)^2 \\ & s.t. \sum |w_{ij}| < s \end{align}\)
这两个公式表示,在满足约束条件 \(\sum |w_{ij}| < s\)的情况下,计算 \((y-wX)^2\)的最小值。

2. 创建样本数据

相比于岭回归模型,LASSO回归模型不仅对于共线性数据集友好,
对于高维数据的数据集,也有不错的性能表现。

它通过将不重要的特征的系数压缩为零,帮助我们选择最重要的特征,从而提高模型的预测准确性和可解释性。
下面我们模拟创建一些高维数据,创建一个特征数比样本数还多的样本数据集。

from sklearn.datasets import make_regression

X, y = make_regression(n_samples=80, n_features=100, noise=10)

这个数据集中,只有80个样本,每个样本却有100个特征,并且噪声也设置的很大(noise=10)。

3. 模型训练

第一步,分割训练集测试集

from sklearn.model_selection import train_test_split

# 分割训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1)

scikit-learn中的LASSO模型来训练:

from sklearn.linear_model import Lasso

# 初始化LASSO线性模型
reg = Lasso()
# 训练模型
reg.fit(X_train, y_train)

这里使用的 Lasso()的默认参数来训练模型,它的主要参数包括:

  1. alpha:正则化项系数。它控制了L1正则化项的强度,即对模型复杂度的惩罚。alpha越大,模型越简单,但过大的alpha可能会导致模型欠拟合;alpha越小,模型越复杂,但过小的alpha可能会导致模型过拟合。默认值为1.0
  2. fit_intercept:布尔值,指定是否需要计算截距b值。如果设为False,则不计算b值默认值为True
  3. normalize:布尔值。如果设为True,则在模型训练之前将数据归一化。默认值为False
  4. precompute:布尔值,指定是否预先计算X的平方和。如果设为True,则在每次迭代之前计算X的平方和。默认值为False
  5. copy_X:布尔值,指定是否在训练过程中复制X。如果设为True,则在训练过程中复制X默认值为True
  6. max_iter:最大迭代次数。默认值为1000
  7. tol:阈值,用于判断是否达到收敛条件。默认值为1e-4
  8. warm_start:布尔值,如果设为True,则使用前一次的解作为本次迭代的起始点。默认值为False
  9. positive:布尔值,如果设为True,则强制系数为正。默认值为False
  10. selection:用于在每次迭代中选择系数的算法(有“cyclic”和“random”两种选择)。默认值为“cyclic”,即循环选择。

最后验证模型的训练效果:

from sklearn import metrics

y_pred = reg.predict(X_test)
mse = metrics.mean_squared_error(y_test, y_pred)
r2 = metrics.r2_score(y_test, y_pred)
m_error = metrics.median_absolute_error(y_test, y_pred)

print("均方误差:{}".format(mse))
print("复相关系数:{}".format(r2))
print("中位数绝对误差:{}".format(m_error))

# 运行结果
均方误差:441.07830708712186
复相关系数:0.9838880665687711
中位数绝对误差:11.643348614829785

误差看上去不小,因为这次实际生成的样本,不仅数量小(80件)且噪声大(noise=10)。

3.1. 与岭回归模型比较

单独看LASSO模型的训练结果,看不出其处理高维数据的优势。
同样用上面分割好的训练集测试集,来看看岭回归模型的拟合效果。

from sklearn.linear_model import Ridge
# from sklearn.model_selection import train_test_split

mse, r2, m_error = 0.0, 0.0, 0.0

# 初始化岭回归线性模型
reg = Ridge()
# 训练模型
reg.fit(X_train, y_train)

y_pred = reg.predict(X_test)
mse = metrics.mean_squared_error(y_test, y_pred)
r2 = metrics.r2_score(y_test, y_pred)
m_error = metrics.median_absolute_error(y_test, y_pred)

print("均方误差:{}".format(mse))
print("复相关系数:{}".format(r2))
print("中位数绝对误差:{}".format(m_error))

# 运行结果
均方误差:6315.046844910431
复相关系数:0.7693207470296398
中位数绝对误差:60.65140692273637

对于高维数据,可以看出,岭回归模型的误差 远远大于 LASSO模型。

3.2. 与最小二乘法模型比较

同样用上面分割好的训练集测试集,再来看看线性模型(最小二乘法)的拟合效果。

from sklearn.linear_model import LinearRegression

mse, r2, m_error = 0.0, 0.0, 0.0

# 初始化最小二乘法线性模型
reg = LinearRegression()
# 训练模型
reg.fit(X_train, y_train)

y_pred = reg.predict(X_test)
mse = metrics.mean_squared_error(y_test, y_pred)
r2 = metrics.r2_score(y_test, y_pred)
m_error = metrics.median_absolute_error(y_test, y_pred)

print("均方误差:{}".format(mse))
print("复相关系数:{}".format(r2))
print("中位数绝对误差:{}".format(m_error))

# 运行结果
均方误差:5912.442445894787
复相关系数:0.7840272859181612
中位数绝对误差:62.89225147465376

可以看出,线性模型的训练效果和岭回归模型差不多,但是都远远不如LASSO模型

4. 总结

总的来说,LASSO回归模型是一种流行的线性回归扩展,具有一些显著的优势和劣势。
比如,在特征选择上,LASSO通过将某些系数压缩为零,能够有效地进行特征选择,这在高维数据集中特别有用。
此外,LASSO可以作为正则化工具,有助于防止过拟合。

不过,LASSO会假设特征是线性相关的,对于非线性关系的数据,效果可能不佳。
而且,如果数据存在复杂模式或噪声,LASSO可能会过度拟合这些模式。

标签:error,r2,--,pred,模型,scikit,learn,test,LASSO
From: https://www.cnblogs.com/wang_yb/p/17931877.html

相关文章

  • 02 USB_JTAG驱动安装
    1概述一般安装vitis(vivado)的过程中勾选了安装JTAGcable驱动就会默认安装好JTAG驱动,但是如果vivado无法正确识别到JTAG,那么可以试下重新手动安装驱动2准备工作安装驱动前,必须关闭所有的vivado,vitis-sdk并且拔掉USBJTAG以免导致安装失败3USB_JTAG驱动安装找到vivado安......
  • 03 CP2104串口驱动安装
    1概述串口是最常用的一种调试工具,开发过程中我们经常会使用串口输出一些调试信息,在LINUX下也会用串口控制台控制LINUX系统。目前的串口,大部分都是USB转串口。CP2104是一款非常稳定好用的USB转串口芯片。接下来我们看下如何进行驱动安装。2软件下载登录米联客技术论坛https://......
  • Rust爬取大A股票数据.rs
    externcratesimple_excel_writerasexcel;useexcel::*;fnmain()->Result<(),Box<dynstd::error::Error>>{  leturl:&str="http://94.push2.eastmoney.com/api/qt/clist/get?cb=jQuery1124040399874179311124_1685159655748&pn......
  • 洛阳师范学院Luoyang normal university
    洛阳师范学院是一所省属普通高等本科院校,位于千年帝都、牡丹花城、丝路起点——洛阳。学校地处伊水之滨,万安山下,东汉太学便发端于此。南望二程故里,传颂着程门立雪、鲁台望道的佳话;西望关林和世界文化遗产龙门石窟,绽放着世界文化遗产的璀璨光芒。学校前身是始建于1916年的河南省立......
  • 河南警察学院 Henan police college
    河南警察学院是我省唯一的省属公安本科院校,前身是1949年2月成立的中共豫西区委保卫干部训练班,历经河南省公安干部学校、河南省人民警察学校、河南公安高等专科学校等时期。2010年3月经教育部批准成立河南警察学院。2012年8月开封警校、洛阳警校并入河南警察学院。2019年11月通过教......
  • 【Datahub系列教程】Datahub入门必学——DatahubCLI之Docker命令详解
    大家好,我是独孤风,今天的元数据管理平台Datahub的系列教程,我们来聊一下DatahubCLI。也就是Datahub的客户端。我们在安装和使用Datahub的过程中遇到了很多问题。如何安装Datahub?为什么总是拉取镜像?如何启动Datahub?这些Datahub的Docker命令都是做什么的?有很多同学虽然搜......
  • 基于深度学习网络的美食检测系统matlab仿真
    1.算法运行效果图预览  2.算法运行软件版本matlab2022a 3.算法理论概述      美食检测是一项利用计算机视觉技术来识别和分类食物图像的任务。       特征提取是食品检测的核心步骤,其目的是从输入图像中提取出有效的特征,以便于后续的分类。常见的......
  • Redis进阶 使用Lua编写Redis脚本
    前面学习了Lua的基本语法,接下来是使用Lua编写脚本1.可以使用redis.call来调用redis命令使用redis.call会将redis命令返回的类型转换成对应的Lua数据类型。关系如下 与redis.call想类似的就是redis.pcall。【redis.call与redis.pcall的区别】当命令出错的时候,redis.pcall......
  • CF1910G Pool Records记录
    题目链接:https://codeforces.com/contest/1910/problem/G题意简述有两个运动员以未知的固定速度\(v_1\nev_2\)在一个长为\(50\)米的游泳池中游泳,一旦到边缘就立即掉头。现在有他们前\(n\)次相遇时间\(t_i\)(递增,均为整数)的记录,问这个记录是否合法。\(n\le2\times10^......
  • 多语言应用监控最优选,ARMS 应用监控 eBPF 版正式发布
    作者:古琦、千陆、彦鸿随着Kubernetes、Serverless等云原生技术引领研发、运维模式变革。应用架构从单体架构逐步演进为分布式、微服务化应用,随着业务的发展,多语言、多框架、多协议的微服务在企业内部越来越多,微服务的复杂度越来越高,如何通过可观测来快速发现、定位微服务的问......