首页 > 其他分享 >深入理解 Scikit-Learn 中的 fit, transform 和 fit_transform

深入理解 Scikit-Learn 中的 fit, transform 和 fit_transform

时间:2024-08-05 13:28:28浏览次数:15  
标签:fit scaler Scikit transform 参数 test 数据

# 深入理解 Scikit-Learn 中的 fit, transform 和 fit_transform

在使用 Scikit-Learn 进行数据处理和机器学习建模时,经常会遇到三个重要的方法:`fit`、`transform` 和 `fit_transform`。它们是 Scikit-Learn 中用于数据预处理、特征提取和模型训练的核心方法。本文将详细解释这三个方法的作用、区别,并通过实例展示它们的使用。

 1. fit

`fit` 方法的主要作用是计算并存储模型参数。这些参数将在 `transform` 方法中用于数据转换。在数据预处理中,`fit` 通常用于估计数据的统计特性,如均值、方差、最大值、最小值等。

### 例子

from sklearn.preprocessing import StandardScaler
import numpy as np

# 创建一个样本数据
X = np.array([[1., -1., 2.],
              [2., 0., 0.],
              [0., 1., -1.]])

# 初始化 StandardScaler
scaler = StandardScaler()

# 计算并存储均值和标准差
scaler.fit(X)

# 查看计算得到的均值和标准差
print("均值:", scaler.mean_)
print("标准差:", scaler.scale_)

在上述例子中,`fit` 方法计算并存储了样本数据的均值和标准差,这些参数将用于后续的数据标准化操作,在机器学习中通常假设数据服从相同的分布,因此在标准归一化化测试集时需要用训练集的均值和标准差。

2. transform

`transform` 方法的作用是使用 `fit` 方法计算得到的参数对数据进行转换。对于数据预处理,`transform` 通常用于将数据缩放、标准化或归一化。

### 例子

# 使用计算得到的均值和标准差对数据进行标准化
X_scaled = scaler.transform(X)

print("标准化后的数据:\n", X_scaled)

在上述例子中,`transform` 方法使用先前 `fit` 方法计算得到的均值和标准差对数据进行标准化,使每个特征的均值为 0,标准差为 1。

 3. fit_transform

`fit_transform` 方法是 `fit` 和 `transform` 的组合。它首先对数据进行 `fit`(计算并存储参数),然后对数据进行 `transform`(使用计算得到的参数转换数据)。这种方法通常在数据预处理中更为简洁高效。

### 例子

# fit_transform 一步完成计算参数和数据转换
X_scaled_direct = scaler.fit_transform(X)

print("使用 fit_transform 标准化后的数据:\n", X_scaled_direct)

在上述例子中,`fit_transform` 方法一步完成了计算参数和数据转换,效果与先 `fit` 再 `transform` 相同。

 4. 总结

### fit

- 作用:计算并存储模型参数(如均值、标准差)。
- 适用对象:估计器、转换器。
- 主要使用场景:参数计算。

### transform

- 作用:使用 `fit` 方法计算得到的参数对数据进行转换。
- 适用对象:转换器。
- 主要使用场景:数据转换。

### fit_transform

- 作用:先 `fit`(计算并存储参数),再 `transform`(转换数据)。
- 适用对象:转换器。
- 主要使用场景:数据预处理时一步完成参数计算和数据转换。

 5. 应用实例

下面是一个完整的例子,展示如何使用 `fit`、`transform` 和 `fit_transform` 方法进行数据预处理和模型训练。

from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression

# 加载 Iris 数据集
iris = load_iris()
X = iris.data
y = iris.target

# 将数据集划分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 初始化 StandardScaler
scaler = StandardScaler()

# 对训练集进行 fit 和 transform
X_train_scaled = scaler.fit_transform(X_train)

# 对测试集只进行 transform
X_test_scaled = scaler.transform(X_test)

# 初始化逻辑回归模型
model = LogisticRegression()

# 训练模型
model.fit(X_train_scaled, y_train)

# 评估模型
accuracy = model.score(X_test_scaled, y_test)
print("模型准确率:", accuracy)

在这个例子中,我们首先对训练数据进行 `fit_transform`,然后对测试数据进行 `transform`,确保训练数据和测试数据使用相同的缩放参数进行标准化。最后,我们使用标准化后的数据训练并评估逻辑回归模型。

