首页 > 编程问答 >pytorch CNN 模型参数存储不正确

pytorch CNN 模型参数存储不正确

时间:2024-07-23 04:34:30浏览次数:9  
标签:python pytorch

工作任务是制作一个 CNN 模型来对图像进行一些分类任务。另外,我应该能够在对图像进行分类后查看特征图,即应用卷积或池化操作后获得的图像。下面是我定义 CNN 类的方式:

class ConvNet(nn.Module):
  def __init__(self, input_channels, output_dim):
    super().__init__()
    # input 48x48
    self.architecture = {
        "conv1": self.convblock(input_channels, 128, (3,3)), # 46x46
        "conv2" : self.convblock(128, 64, (3, 3), bnorm=True), # 44x44
        "pool1" : self.poolblock((2,2)), # 22x22
        "conv3" : self.convblock(64, 16, (3,3), stride=2), #10x10
        "conv4" : self.convblock(16, 10, (3,3)), # 8x8
        "pool2" : self.poolblock((2,2), bnorm=10), # 4x4
        "feedforward" : nn.Sequential(
          nn.Flatten(), # 4x4x10 = 160
          nn.Linear(160, 128), # 128
          nn.ReLU(inplace=True),
          nn.Dropout(0.3),
          nn.Linear(128, output_dim), # 3
          nn.Softmax(dim=1)
        )                  
    }
    self.maps = {}

  def forward(self, x):
    image = x
    for name, layer in self.architecture.items():
      out = layer(image)
      self.maps[name] = out
      image = out
    return image
     
  def convblock(self, inp, out, kernel, stride=1, bnorm=False):
    if bnorm:
      return nn.Sequential(
        nn.Conv2d(inp, out, kernel, stride=stride),
        nn.ReLU(inplace=True),
        nn.BatchNorm2d(out)
      )
    else:
      return nn.Sequential(
        nn.Conv2d(inp, out, kernel, stride=stride),
        nn.ReLU(inplace=True)
      )

  def poolblock(self, kernel, bnorm=None):
    if bnorm is None:
      return nn.MaxPool2d(kernel)
    else:
      return nn.Sequential(
          nn.MaxPool2d(kernel),
          nn.BatchNorm2d(bnorm)
      )

  def get_map(self, im, layer):
    fig, ax = plt.subplots(1,2, figsize=(20,10), gridspec_kw={'width_ratios': [1,3]})
    ax[0].set_xticks([])
    ax[0].set_yticks([])
    ax[0].imshow(im.reshape(im.shape[-2],im.shape[-1],1), cmap="gray") # Shows Input image
    self(im)
    map = self.maps[layer]
    map=map.reshape(map.shape[1],1,map.shape[-2],map.shape[-1])
    ax[1].set_xticks([])
    ax[1].set_yticks([])
    rows = max(int(map.shape[0]/8), 8)
    ax[1].imshow(make_grid(map,nrow=rows).permute(1, 2, 0)) # Shows all the channels after an operation.

这个想法是将卷积层和池化层块存储在 self.architecture 字典中,名称如 'conv1' 'conv2' 'pool1' 等... 然后,在前向方法中,我将通过每个块运行输入图像,并将每个块的输出存储在 self.maps 字典中以供稍后检索( self.get_map 这样做)。

问题是模型的参数设置不正确。下面是我实例化模型和优化器的代码:

model = ConvNet(1, 3).to(device)
adam_opt = torch.optim.Adam(model.parameters(), lr=learning_rate)

但是我遇到了以下错误:

/usr/local/lib/python3.10/dist-packages/torch/optim/optimizer.py in __init__(self, params, defaults)
    271         param_groups = list(params)
    272         if len(param_groups) == 0:
--> 273             raise ValueError("optimizer got an empty parameter list")
    274         if not isinstance(param_groups[0], dict):
    275             param_groups = [{'params': param_groups}]

ValueError: optimizer got an empty parameter list

我打印出了参数列表,它是空的。我不明白为什么会发生这种情况。

我定义架构或前向方法的方式是否错误,PyTorch 在子类化时需要一些特定的行为 nn.Module ?如果是这样,那是什么?我该如何更改我的班级?任何关于 Pytorch 如何实际存储类参数的额外信息都是最受欢迎的。


你的代码中出现错误的原因是,你没有将 nn.Module 中的层作为属性添加到你的类中。你将它们存储在名为 self.architecture 的普通 Python 字典中,PyTorch 无法识别。

以下是如何修复代码的方法:

class ConvNet(nn.Module):
    def __init__(self, input_channels, output_dim):
        super().__init__()
        # input 48x48
        self.conv1 = self.convblock(input_channels, 128, (3,3)) # 46x46
        self.conv2 = self.convblock(128, 64, (3, 3), bnorm=True) # 44x44
        self.pool1 = self.poolblock((2,2)) # 22x22
        self.conv3 = self.convblock(64, 16, (3,3), stride=2) #10x10
        self.conv4 = self.convblock(16, 10, (3,3)) # 8x8
        self.pool2 = self.poolblock((2,2), bnorm=10) # 4x4
        self.feedforward = nn.Sequential(
            nn.Flatten(), # 4x4x10 = 160
            nn.Linear(160, 128), # 128
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(128, output_dim), # 3
            nn.Softmax(dim=1)
        )                  
        self.maps = {}

    def forward(self, x):
        image = x
        for name, layer in self._modules.items():
            out = layer(image)
            self.maps[name] = out
            image = out
        return image

    # ... rest of your code (convblock, poolblock, get_map)...

