首页 > 其他分享 >目标检测中目标数量不一致的解决方法

目标检测中目标数量不一致的解决方法

时间:2023-06-26 19:12:48浏览次数:42  
标签:__ tensor 检测 self 目标 collate 一致 label fn

问题描述

  在学习使用Pytorch进行目标检测任务时,会出现和分类任务有很大区别的一点。在进行分类任务时,当指定了图像的大小,那么Dataset中每张图的张量大小都是一致的。而在目标检测任务中,在一张图上可以只有一个目标,也可以有多个目标,在Dataset中张量大小不一致会报错,例如:

RuntimeError: stack expects each tensor to be equal size, but got [1, 4] at entry 0 and [2, 4] at entry 1

在学习动手学CV-Pytorch时,发现该问题通过以下方式解决:(该部分省略与改写了部分代码)

class VOCDataset(Dataset):
    # ...
    
    def collate_fn(self):
        images = [i[0] for i in self]
        boxes = [i[1] for i in self]
        labels = [i[2] for i in self]
        difficulties = [i[3] for i in self]

        images = torch.stack(images, dim=0)

        return images, boxes, labels, difficulties
    
    # ...

train_loader = DataLoader(
        VOCDataset(),
        shuffle=True,
        batch_size=32,
        collate_fn=VOCDataset.collate_fn,
        pin_memory=True
    )

  我们看到,该代码片使用了collate_fn这个参数并调用类中相关方法解决了这个问题,我们来探究下这个参数到底有什么用。

解决方法

  我们首先探究下collate_fn这个参数,根据官方文档的描述与官方Discuss的问题,这个参数可以在接收来自__getitem()__的数据后重新规整再输出,该Discuss表明其对变长数据的处理会有非常大的帮助,我们通过例子来解释下这个参数到底在做什么。该例源自这篇文章并做了些修改。

  我们先定义一个矩阵和它所对应的label

li = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
matrix = torch.tensor([li[i:i + 3] for i in range(10)])
label = torch.tensor([li[i:i + 1] for i in range(10)])
print('matrix:', matrix)
print('label:', label)


# >>> matrix: tensor([[ 0,  1,  2],
#                     [ 1,  2,  3],
#                 	  [ 2,  3,  4],
#                 	  [ 3,  4,  5],
#                 	  [ 4,  5,  6],
#                 	  [ 5,  6,  7],
#                 	  [ 6,  7,  8],
#                 	  [ 7,  8,  9],
#                 	  [ 8,  9, 10],
#                 	  [ 9, 10, 11]])
# >>> label: tensor([[0],
#               	 [1],
#               	 [2],
#              		 [3],
#               	 [4],
#                	 [5],
#               	 [6],
#               	 [7],
#               	 [8],
#               	 [9]])

  接下来我们写一个非常简单的Dataset

class LiDataset(Dataset):
    def __init__(self, param1, param2):
        self.param1 = param1
        self.param2 = param2

    def __getitem__(self, item):
        return self.param1[item], self.param2[item]

    def __len__(self):
        return len(self.param1)

    def collect_fn(self):
        p1 = [i[0] for i in self]
        p2 = [i[1] for i in self]

        return p1, p2

  我们用DataLoader装载下这个Dataset,这也是我们在分类任务中见到的最基础的写法

print('WITH OUT collate_fn:')
dataset1 = DataLoader(
    LiDataset(matrix, label),
    batch_size=3
)

for i in dataset1:
    print(i)
    
    
