首页 > 其他分享 >pytorch不定长数据的dataloader读取

pytorch不定长数据的dataloader读取

时间:2023-02-27 18:12:37浏览次数:70  
标签:读取 dataloader batch padding pytorch collate fn bbox

  参考资料:

  https://pytorch.org/docs/stable/data.html#dataloader-collate-fn

  https://blog.csdn.net/anshiquanshu/article/details/112868740

  在使用Pytorch深度学习框架的时候,一定绕不开的就是dataset和dataloader,后者依赖于前者,并给出了高效加载数据的解决方案(多线程,batch训练等)。

  以RGB图片为例,dataset出来的数据形状是(3, H, W),而dataloader出来的数据形状是(batch_size, 3, H, W)。很明显,多了一维即batch维度。这显然是dataloader将数据给“叠”了起来。事实上,dataloader是有一个参数为collate_fn的,它的默认值是None,即当你在使用dataloader并不指定collate_fn的时候,实际上pytorch调用了默认的collate_fn函数,将数据“叠”起来之后再交给你。

  然而,当你的数据是不定长的数据的时候,它就没有办法成功把数据叠起来,比如我就遇到了如下报错:

  RuntimeError: stack expects each tensor to be equal size, but got [2, 4] at entry 0 and [5, 4] at entry 1

  一个数据长度为2,一个数据长度为5,显然无法直接stack?此时在面对不定长数据的时候就需要自定义collate_fn进行填充了。譬如,pytorch文档上有这么一段话:

  A custom collate_fn can be used to customize collation, e.g., padding sequential data to max length of a batch.

  那么,如何自定义一个collate_fn?这个collate_fn的输入和输出又是什么?我们来看一下这个例子:

def padding_collate_fn(data_batch):
    batch_bbox_list = [item['bbox'] for item in data_batch]
    batch_label_list = [item['label'] for item in data_batch]
    batch_filename_list = [item['filename'] for item in data_batch]
    
    padding_bbox = pad_sequence(batch_bbox_list, batch_first=True, padding_value=0)
    padding_label = pad_sequence(batch_bbox_list, batch_first=True, padding_value=5)
    
    result = dict()
    result["bbox"] = padding_bbox
    result["label"] = padding_label
    result["filename"] = batch_filename_list
    
    return result

  首先我原始的dataset输出是一个字典,上述代码就是把字典中的值取出来再叠起来,最后放到大字典中返回。其中pad_sequence这个函数在torch.nn.utils.rnn这个包里,很好用。

  实际上,batch就是你的dataset[index] ~ dataset[index + batch_size] 构成的列表,知道这一点后问题就迎刃而解了。

标签:读取,dataloader,batch,padding,pytorch,collate,fn,bbox
From: https://www.cnblogs.com/chester-cs/p/17160501.html

相关文章