首页 > 其他分享 >深度学习从入门到精通——VOC 2012数据读取(pytorch)

深度学习从入门到精通——VOC 2012数据读取(pytorch)

时间:2022-11-01 18:02:41浏览次数:64  
标签:xml VOC self torch boxes pytorch path data 2012


from torch.utils.data import Dataset
import os
import torch
import json
from PIL import Image
from lxml import etree


class VOC2012DataSet(Dataset):
"""读取解析PASCAL VOC2012数据集"""

def __init__(self, voc_root, transforms, txt_name: str = "train.txt"):
self.root = os.path.join(voc_root, "VOCdevkit", "VOC2012")
self.img_root = os.path.join(self.root, "JPEGImages")
self.annotations_root = os.path.join(self.root, "Annotations")

# read train.txt or val.txt file
txt_path = os.path.join(self.root, "ImageSets", "Main", txt_name)
assert os.path.exists(txt_path), "not found {} file.".format(txt_name)

with open(txt_path) as read:
self.xml_list = [os.path.join(self.annotations_root, line.strip() + ".xml")
for line in read.readlines()]

# check file
assert len(self.xml_list) > 0, "in '{}' file does not find any information.".format(txt_path)
for xml_path in self.xml_list:
assert os.path.exists(xml_path), "not found '{}' file.".format(xml_path)

# read class_indict
try:
json_file = open('./pascal_voc_classes.json', 'r')
self.class_dict = json.load(json_file)
except Exception as e:
print(e)
exit(-1)

self.transforms = transforms

def __len__(self):
return len(self.xml_list)

def __getitem__(self, idx):
# read xml
xml_path = self.xml_list[idx]
with open(xml_path) as fid:
xml_str = fid.read()
xml = etree.fromstring(xml_str)
data = self.parse_xml_to_dict(xml)["annotation"]
img_path = os.path.join(self.img_root, data["filename"])
image = Image.open(img_path)
if image.format != "JPEG":
raise ValueError("Image format not JPEG")
boxes = []
labels = []
iscrowd = []
for obj in data["object"]:
xmin = float(obj["bndbox"]["xmin"])
xmax = float(obj["bndbox"]["xmax"])
ymin = float(obj["bndbox"]["ymin"])
ymax = float(obj["bndbox"]["ymax"])
boxes.append([xmin, ymin, xmax, ymax])
labels.append(self.class_dict[obj["name"]])
iscrowd.append(int(obj["difficult"]))

# convert everything into a torch.Tensor
boxes = torch.as_tensor(boxes, dtype=torch.float32)
labels = torch.as_tensor(labels, dtype=torch.int64)
iscrowd = torch.as_tensor(iscrowd, dtype=torch.int64)
image_id = torch.tensor([idx])
area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])

target = {}
target["boxes"] = boxes
target["labels"] = labels
target["image_id"] = image_id
target["area"] = area
target["iscrowd"] = iscrowd

if self.transforms is not None:
image, target = self.transforms(image, target)

return image, target

def get_height_and_width(self, idx):
# read xml
xml_path = self.xml_list[idx]
with open(xml_path) as fid:
xml_str = fid.read()
xml = etree.fromstring(xml_str)
data = self.parse_xml_to_dict(xml)["annotation"]
data_height = int(data["size"]["height"])
data_width = int(data["size"]["width"])
return data_height, data_width

def parse_xml_to_dict(self, xml):
"""
将xml文件解析成字典形式,参考tensorflow的recursive_parse_xml_to_dict
Args:
xml: xml tree obtained by parsing XML file contents using lxml.etree

Returns:
Python dictionary holding XML contents.
"""

if len(xml) == 0: # 遍历到底层,直接返回tag对应的信息
return {xml.tag: xml.text}

result = {}
for child in xml:
child_result = self.parse_xml_to_dict(child) # 递归遍历标签信息
if child.tag != 'object':
result[child.tag] = child_result[child.tag]
else:
if child.tag not in result: # 因为object可能有多个,所以需要放入列表里
result[child.tag] = []
result[child.tag].append(child_result[child.tag])
return {xml.tag: result}

