首页 > 编程问答 >使用适用于灰度图像的模型的指导

使用适用于灰度图像的模型的指导

时间:2024-07-28 06:29:31浏览次数:10  
标签:python deep-learning pytorch neural-network computer-vision

我有一个由 3 个 conv2D 层和 ReLU 激活组成的模型。它将标准化为区间 [0,1] 的灰度图像作为输入。输入图像有一些黑色区域、一些白色区域和其他区域。

但是,输出动态范围被压缩到 [0.4,0.401]。所有图像都是灰色的,即使在重新规范化将它们带回 [0,255] 之后也是如此。

我有点迷失,我不明白为什么会出现这种情况。

我尝试过的:

  • 情节梯度直方图,尽管我不确定如何解释它。对于某些参数,模型似乎在几个时期后停止学习。对于其他层,一层会停止学习。
  • 修改内核大小
  • 添加更多层
  • 尝试很多不同的学习率值

下面您将找到我的模型和我的训练循环。

class Model(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
           nn.Conv2d(in_channels=1, out_channels=64, kernel_size=13, stride=1, padding=5),
           nn.ReLU(True),
           nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=2),
           nn.ReLU(True),
           nn.Conv2d(in_channels=32, out_channels=1, kernel_size=5, stride=1, padding=2),
          )


    def forward(self, input_image):
        output_image = self.model(input_image)
        return output_image
def train_one_epoch(
    model: nn.Module,
    dataloader:  torch.utils.data.DataLoader,
    loss_fn: nn.MSELoss,
    optimizer: optim.Adam,
    epoch_index,
    scaler: amp.GradScaler, 
):
   
    model.train()
    running_loss = 0
    last_loss = 0
    batch_index = 0

    for batch, loader in enumerate(train_dataloader):
                
        input_img = loader['input'].to(device, non_blocking=True)
        gt_img = loader['gt'].to(device, non_blocking=True)

        model.zero_grad(set_to_none=True)
      
        with amp.autocast():
            output = model(input_img)
            loss = loss_fn(output, gt_img) 
       
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item()
        if batch_index % 100 == 99:
            last_loss = running_loss / 1000
            running_loss = 0.0
        
        batch_index += 1

    return last_loss

任何感谢帮助。如果需要更多信息,请告诉我。


以下是的模型输出范围限于一小段的原因以及如何解决该问题:

  1. 激活函数缺失: 的模型架构中缺少输出层的激活函数。目前,的模型应用三个卷积层,然后是 ReLU 激活。这会导致线性输出,并且由于的输入在 [0, 1] 范围内进行归一化,因此模型的输出也倾向于保持在有限的范围内。在最后一层添加一个 sigmoid 或 tanh 激活函数可以帮助扩展输出范围。
  2. sigmoid 会将输出压缩到 [0, 1] 范围内,这对于的情况是合适的,因为的目标输出也是灰度图像。
  3. tanh 会将输出压缩到 [-1, 1] 范围内,在训练过程中可能需要将其重新缩放回 [0, 1] 范围。

  4. 权重初始化: 如果即使添加了激活函数,的输出范围仍然很小,那么问题可能出在权重初始化上。尝试使用不同的初始化技术来查看它们是否对输出范围有影响。PyTorch 提供了许多初始化函数,例如 torch.nn.init.kaiming_normal_() torch.nn.init.xavier_uniform_() ,这些函数可以帮助设置适合 ReLU 激活的权重。

  5. 数据分布: 提到的输入图像有一些黑色、白色和其他区域。如果这些区域之间的对比度较低或数据集的整体动态范围有限,则可能会导致模型输出范围较小。请确保的数据具有足够的对比度和变化以使模型能够学习有意义的表示。还可以尝试使用直方图均衡化等技术来增强训练数据的对比度。

  6. 学习率: 虽然尝试了许多不同的学习率值,但确保使用合适的学习率仍然很重要。学习率过高或过低都会导致收敛问题。考虑使用学习率调度器来在训练过程中调整学习率。

以下是应该进行的代码更改以解决这些问题:

