问题描述
在学习使用Pytorch进行目标检测任务时,会出现和分类任务有很大区别的一点。在进行分类任务时,当指定了图像的大小,那么Dataset中每张图的张量大小都是一致的。而在目标检测任务中,在一张图上可以只有一个目标,也可以有多个目标,在Dataset中张量大小不一致会报错,例如:
RuntimeError: stack expects each tensor to be equal size, but got [1, 4] at entry 0 and [2, 4] at entry 1
在学习动手学CV-Pytorch时,发现该问题通过以下方式解决:(该部分省略与改写了部分代码)
class VOCDataset(Dataset):
# ...
def collate_fn(self):
images = [i[0] for i in self]
boxes = [i[1] for i in self]
labels = [i[2] for i in self]
difficulties = [i[3] for i in self]
images = torch.stack(images, dim=0)
return images, boxes, labels, difficulties
# ...
train_loader = DataLoader(
VOCDataset(),
shuffle=True,
batch_size=32,
collate_fn=VOCDataset.collate_fn,
pin_memory=True
)
我们看到,该代码片使用了collate_fn
这个参数并调用类中相关方法解决了这个问题,我们来探究下这个参数到底有什么用。
解决方法
我们首先探究下collate_fn
这个参数,根据官方文档的描述与官方Discuss的问题,这个参数可以在接收来自__getitem()__
的数据后重新规整再输出,该Discuss表明其对变长数据的处理会有非常大的帮助,我们通过例子来解释下这个参数到底在做什么。该例源自这篇文章并做了些修改。
我们先定义一个矩阵和它所对应的label
li = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
matrix = torch.tensor([li[i:i + 3] for i in range(10)])
label = torch.tensor([li[i:i + 1] for i in range(10)])
print('matrix:', matrix)
print('label:', label)
# >>> matrix: tensor([[ 0, 1, 2],
# [ 1, 2, 3],
# [ 2, 3, 4],
# [ 3, 4, 5],
# [ 4, 5, 6],
# [ 5, 6, 7],
# [ 6, 7, 8],
# [ 7, 8, 9],
# [ 8, 9, 10],
# [ 9, 10, 11]])
# >>> label: tensor([[0],
# [1],
# [2],
# [3],
# [4],
# [5],
# [6],
# [7],
# [8],
# [9]])
接下来我们写一个非常简单的Dataset
class LiDataset(Dataset):
def __init__(self, param1, param2):
self.param1 = param1
self.param2 = param2
def __getitem__(self, item):
return self.param1[item], self.param2[item]
def __len__(self):
return len(self.param1)
def collect_fn(self):
p1 = [i[0] for i in self]
p2 = [i[1] for i in self]
return p1, p2
我们用DataLoader装载下这个Dataset,这也是我们在分类任务中见到的最基础的写法
print('WITH OUT collate_fn:')
dataset1 = DataLoader(
LiDataset(matrix, label),
batch_size=3
)
for i in dataset1:
print(i)
# >>> WITH OUT collate_fn:
# [
# tensor([[0, 1, 2],
# [1, 2, 3],
# [2, 3, 4]]),
# tensor([[0],
# [1],
# [2]])
# ]
# [
# tensor([[3, 4, 5],
# [4, 5, 6],
# [5, 6, 7]]),
# tensor([[3],
# [4],
# [5]])
# ]
# [
# tensor([[6, 7, 8],
# [7, 8, 9],
# [8, 9, 10]]),
# tensor([[6],
# [7],
# [8]])]
# [
# tensor([[9, 10, 11]]),
# tensor([[9]])
# ]
这是我们最常见的输出,整个数据集被划分为多个batch,每个batch里有3条数据。虽然没有指定collate_fn
,但其实这时是调用了官方默认的defaultcollate_fn
并已经帮我们重组成我们现在所看到的样子,这时我们用lambda x: x
定义这个参数来看看这个它原本到底是个啥样。
print('WITH lambda collate_fn:')
dataset2 = DataLoader(
LiDataset(matrix, label),
batch_size=3,
collate_fn=lambda x: x
)
for i in dataset2:
print(i)
# >>> WITH lambda collate_fn:
# [(tensor([0, 1, 2]), tensor([0])), (tensor([1, 2, 3]), tensor([1])), (tensor([2, 3, 4]), tensor([2]))]
# [(tensor([3, 4, 5]), tensor([3])), (tensor([4, 5, 6]), tensor([4])), (tensor([5, 6, 7]), tensor([5]))]
# [(tensor([6, 7, 8]), tensor([6])), (tensor([7, 8, 9]), tensor([7])), (tensor([ 8, 9, 10]), tensor([8]))]
# [(tensor([ 9, 10, 11]), tensor([9]))]
这时,我们可以清楚的看到它原本的模样了。每个batch的数据通过__getitem()__
传输过来一个列表,而每个列表由return
的元素组成batch_size个元组。如下所示:
[(tensor([0, 1, 2]), tensor([0])), (tensor([1, 2, 3]), tensor([1])), (tensor([2, 3, 4]), tensor([2]))]
^ ^ ^ ^ ^ ^
(matrix label) (matrix label) (matrix label)
也就是说,原始的数据为上面这种形式,在经过DataLoader
的collate_fn
后可以重组数据的输出格式。这时候我们回到开头Dataset类LiDataset
中的自定义方法collate_fn
,其定义如下:
def collect_fn(self):
p1 = [i[0] for i in self]
p2 = [i[1] for i in self]
该方法的使用是collate_fn=LiDataset.collect_fn
,完整代码见下面。根据上面所说的,来解释下这个方法。self
即为每一个batch,for i in self
则遍历一个batch中的所有元组,i[0]
为__getitem()__
中return中的第一个参数,这里即为matrix,同理i[1]
为label。这里将matrix整合为一个列表,label整合为一个列表,那么一个batch则为两个列表组成的元组,长度便固定为2了。
print('WITH modified collate_fn:')
dataset3 = DataLoader(
LiDataset(matrix, label),
batch_size=3,
collate_fn=LiDataset.collect_fn
)
for i in dataset3:
print(i)
# >>> WITH modified collate_fn:
# ([tensor([0, 1, 2]), tensor([1, 2, 3]), tensor([2, 3, 4])], [tensor([0]), tensor([1]), tensor([2])])
# ([tensor([3, 4, 5]), tensor([4, 5, 6]), tensor([5, 6, 7])], [tensor([3]), tensor([4]), tensor([5])])
# ([tensor([6, 7, 8]), tensor([7, 8, 9]), tensor([ 8, 9, 10])], [tensor([6]), tensor([7]), tensor([8])])
# ([tensor([ 9, 10, 11])], [tensor([9])])
标签:__,tensor,检测,self,目标,collate,一致,label,fn
From: https://www.cnblogs.com/ToryRegulus/p/17506519.html