解释:

  1. 直接将层作为属性: 我们删除了 self.architecture 字典,并将每一层( self.conv1 , self.conv2 , etc.)直接添加为 ConvNet 类的属性。这使得 PyTorch 能够识别它们为需要优化的参数。
  2. 使用 self._modules forward 方法中,我们使用 self._modules 来迭代模块。 self._modules 是一个包含所有已注册子模块(即你的层)的 OrderedDict。这确保了以正确的顺序执行层。

通过这些更改,PyTorch 现在可以正确识别模型中的参数,并且你应该能够在不出现错误的情况下创建优化器。

关于 PyTorch 如何存储参数的额外信息:

  • 当你将一个 nn.Module 子类化并添加 nn.Parameter 或其他 nn.Module 实例作为属性时,PyTorch 会自动跟踪它们。
  • nn.Parameter 是一个 Tensor 的包装器,它告诉 PyTorch 该张量需要梯度计算。
  • model.parameters() 方法递归地收集模型及其子模块中的所有参数,并返回一个迭代器。
  • 优化器使用此迭代器来访问和更新模型训练期间的参数。

标签:python,pytorch
From: 78315870

相关文章

  • 并行 GPU 中的 Pytorch 推理
    我有一个经过训练的模型和4个GPU。我有一个数据集,我想将其分割为4个(并在每个GPU上使用相同的批量大小进行处理),并且基本上将我从每个GPU获得的结果相加。一般来说,我对并行处理很陌生,我对torch.multiprocessing等的谷歌搜索没有帮助。我的问题的一个最小示例如下:我......
  • 无法在 python 中安装 pip install expliot - bluepy 的 Building Wheel (pyproject.t
    在此处输入图像描述当我尝试在Windows计算机中通过cmd安装pipinstallexpliot包时,我收到2个错误名称×Buildingwheelforbluepy(pyproject.toml)didnotrunsuccessfully.│exitcode:1**AND**opt=self.warn_dash_deprecation......
  • python 用单斜杠-反斜杠替换url字符串中的双斜杠
    我的URL包含错误的双斜杠(“//”),我需要将其转换为单斜杠。不用说,我想保持“https:”后面的双斜杠不变。可以在字符串中进行此更改的最短Python代码是什么?我一直在尝试使用re.sub,带有冒号否定的正则表达式(即,[^:](//)),但它想要替换整个匹配项(包括前面......
  • 如何使用 Selenium Python 搜索 Excel 文件中的文本
    我有一些数据在Excel文件中。我想要转到Excel文件,然后搜索文本(取自网站表),然后获取该行的所有数据,这些数据将用于在浏览器中填充表格。示例:我希望selenium搜索ST0003然后获取名称,该学生ID的父亲姓名,以便我可以在大学网站中填写此信息。我想我会从网站......
  • Python 套接字请求在很多情况下都会失败
    我在python中尝试了超过5种不同的方法,尽管人们说它在其他论坛上有效,但所有这些方法都惨遭失败。importsocketmessage="test"clientsocket=socket.socket(socket.AF_INET,socket.SOCK_STREAM)clientsocket.connect(('1.1.1.1',80))clientsocket.send(mes......
  • Python 网络套接字
    我一直尝试通过Python访问该网站的websocket,但是需要绕过CloudFlare,现在我尝试通过cookie进行绕过,但是这不起作用。我已经尝试在没有cookie的情况下执行此操作,但这也不起作用。importwebsocketimportbase64importosdriver=selenium.webdriver.Firefox()driver.ge......
  • 如何在Python中使用Selenium提取data-v-xxx?
    因为我想查看每个class='num'内的文本是否大于0。如果测试通过,那么我需要获取venuen-name内的文本。我观察到,data-v是相同的。所以我的方法是获取相同的data-v-<hashvalue>来查找场地名称。我尝试了不同的方法来提取,但仍然无法提取。有什么建议吗?这是DOM<div......
  • Python:添加异常上下文
    假设我想提出一个异常并提供额外的处理信息;最好的做法是什么?我想出了以下方法,但对我来说有点可疑:definternal_function():raiseValueError("smellysocks!")defcontext_function():try:internal_function()exceptExceptionase:......
  • 【视频】Python遗传算法GA优化SVR、ANFIS预测证券指数ISE数据-CSDN博客
    全文链接:https://tecdat.cn/?p=37060本文旨在通过应用多种机器学习技术,对交易所的历史数据进行深入分析和预测。我们帮助客户使用了遗传算法GA优化的支持向量回归(SVR)、自适应神经模糊推理系统(ANFIS)等方法,对数据进行了特征选择、数据预处理、模型训练与评估。实验结果表明,这些方法......
  • Python学习笔记42:游戏篇之外星人入侵(三)
    前言在之前我们已经创建好了目录,并且编写好了游戏入口的模块。今天的内容主要是讲讲需求的分析以及项目各模块的代码初步编写。在正式编写代码前,碎碎念几句。在正式编写一个项目代码之前,实际是有很多工作要做的。就项目而言,简单的定项,需求对齐,项目架构设计,实际的代码编写,......