通过以上内容,希望大家能够更好地理解 `fit`、`transform` 和 `fit_transform` 的区别和使用场景,在实际项目中灵活运用它们。

---

标签:fit,scaler,Scikit,transform,参数,test,数据
From: https://blog.csdn.net/jjqhj/article/details/140906700

相关文章

  • Scalable Diffusion Models with Transformers(DIT)代码笔记
    完整代码来源:DiTDiT模型主要是在diffusion中,使用transformer模型替换了UNet模型,使用class来控制图像生成。根据论文,模型越大,patchsize越小,FID越小。模型越大,参数越多,patchsize越小,参与计算的信息就越多,模型效果越好。模型使用了Imagenet训练,有1000个分类,class_labe......
  • 数据变换 Transforms
    通常情况下,直接加载的原始数据并不能直接送入神经网络进行训练,此时我们需要对其进行数据预处理。MindSpore提供不同种类的数据变换(Transforms),配合数据处理Pipeline来实现数据预处理。所有的Transforms均可通过map方法传入,实现对指定数据列的处理。mindspore.dataset提供了......
  • 【创新未发表】Matlab实现蚁狮优化算法ALO-Kmean-Transformer-LSTM组合状态识别算法研
    蚁狮优化算法(AntLionOptimisation,ALO)是一种启发式优化算法,灵感来源于蚁狮捕食过程中的行为。这种算法模拟了蚁狮捕食中的策略,其中蚁狮通过在环境中设置虚拟陷阱来吸引蚂蚁,然后捕食这些落入陷阱的蚂蚁。在算法中,蚁狮代表潜在解决方案,而虚拟陷阱代表目标函数的局部最小值。......
  • Transformer 工作流程(大白话版)
    Transformer工作流程:通俗易懂的解释想象一下,你在参加一个创意写作班,你和其他几位同学一起写一篇故事。老师会让每个人轮流写一段,但在写之前,你们可以参考之前同学写的内容。这有点像Transformer的工作流程。让我们一步步来解释。编码器(Encoder)1.输入嵌入层(InputEmbed......
  • WPF C# implement scaletransform and translatetransfrom programmatically
    privatevoidInitRenderTransfrom(){TransformGrouptg=newTransformGroup();ScaleTransformst=newScaleTransform();if(!tg.Children.Contains(st)){tg.Children.Add(st);scaler=st;}TranslateTransformtt=n......
  • Pytorch笔记|小土堆|P10-13|transforms
    transforms对图像进行改造最靠谱的办法:根据help文件自行学习transforms包含哪些工具(类)以及如何使用————————————————————————————————————自学一个类时,应关注:1、如何使用各种工具(类)的使用思路:创建对象(实例化)——>传入参数,调用函数(如有__......
  • [CSS] max-content, min-content, fit-content
    max-contenthttps://developer.mozilla.org/en-US/docs/Web/CSS/max-contentThe max-content sizingkeywordrepresentsthemaximum intrinsicsize ofthecontent.Fortextcontentthismeansthatthecontentwillnotwrapatallevenifitcausesoverflows.......
  • 深度学习扫盲——Transforms
    在PyTorch中,torchvision是一个常用的库,它提供了对图像和视频数据的处理功能,包括数据加载、转换等。transforms是torchvision.transforms模块的一部分,它定义了一系列的图像转换操作,这些操作可以单独使用或者组合成转换序列(通过transforms.Compose),以便于在数据加载时自动应用到图像......
  • Seurat-SCTransform与harmony整合学习
    目录基础介绍SCTransform与harmony联合代码测试1)报错解决2)SCTransform标准化3)harmony去批次基础介绍源于Rtips:Seurat之SCTransform方法原理(qq.com)Seurat对象在经过SCTransform处理后会增加一个SCT的Assay,里面的scaled.data就是经过scale之后的pearsonresidual值......
  • 如何理解词向量、Transformer模型以及三个权重矩阵
    词向量与transformer 生成词向量的过程和训练Transformer的过程是两个不同的过程,但它们都是自然语言处理中的重要组成部分。#词向量的生成词向量(如Word2Vec、GloVe、FastText等)通常是通过预训练的词嵌入模型得到的。这些模型在大规模文本数据上训练,捕捉词与词之间的语义关系,......