class Model(nn.Module):

    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
           nn.Conv2d(in_channels=1, out_channels=64, kernel_size=13, stride=1, padding=5),
           nn.ReLU(True),
           nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=2),
           nn.ReLU(True),
           nn.Conv2d(in_channels=32, out_channels=1, kernel_size=5, stride=1, padding=2),
           nn.Sigmoid()  # 添加 sigmoid 激活函数
          )

    def forward(self, input_image):
        output_image = self.model(input_image)
        return output_image

# ...(其余代码保持不变)

为的模型尝试这些更改:

  1. 添加 sigmoid 或 tanh 激活函数。
  2. 尝试不同的权重初始化技术。
  3. 验证的数据是否具有足够的对比度和变化。
  4. 微调学习率并考虑使用学习率调度器。

通过解决这些问题,应该能够扩展模型的输出范围并生成更生动的灰度图像。

标签:python,deep-learning,pytorch,neural-network,computer-vision
From: 77273071

相关文章

  • 当我们创建一个在 https 服务器上获取和发送数据的 python 应用程序时,我们应该如何处
    python脚本使用Urllib3,我的服务器是在Node.js上编写的脚本。我担心(并且不清楚)证书:我是否需要将我的python应用程序上的证书另存为变量?例如我这样做了,http=urllib3.PoolManager(cert_reqs="CERT_REQUIRED",ca_certs='client-cert.pem')并且我不知道......
  • 使用 Python 进行 Web 抓取以获取数据 NoneType ERROR
    我正在努力为我的学校项目获取美元和价格。所以我决定为此使用网络抓取,但我有一个问题。当我尝试在服务器上使用我的代码时,它给我NoneType错误。它可以在googlecolab上使用,但我无法在我的电脑或服务器上使用。我该如何解决这个问题?网页抓取代码;defdolar():he......
  • Python 请求 - response.json() 未按预期工作
    我正在尝试从Python的requests模块调用API。在邮递员上,返回的响应标头中的Content-Type是application/json;charset=utf-8,响应json数据是我期望的样子。但是,在python上的API的get方法之后运行response.json()会抛出错误simplejson.errors......
  • Python 中的“样板”代码?
    Google有一个Python教程,他们将样板代码描述为“不幸的”,并提供了以下示例:#!/usr/bin/python#importmodulesusedhere--sysisaverystandardoneimportsys#Gatherourcodeinamain()functiondefmain():print'Hellothere',sys.argv[1]#Command......
  • Python 3.9.1 中的 collections.abc.Callable 是否有 bug?
    Python3.9包含PEP585并弃用typing模块中的许多类型,转而支持collections.abc中的类型,现在它们支持__class_getitem__例如Callable就是这种情况。对我来说,typing.Callable和collections.abc.Ca......
  • 列表子类的 Python 类型
    我希望能够定义列表子类的内容必须是什么。该类如下所示。classA(list):def__init__(self):list.__init__(self)我想包含键入内容,以便发生以下情况。importtypingclassA(list:typing.List[str]):#Maybesomethinglikethisdef__init__(self):......
  • Python 中类型友好的委托
    考虑以下代码示例defsum(a:int,b:int):returna+bdefwrap(*args,**kwargs):#delegatetosumreturnsum(*args,**kwargs)该代码运行良好,只是类型提示丢失了。在Python中使用*args,**kwargs来实现​​委托模式是很常见的。如果有一种方法可......
  • 使用 python 支持构建自定义 vim 二进制文件
    背景Debian11vim软件包不包含python3支持。请参阅标题为“Debian11vim中不支持python-证据”的部分下面我需要vim支持python3YouCompleteMevim插件为了构建一个新的,我将vim9.0tarball下载到v......
  • 如何在Python 3.12+中正确使用泛型来提高代码质量?
    我正在尝试使用泛型来改进FastAPI应用程序中的类型注释。我有一个抽象存储库类,在其中使用泛型:fromabcimportABC,abstractmethodfromtypingimportListclassAbstractRepository[T](ABC):@abstractmethodasyncdefadd_one(self,data:dict)->T:......
  • python中的while循环不退出
    我试图完成第一年的python商业课程作业,但我的while循环无法退出,有人能帮忙吗?commisionTable=[{"admin_fee":100,"comm_rate":0.10},{"admin_fee":125,"comm_rate":0.12},{"admin_fee":150,"comm_rate":......