首页 > 编程问答 >如何将输入暗淡从 fit 方法传递到 skorch 包装器?

如何将输入暗淡从 fit 方法传递到 skorch 包装器?

时间:2024-07-27 06:37:58浏览次数:11  
标签:python deep-learning pytorch skorch

我正在尝试将 PyTorch 功能合并到 scikit-learn 环境中(特别是 Pipelines 和 GridSearchCV),因此一直在研究 skorch 神经网络的标准文档示例看起来像

import torch.nn.functional as F
from torch import nn
from skorch import NeuralNetClassifier

class MyModule(nn.Module):
    def __init__(self, num_units=10, nonlin=F.relu):
        super(MyModule, self).__init__()

        self.dense0 = nn.Linear(20, num_units)
        self.nonlin = nonlin
        self.dropout = nn.Dropout(0.5)
        ...
        ...
        self.output = nn.Linear(10, 2)
    ...
    ...

,您通过将输入和输出维度硬编码到构造函数中来显式传递它们。然而,这并不是 scikit-learn 接口真正的工作方式,其中输入和输出维度是通过 fit 方法导出的,而不是显式传递给构造函数。作为一个实际的例子,除了变压器中任何地方都必须指定输入和输出维度这一事实之外,在模型之前

# copied from the documentation
net = NeuralNetClassifier(
    MyModule,
    max_epochs=10,
    lr=0.1,
    # Shuffle training data on each epoch
    iterator_train__shuffle=True,
)

# any general Pipeline interface
pipeline = Pipeline([
        ('transformation', AnyTransformer()),
        ('net', net)
        ])

gs = GridSearchCV(net, params, refit=False, cv=3, scoring='accuracy')
gs.fit(X, y)

应用的变压器可能会改变训练集的维度(想想降维和类似),因此在神经网络构造函数中硬编码输入和输出是行不通的。 我是否误解了它应该如何工作或者建议的解决方案是什么(我正在考虑将构造函数指定为 您确实有

可用于拟合的方法,但我不确定这是一个好的做法)? forward method where you do have X available for fit already, but I am not sure this is good practice)?


你对如何在 scikit-learn 环境中使用 skorch 的理解是正确的。你不应该将输入维度硬编码到你的 PyTorch 模块中,因为这会阻止你在管道中使用它,在管道中,维度可能在拟合时之前是未知的。

你有几个选择可以将输入维度传递给你的 skorch 包装器:

1. 使用 input_dim 参数(推荐):

skorch NeuralNetClassifier NeuralNetRegressor 有一个 module__ 前缀,允许你将参数传递给底层 PyTorch 模块。你可以使用它在实例化你的模型时传递输入维度:

net = NeuralNetClassifier(
    MyModule,
    module__input_dim=20,  # 将输入维度传递给 MyModule
    max_epochs=10,
    lr=0.1,
    iterator_train__shuffle=True,
)

在你的 MyModule 中,你可以像这样访问 input_dim

class MyModule(nn.Module):
    def __init__(self, input_dim, num_units=10, nonlin=F.relu):
        super(MyModule, self).__init__()
        self.dense0 = nn.Linear(input_dim, num_units)
        # ... 其他层 ...

此方法使你的代码更具可读性和灵活性,因为它使输入维度成为 NeuralNetClassifier 的显式参数。

2. 在 fit 方法中设置输入维度:

你可以在 MyModule fit 方法中访问输入数据的形状,并使用该信息来初始化你的第一层。

class MyModule(nn.Module):
    def __init__(self, num_units=10, nonlin=F.relu):
        super(MyModule, self).__init__()
        self.num_units = num_units
        self.nonlin = nonlin
        # ... 其他层 ...
        self.first_layer = None 

    def forward(self, X):
        if self.first_layer is None:
            self.first_layer = nn.Linear(X.shape[1], self.num_units)
        # ... 其他层 ...

这种方法不是很干净,因为它使你的 forward 方法更复杂,并且偏离了 PyTorch 模块的典型初始化模式。

3. 使用灵活的层,例如 nn.Flatten

如果你使用 nn.Flatten 作为你的第一个层,它可以处理任何输入维度,你就不需要显式地传递输入维度。