# >>> WITH OUT collate_fn:
# [
#     tensor([[0, 1, 2],
#             [1, 2, 3],
#             [2, 3, 4]]),
#     tensor([[0],
#             [1],
#             [2]])
# ]
# [
#     tensor([[3, 4, 5],
#             [4, 5, 6],
#             [5, 6, 7]]),
#     tensor([[3],
#             [4],
#             [5]])
# ]
# [
#     tensor([[6,  7,  8],
#             [7,  8,  9],
#             [8,  9, 10]]),
#     tensor([[6],
#             [7],
#             [8]])]
# [
#     tensor([[9, 10, 11]]),
#     tensor([[9]])
# ]

  这是我们最常见的输出,整个数据集被划分为多个batch,每个batch里有3条数据。虽然没有指定collate_fn,但其实这时是调用了官方默认的defaultcollate_fn并已经帮我们重组成我们现在所看到的样子,这时我们用lambda x: x定义这个参数来看看这个它原本到底是个啥样。

print('WITH lambda collate_fn:')
dataset2 = DataLoader(
    LiDataset(matrix, label),
    batch_size=3,
    collate_fn=lambda x: x
)

for i in dataset2:
    print(i)


# >>> WITH lambda collate_fn:
# [(tensor([0, 1, 2]), tensor([0])), (tensor([1, 2, 3]), tensor([1])), (tensor([2, 3, 4]), tensor([2]))]
# [(tensor([3, 4, 5]), tensor([3])), (tensor([4, 5, 6]), tensor([4])), (tensor([5, 6, 7]), tensor([5]))]
# [(tensor([6, 7, 8]), tensor([6])), (tensor([7, 8, 9]), tensor([7])), (tensor([ 8,  9, 10]), tensor([8]))]
# [(tensor([ 9, 10, 11]), tensor([9]))]

  这时,我们可以清楚的看到它原本的模样了。每个batch的数据通过__getitem()__传输过来一个列表,而每个列表由return的元素组成batch_size个元组。如下所示:

[(tensor([0, 1, 2]), tensor([0])), (tensor([1, 2, 3]), tensor([1])), (tensor([2, 3, 4]), tensor([2]))]
  ^                  ^              ^                  ^              ^                  ^
 (matrix             label)        (matrix             label)        (matrix             label)

  也就是说,原始的数据为上面这种形式,在经过DataLoadercollate_fn后可以重组数据的输出格式。这时候我们回到开头Dataset类LiDataset中的自定义方法collate_fn,其定义如下:

def collect_fn(self):
    p1 = [i[0] for i in self]
    p2 = [i[1] for i in self]

  该方法的使用是collate_fn=LiDataset.collect_fn,完整代码见下面。根据上面所说的,来解释下这个方法。self即为每一个batch,for i in self则遍历一个batch中的所有元组,i[0]__getitem()__中return中的第一个参数,这里即为matrix,同理i[1]为label。这里将matrix整合为一个列表,label整合为一个列表,那么一个batch则为两个列表组成的元组,长度便固定为2了。

print('WITH modified collate_fn:')
dataset3 = DataLoader(
    LiDataset(matrix, label),
    batch_size=3,
    collate_fn=LiDataset.collect_fn
)

for i in dataset3:
    print(i)
    

# >>> WITH modified collate_fn:
# ([tensor([0, 1, 2]), tensor([1, 2, 3]), tensor([2, 3, 4])], [tensor([0]), tensor([1]), tensor([2])])
# ([tensor([3, 4, 5]), tensor([4, 5, 6]), tensor([5, 6, 7])], [tensor([3]), tensor([4]), tensor([5])])
# ([tensor([6, 7, 8]), tensor([7, 8, 9]), tensor([ 8,  9, 10])], [tensor([6]), tensor([7]), tensor([8])])
# ([tensor([ 9, 10, 11])], [tensor([9])])

标签:__,tensor,检测,self,目标,collate,一致,label,fn
From: https://www.cnblogs.com/ToryRegulus/p/17506519.html

