首页 > 其他分享 >学习笔记425—train_test_split 函数介绍

学习笔记425—train_test_split 函数介绍

时间:2023-10-10 12:46:27浏览次数:43  
标签:None train split test stratify size

train_test_split 函数介绍

在机器学习中,我们通常将原始数据按照比例分割为“测试集”和“训练集”,从 sklearn.model_selection 中调用train_test_split 函数 

简单用法如下:

X_train,X_test, y_train, y_test =sklearn.model_selection.train_test_split(train_data,train_target,test_size=0.4, random_state=0,stratify=y_train)
# train_data:所要划分的样本特征集

# train_target:所要划分的样本结果

# test_size:样本占比,如果是整数的话就是样本的数量

# random_state:是随机数的种子。
# 随机数种子:其实就是该组随机数的编号,在需要重复试验的时候,保证得到一组一样的随机数。比如你每次都填1,其他参数一样的情况下你得到的随机数组是一样的。但填0或不填,每次都会不一样。

stratify是为了保持split前类的分布。比如有100个数据,80个属于A类,20个属于B类。如果train_test_split(... test_size=0.25, stratify = y_all), 那么split之后数据如下: 
training: 75个数据,其中60个属于A类,15个属于B类。 
testing: 25个数据,其中20个属于A类,5个属于B类。 

用了stratify参数,training集和testing集的类的比例是 A:B= 4:1,等同于split前的比例(80:20)。通常在这种类分布不平衡的情况下会用到stratify。

将stratify=X就是按照X中的比例分配 

将stratify=y就是按照y中的比例分配 

整体总结起来各个参数的设置及其类型如下:

主要参数说明:

*arrays:可以是列表、numpy数组、scipy稀疏矩阵或pandas的数据框

test_size:可以为浮点、整数或None,默认为None

①若为浮点时,表示测试集占总样本的百分比

②若为整数时,表示测试样本样本数

③若为None时,test size自动设置成0.25

train_size:可以为浮点、整数或None,默认为None

①若为浮点时,表示训练集占总样本的百分比

②若为整数时,表示训练样本的样本数

③若为None时,train_size自动被设置成0.75

random_state:可以为整数、RandomState实例或None,默认为None

①若为None时,每次生成的数据都是随机,可能不一样

②若为整数时,每次生成的数据都相同

stratify:可以为类似数组或None

①若为None时,划分出来的测试集或训练集中,其类标签的比例也是随机的

②若不为None时,划分出来的测试集或训练集中,其类标签的比例同输入的数组中类标签的比例相同,可以用于处理不均衡的数据集

通过简单栗子看看各个参数的作用:

举例如下:

通过简单例子看看各个参数的作用:

①test_size决定划分测试、训练集比例

②random_state不同值获取到不同的数据集

 设置random_state=0再运行一次,结果同上述相同

设置random_state=None运行两次,发现两次的结果不同

③设置stratify参数,可以处理数据不平衡问题

参考链接:https://blog.csdn.net/u010687164/article/details/108252746

标签:None,train,split,test,stratify,size
From: https://www.cnblogs.com/hechangchun/p/17754369.html

相关文章

  • pytest如何设置 批量用例执行时,失败自动retry重试机制
    失败重跑插件pytest-rerunfailures1.环境前提:以下先决条件才能使用pytest-rerunfailuresPython3.5,最高3.8,orPyPy3pytest5.0或更高版本2.安装插件pipinstallpytest-rerunfailures插件重点:命令行参数:–rerunsn(重新运行次数),–reruns-delaym(等待运行秒......
  • Go - Finding the Shortest Path on a Graph
    Problem: Youwanttofindtheshortestpathbetweentwonodesonaweightedgraph.Solution: UseDijkstra’salgorithmtofindtheshortestpathbetweentwonodes.Dijkstra’salgorithmalsousesapriorityqueue,whichcanbeimplementedusingaminheap.......
  • P8511 [Ynoi Easy Round 2021] TEST_68
    题目传送门看到异或最大值,根据套路不妨考虑\(0-1trie\)。通过\(trie\)找到异或值最大的点对\((x,y)\)。那么除了\((x,y)\)到\(1\)路径上的点之外,其他的点的答案就是\((x,y)\)的异或值。接下来考虑怎么算出这\((x,y)\)到\(1\)路径上的点的答案,可以直接暴力计算!......
  • pytest + yaml 框架 -57.用例前置和后置操作执行sql
    前言前面一篇已经介绍在用例中可以查询以及在断言中实现sql,本篇讲解在用例的前置和后置操作中执行sql配置mysql环境,参考前面一篇https://www.cnblogs.com/yoyoketang/p/16977960.html有2内置的函数可以使用query_sql(sql)查询sql,查询无结果返回[],查询只有一个结果返......
  • 联考test1009
    写在前面的话感觉比以往的比赛难多了。出题人卡高精度,不好评价,但是题目还是好题。考试的时候开题顺序为\(T1-T3-T4-T2\),感觉和题目的实际难度排序差不多。考试的时候懒了,没有去拼暴力,实际得分\(80+0+100+0=180\),总体排名\(rk29\)。\(T1\)题意简述我们知道,对于一个整......
  • Vitest All In One
    VitestAllInOneVite+TestNextGenerationTestingFramework.AVite-nativetestingframework.#VitestrequiresVite>=v3.0.0andNode>=v14.18$npmi-Dvitesthttps://vitest.dev/guide/https://github.com/vitest-dev/vitesthttps://vitest.d......
  • Test
    T1题意思路代码T2题意思路代码T3题意思路代码T4题意思路代码......
  • AtCoder Beginner Contest 323 (ABC 323) D、E、F 题解
    AtCoderBeginnerContest323(ABC323)D、E、F题解D题目大意给\(n\)种数\(s_i\),每一种数有\(c_i\)个,每次可以把两个相同的数合并为一个数,问最后会剩下多少数?分析对于每一个数\(s_i\),它最多被分解\(log_2c_i\)次,并且合并出来最大的数的大小小于\(s_i\timesc_i......
  • Could not resolve all dependencies for configuration ':testCompileClasspath'. Us
    Gradleinit.gradle文件参数错误导致的Gradle加载失败 1allprojects{2repositories{3mavenLocal()4maven{name"Alibaba";url"https://maven.aliyun.com/repository/public"}//将http改为https5maven{name&......
  • AtCoder Beginner Contest 323
    E-Playlist首先需要算出第x+0.5秒后,第一首歌播放的概率1.要在x+0.5秒后播放第一首,需要在x,x-1,x-2,...,x-t[1]+1,时就要开始播放第一首,并且概率是1/n,概率之和除以n2.概率dp,dp[i]表示播放i的概率,那么可以转换成,dp[i]+=dp[i-j]/n%mod(i>=t[j])3.答案就是x,x-1,...,x-t[1]+1概率之和......