class MyModule(nn.Module):
    def __init__(self, num_units=10, nonlin=F.relu):
        super(MyModule, self).__init__()
        self.flatten = nn.Flatten()
        self.dense0 = nn.Linear(num_units, num_units)
        # ... 其他层 ...

这种方法很简单,但如果你需要在第一层之前对输入维度执行特定操作,则可能不适合。

总的来说,建议使用 input_dim 参数,因为它可以提供清晰度和灵活性,而不会使你的代码过于复杂。 这使你可以在 skorch 包装器中利用 scikit-learn 管道的优势,而无需硬编码输入维度。

标签:python,deep-learning,pytorch,skorch
From: 60005715

相关文章

  • Python win32serviceutil QueryServiceStatus:返回值是什么意思?
    我正在学习使用pywin32,并尝试在64位Python3.6.4上使用win32serviceutil模块以下代码:importwin32serviceutilasserviceserviceStatus=service.QueryServiceStatus("WinDefend")print(serviceStatus)返回以下元组:(16,4,197,0,0,0,0)我对wind......
  • Python request-html 未下载 Chromium
    importrequestsfrombs4importBeautifulSoupfromrequests_htmlimportHTMLSessionurl="https://dmarket.com/ingame-items/item-list/csgo-skins?title=recoil%20case"sesion=HTMLSession()response=sesion.get(url)response.html.render()soup=B......
  • VS Code 不改变 python 环境
    我正在使用VS-Code和anaconda环境作为python解释器。我通过ctrl+shift+`选择准确的anaconda基础环境,它也反映在vscode的下侧面板中。但是,当我检查python版本时,它显示我系统的默认python环境3.7.9如果您看到下面的截图,anaconda环境是3.......
  • 使用 Python 打开保存为 Parquet 文件中元数据的 R data.table
    使用R,我创建了一个Parquet文件,其中包含一个data.table作为主要数据,另一个data.table作为元数据。library(data.table)library(arrow)dt=data.table(x=c(1,2,3),y=c("a","b","c"))dt2=data.table(a=22222,b=45555)attr(dt,&......
  • Python 需要 Windows 长路径
    我尝试运行此安装:pip3installmsgraph-sdk它给了我这个错误:它说我需要使用此链接启用Windows长路径:https://learn.microsoft.com/en-us/windows/win32/fileio/maximum-file-path-limitation?tabs=registry#enable-long-paths-in-windows-10-versi......
  • Python griddata() 和 Matlab griddata():某些网格点的结果不同
    在将一些(相当大的物理)Matlab代码转换为Python时,我偶然发现了这种情况。当对相同的二维离散数据进行插值时,Python/Scipy的griddata()函数给出的结果与Matlab的对应函数不同。griddata()Matlab示例代码:Python示例代码:%Samplepoints(x,y):7x5=3......
  • Ebay Python SDK 仅在特定项目类别上返回错误
    我在一个项目中使用ebaySDK一段时间了。最近我尝试导入一些商品,例如手表、手机壳等...并且我使用了eBay自己通过eBay返回的英国商店页面上的类别ID他们的“get_category_suggestions”API端点,但eBay似乎有选择地决定拒绝某些项目并引发服务器错误!为了测试,我做了......
  • 使用特定的Python版本(MacOS)制作virtualenv
    我安装了brew,python3(默认和最新版本)和pip3,pyenv。TensorFlow现在不支持python3.7,所以我听说我应该制作一个独立运行3.6或更低版本的virtualenv。我安装了python3.6.7bypyenvinstall3.6.7但无法制作virtualenv-p3.6.7(mydir)因为3.6.7不在P......
  • 使用Python去除图像中的线条
    我正在尝试使用Python和cv2、numpy、skimage等从黑白图像中删除“阴影线”(如果图像中存在“阴影线”)。本质上,我的图像可以有1或2条曲线,如下例所示。但每条线都有一条1-5像素外的阴影线,需要删除。我怎样才能在Python中做到这一点?原始......
  • Python 和 OpenCV:如何裁剪半成形边界框
    我有一个为无网格表创建网格线的脚本:脚本之前:脚本之后:是否有一种简单的方法,使用OpenCV来裁剪“脚本之后”图像,使其仅包含四边边界框?示例输出:编辑:我目前正在研究一种解决方案,该解决方案可以找到垂直/水平方向的第一条/最后一条......