相关文章

  • 科技项目验收测试规范有哪些?靠谱第三方软件检测机构推荐
    随着科技的不断发展和进步,越来越多的科技项目被投入使用。为了保证这些科技项目的质量,需要进行验收测试。科技项目验收测试是一项非常重要的工作,其结果对项目的质量和功能正常使用有着直接的影响。本文将就科技项目验收测试规范和第三方软件检测机构的资质进行探讨。一、科......
  • 缓存与DB数据一致性问题解决的几个思路
    使用缓存必然会碰到缓存跟真实数据不一致的问题,虽然我们会在数据发生变化时通知缓存,但是这个延迟时间内必然会导致数据不一致,如何解决一般有下面几个思路:首先,当这个延迟如果在业务上时可以接受的,比如文章阅读、评论次数这样的缓存数据,这样的问题这里不考虑。 类似数据库分布式事务......
  • MATLAB车道偏离检测,车道线检测 这段程序主要是对图像进行处理和分析,用于检测车道线
    MATLAB车道偏离检测,车道线检测这段程序主要是对图像进行处理和分析,用于检测车道线并计算车辆的偏离率。下面我将逐步解释代码的功能和工作流程。首先,程序进行了一些初始化操作,定义了一些变量,并读取了一张图片。接下来,程序对图像进行了一系列处理步骤,包括图像切割、灰度化、滤波......
  • opencv 表识别 工业表智能识别 数字式表盘识别,指针式表盘刻度识别,分为表检测,表盘纠正,
    在工业表智能识别中,OpenCV被用于数字式表盘和指针式表盘的识别。这个过程可以分为几个步骤:表的检测、表盘的纠正、刻度的分割、刻度的拉直识别,以及指针和时刻的分割。首先,通过表的检测,确定表在图像中的位置。然后,对表盘进行纠正,将圆形表盘拉直成一条线,以便后续处理。接下来,进行刻度......
  • 目标字符串驼峰化处理
    功能函数的设计初衷是将目标字符串驼峰化的api:比如CSS样式特性与JavaScipt样式属性的切换  background-color与style.backgroundColorfont-weight与fontWeightfont-family与fontFamily  ~~~~~~~~~~~~~~  /**toCamelCase--将目标字符串进行驼峰化处理**@func......
  • BASE最终一致性
    BASE(BasicallyAvailable,SoftState,EventuallyConsistent)是一种分布式系统设计原则,它与传统的ACID(Atomicity,Consistency,Isolation,Durability)模型相对应。在构建大规模、高可用性的分布式系统时,BASE的设计原则被广泛采用。BASE所强调的最终一致性,是指系统中的数据最终......
  • 水质传感器和水质检测传感器有哪些
    水质传感器又称水质检测传感器、水质监测传感器,风途水质传感器是检测水质参数的仪器,包括很多种传感器,可以实时监测水体中的溶解氧、pH值、电导率、浊度、温度、总磷、总氮等多种参数测量。不同的行业对检测的要求不同,水质传感器的选择也不同。以下是常用的6种水质传感器。水质传感......
  • 课程介绍:YOLOv8实战火焰和烟雾检测 (视频教程)
    课程链接:https://edu.51cto.com/course/34090.htmlYOLOv8基于先前YOLO版本在目标检测任务上的成功,进一步提升性能和灵活性。本课程将手把手地教大家使用YOLOv8结合可变形卷积(deformableconvolution)训练火焰和烟雾数据集,完成一个多目标检测实战项目,可实时检测图像、视频、摄像......
  • 缓存一致性如何保障
    缓存在现代应用程序中被广泛使用,用于提高性能和降低对后端数据存储系统的负载。然而,使用缓存也带来了一个重要问题:缓存一致性。在分布式系统中,缓存一致性成为了一个挑战,因为我们需要确保缓存中的数据与后端数据存储系统的数据保持同步,以避免数据不一致的情况发生。CacheAsidePa......
  • 成功实现脚本检测手机号是否注册imessage的原理
    一、imessages数据检测的两种方式:1.人工筛选,将要验证的号码输出到文件中,以逗号分隔。再将文件中的号码粘贴到iMessage客户端的地址栏,iMessage客户端会自动逐个检验该号码是否为iMessage账号,检验速度视网速而定。红色表示不是iMessage账号,蓝色表示iMessage账号。2.编写苹果操作系......