def coco_index(self, idx):
"""
该方法是专门为pycocotools统计标签信息准备,不对图像和标签作任何处理
由于不用去读取图片,可大幅缩减统计时间

Args:
idx: 输入需要获取图像的索引
"""
# read xml
xml_path = self.xml_list[idx]
with open(xml_path) as fid:
xml_str = fid.read()
xml = etree.fromstring(xml_str)
data = self.parse_xml_to_dict(xml)["annotation"]
data_height = int(data["size"]["height"])
data_width = int(data["size"]["width"])
# img_path = os.path.join(self.img_root, data["filename"])
# image = Image.open(img_path)
# if image.format != "JPEG":
# raise ValueError("Image format not JPEG")
boxes = []
labels = []
iscrowd = []
for obj in data["object"]:
xmin = float(obj["bndbox"]["xmin"])
xmax = float(obj["bndbox"]["xmax"])
ymin = float(obj["bndbox"]["ymin"])
ymax = float(obj["bndbox"]["ymax"])
boxes.append([xmin, ymin, xmax, ymax])
labels.append(self.class_dict[obj["name"]])
iscrowd.append(int(obj["difficult"]))

# convert everything into a torch.Tensor
boxes = torch.as_tensor(boxes, dtype=torch.float32)
labels = torch.as_tensor(labels, dtype=torch.int64)
iscrowd = torch.as_tensor(iscrowd, dtype=torch.int64)
image_id = torch.tensor([idx])
area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])

target = {}
target["boxes"] = boxes
target["labels"] = labels
target["image_id"] = image_id
target["area"] = area
target["iscrowd"] = iscrowd

return (data_height, data_width), target

@staticmethod
def collate_fn(batch):
return tuple(zip(*batch))

if __name__ == '__main__':
data = VOC2012DataSet(r"D:/",transforms=None)
print(data[0][1]["boxes"])

深度学习从入门到精通——VOC 2012数据读取(pytorch)_深度学习


标签:xml,VOC,self,torch,boxes,pytorch,path,data,2012
From: https://blog.51cto.com/u_13859040/5814627

相关文章

  • 目标检测之FasterRcnn算法——训练自己的数据集(pytorch)
    数据集数据集目录如上,VOC数据集的格式JPEGImages目录下,放上自己的训练集和测试集Annotations下,放上自己的xml文档配置,如上。在VOCdevkit\VOC2012\ImageSets\Main下,放上自己......
  • Pytorch网络结构可视化
    Pytorch网络结构可视化1.1可视化工具Netron Netron是一种支持神经网可视化络化的工具,在实际的项目中,经过会遇到各种网络模型,需要快速去了解网络结构。如果单纯的去看......
  • pytorch 生态
    torchvision.datasets*torchvision.models*CaltechCelebACIFARCityscapesEMNISTFakeDataFashion-MNISTFlickrImageNetKinetics-400KITTIKMNISTPhotoTourPl......
  • PyTorch: 张量的变换、数学运算及线性回归
    本文已收录于Pytorch系列专栏:​​Pytorch入门与实践​​专栏旨在详解Pytorch,精炼地总结重点,面向入门学习者,掌握Pytorch框架,为数据分析,机器学习及深度学习的代码能力打下......
  • 安装 pytorch
    如果电脑有GPU,就安装  pytorchGPU版本,可以加速如果没有就安装CPU版本,执行速度可能会慢首先,安装的pytorch版本与python版本有关系,对应关系如下:其次,PyTorch版本要根......
  • 使用上下文装饰器调试Pytorch的内存泄漏问题
    装饰器是python上下文管理器的特定实现。本片文章将通过一个pytorchGPU调试的示例来说明如何使用它们。虽然它可能不适用于所有情况,但我它们却是非常有用。调试内存......
  • 深度学习论文: MOAT: Alternating Mobile Convolution and Attention Brings Strong V
    深度学习论文:MOAT:AlternatingMobileConvolutionandAttentionBringsStrongVisionModels及其PyTorch实现MOAT:AlternatingMobileConvolutionandAttentionB......
  • 20201208史逸霏第六章学习笔记
    6.1~6.3信号和中断中断:中断是I/O设备发送到CPU的外部请求,将CPU从正常执行转移到中断处理。信号:信号是发送给进程的请求,将进程从正常执行转移到中断处理。中断的类型:......
  • WIN2012远程桌面授权服务器许可证问题解决方法
    WIN2012服务器报错为由于没有远程桌面授权服务器可以提供许可证,远程会话连接已断开。请跟服务器管理员联系。原因是服务器安装了远程桌面服务RemoteApp,这个是需要授权的。微......
  • 【PyTorch】 torch.flatten()与nn.Flatten()的区别
    问题torch.flatten()与nn.Flatten()都可以实现展开Tensor,那么二者的区别是什么呢?方法经过查阅相关资料,发现二者主要区别有:(1)默认的dim不同,torch.flatten()默认的dim=0,而n......