首页 > 编程问答 >train_test_split 导致 xgboost 忽略“enable_categorical”

train_test_split 导致 xgboost 忽略“enable_categorical”

时间:2024-07-29 08:57:22浏览次数:13  
标签:python machine-learning statistics xgboost

我正在使用 xgboost 版本 2.1.0
当使用 xgboost.DMatrix() 和 'enable_categorical'=True 将包含类别列的 pandas 数据帧转换为 DMatrix 时,所有行为均按预期运行,除非数据帧是 sklearn train_test_split() 返回的数据帧,尽管所有列的数据类型仍属于类别。

以下代码产生预期的行为:

import pandas as pd 
import xgboost as xgb
import seaborn as sns

tips = sns.load_dataset('tips')

X, y = tips.drop('tip', axis=1), tips['tip']
print(X.dtypes)

# convert to DMatrix
dm = xgb.DMatrix(X, y, enable_categorical=True)
dm

---------------------------

total_bill     float64
sex           category
smoker        category
day           category
time          category
size             int64
dtype: object

<xgboost.core.DMatrix at 0x23153d7cf10>

以下代码引发错误:

from sklearn.model_selection import train_test_split

X_train, y_train, X_test, y_test = train_test_split(X, y, random_state=1)
print(X_train.dtypes)

dtrain = xgb.DMatrix(X_train, y_train, enable_categorical=True)

-----------------------------

total_bill     float64
sex           category
smoker        category
day           category
time          category
size             int64
dtype: object

ValueError: DataFrame.dtypes for data must be int, float, bool or category. When categorical type is supplied, the experimental DMatrix parameter`enable_categorical` must be set to `True`.  Invalid columns:sex: category, smoker: category, day: category, time: category

我很困惑。 谢谢。


你遇到的问题是由于 train_test_split 函数返回的数组顺序造成的,而不是数据类型的问题。 train_test_split 返回的是 X_train , X_test , y_train , y_test ,而不是 X_train , y_train , X_test , y_test

将你的代码修改为以下即可解决问题:

from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=1)  # 注意顺序
print(X_train.dtypes)

dtrain = xgb.DMatrix(X_train, y_train, enable_categorical=True)

修正后的代码中, train_test_split 返回的数组顺序正确,因此可以成功创建 DMatrix 对象。

标签:python,machine-learning,statistics,xgboost
From: 78805114

相关文章

  • 《最新出炉》系列入门篇-Python+Playwright自动化测试-56- 多文件上传 - 下篇
    1.简介前边的两篇文章中,宏哥分别对input控件上传文件和非input控件上传文件进行了从理论到实践地讲解和介绍,但是后来又有人提出疑问,前边讲解和介绍的都是上传一个文件,如果上传多个文件,Playwright是如何实现的呢?宏哥看了一下官方的API也有上传多个文件的API,那么今天就来讲解和介绍......
  • 如何使用python模块捕获用户的文本输入
    我正在开发一个项目,它会检测到如果您按“(”,它会自动关闭它“[”和“{”的情况相同,但重点是它检测键盘按钮“{”或“[”不是字符,这意味着如果朋友有不同的方式输入“[”,它将无法工作,因为该程序用于检测“altgr+(”序列,这可能会影响不同语言的键盘因为您不想在按下......
  • 如何更新 numpy 2 的 python 模块?
    在带有pip的Linux上,新的numpy2似乎可以很好地与pandas配合使用:$python3-c'importnumpyasnp;print(np.__version__);importpandasaspd;print(pd.__version__)'2.0.12.2.2但是,在带有miniconda的Windows上,我得到$${localappdata}/miniconda3/en......
  • python BioChemist 数据集的数据字典/描述
    我正在使用生物化学家数据集。我在哪里可以找到包含每个变量描述的“数据字典”?这就是我正在查看的:importpandasaspdfrompydatasetimportdatadata('bioChemists')我已经用谷歌搜索并尝试寻找运算符,但没有运气!pydataset软件包不包含生物化学家数据集的描述......
  • python中的Telebot API不断断开连接
    使用远程机器人,不断断开服务。我暂时让它在发生这种情况时重新启动。下面是我的代码和错误:importrandomimporttelebotfromtelebot.typesimportInlineKeyboardMarkup,InlineKeyboardButtonfromthreadingimportTimer,Eventfromdotenvimportload_dotenvimporto......
  • 如何用Python制作Android服务?
    我想构建一个简单的Android应用程序,例如PushOver应用程序,它具有TCP服务器并接收其记录的文本消息,然后将其作为推送通知发送。这部分已经完成并且工作正常。但即使GUI应用程序关闭,我也想接收消息。我知道这是可能的,因为PushOver应用程序做到了!我想,我可能需要一......
  • Python Discord Bot 的应用程序命令的区域设置名称(多语言别名)
    如何根据用户的语言设置,使应用程序命令的名称具有不同的名称例如,如果一个用户将其discord的语言设置为英语,则用户可以看到英语的应用程序命令名称。另一方面,如果另一个用户将其不和谐语言设置为法语,则用户可以看到法语中的相同应用程序命令的名称。为此,我尝试使用ap......
  • 如何在Python中添加热键?
    我正在为游戏制作一个机器人,我想在按下热键时调用该函数。我已经尝试了一些解决方案,但效果不佳。这是我的代码:defstart():whileTrue:ifkeyboard.is_pressed('alt+s'):break...defmain():whileTrue:ifkeyboard.is_pr......
  • 在Python中解压文件
    我通读了zipfile文档,但不明白如何解压缩文件,只了解如何压缩文件。如何将zip文件的所有内容解压缩到同一目录中?importzipfilewithzipfile.ZipFile('your_zip_file.zip','r')aszip_ref:zip_ref.extractall('target_directory')将......
  • 如何在Python中从RSA公钥中提取N和E?
    我有一个RSA公钥,看起来像-----BEGINPUBLICKEY-----MIIBIDANBgkqhkiG9w0BAQEFAAOCAQ0AMIIBCAKCAQEAvm0WYXg6mJc5GOWJ+5jkhtbBOe0gyTlujRER++cvKOxbIdg8So3mV1eASEHxqSnp5lGa8R9Pyxz3iaZpBCBBvDB7Fbbe5koVTmt+K06o96ki1/4NbHGyRVL/x5fFiVuTVfmk+GZNakH5dXDq0fwvJyVmUtGYA......