首页 > 其他分享 >使用YOLOv4训练DeepFashion2数据集详解

使用YOLOv4训练DeepFashion2数据集详解

时间:2024-12-15 21:27:39浏览次数:6  
标签:YOLOv4 训练 img -- train DeepFashion2 path 详解

文章目录

使用YOLOv4训练DeepFashion2数据集详解

一、引言

在计算机视觉领域,目标检测是一个非常重要的任务,而YOLO(You Only Look Once)系列模型因其速度快、性能好而广受欢迎。DeepFashion2数据集是一个大规模的服装数据集,包含了丰富的服装图像和标注信息,适合用来训练和测试目标检测模型。本文将详细介绍如何使用YOLOv4模型来训练DeepFashion2数据集,以达到对服装进行检测的目的。

二、准备工作

在这里插入图片描述

1、数据集和代码准备

首先,我们需要准备好DeepFashion2数据集和YOLOv4的代码。DeepFashion2数据集可以从其官方GitHub仓库获取。YOLOv4的代码我们可以使用bubbliiiing的PyTorch实现,这份实现包含了一些训练的技巧,比如Cosine scheduler learning rate,Mosaic,CutMix,label smoothing,CIoU等。

2、环境配置

确保你的Python环境已经安装了PyTorch、CUDA等必要的库,并且你的机器拥有足够的GPU资源来支持训练。

三、数据预处理

1、生成训练和验证集标签

我们需要将DeepFashion2数据集中的标注信息转换成YOLOv4可以识别的格式。可以使用以下命令生成训练集和验证集的标签文件:

python example.py --datasets COCO --img_path ./train/image/ --label train.json --convert_output_path YOLO/ --img_type ".jpg" --manipast_path train.txt --cls_list_file fashion_classes.txt
python example.py --datasets COCO --img_path ./validation/image/ --label valid.json --convert_output_path YOLO/ --img_type ".jpg" --manipast_path valid.txt --cls_list_file fashion_classes.txt

上述命令会生成train.txtvalid.txt两个文件,这两个文件包含了训练和验证集的路径和标注信息。

2、调整数据集路径

如果标签文件中的路径不是绝对路径,需要将数据集的图像按照以下结构组织,并放到项目路径下面:

- DeepFashion2
    - train
        - image
    - validation
        - image

四、模型训练

1、修改配置文件

由于我们有验证集的标注,需要修改train_with_tensorboard.py中的代码,以使用我们生成的训练集和验证集标签文件。具体修改如下:

[line 142]:
- annotation_path = '2007_train.txt'
+ train_path = 'train.txt'
+ val_path = 'valid.txt'

[line 179]:
- val_split = 0.1
- with open(annotation_path) as f:
-     lines = f.readlines()
- np.random.seed(10101)
- np.random.shuffle(lines)
- np.random.seed(None)
- num_val = int(len(lines)*val_split)
- num_train = len(lines) - num_val
+ with open(train_path) as f:
+     train_lines = f.readlines()
+ with open(val_path) as f1:
+     val_lines = f1.readlines()
+ np.random.seed(10101)
+ np.random.shuffle(train_lines)
+ np.random.shuffle(val_lines)
+ np.random.seed(None)
+ num_train = int(len(train_lines))
+ num_val = int(len(val_lines))

2、开始训练

完成上述修改后,可以按照项目的说明文档操作开始训练。训练过程中,可以监控TensorBoard来查看损失和准确率等指标。

五、使用示例

训练完成后,我们可以使用训练好的模型来进行预测。以下是使用YOLOv4模型进行预测的简单示例代码:

from models.common import DetectMultiBackend
from utils.general import non_max_suppression, scale_coords, xyxy2xywh
from utils.torch_utils import select_device, load_classifier

# Load model
device = select_device('')
model = DetectMultiBackend('yolov4/yolov4.pt', device=device, dnn=False)
stride, names, pt, jit, onnx = model.stride, model.names, model.pt, model.jit, model.onnx

# Load image
img_path = 'data/images/bus.jpg'
img = cv2.imread(img_path)  # BGR
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

# Inference
img = letterbox(img, new_shape=640)[0]
img = img[:, :, ::-1].transpose(2, 0, 1)
img = np.ascontiguousarray(img)

img = torch.from_numpy(img).to(device)
img = img.float()
if len(img.shape) == 3:
    img = img[None]  # expand for batch dim

pred = model(img, augment=False)[0]

# Apply NMS
pred = non_max_suppression(pred, 0.25, 0.45, classes=None, agnostic=False)

# Process detections
for i, det in enumerate(pred):  # detections per image
    if len(det):
        det[:, :4] = scale_coords(img.shape[2:], det[:, :4], img_path.shape).round()

        for *xyxy, conf, cls in reversed(det):
            label = f'{names[int(cls)]} {conf:.2f}'
            plot_one_box(xyxy, img_path, label=label, color=colors[int(cls)], line_thickness=3)

# Save and show results
cv2.imwrite('output.jpg', img)

六、总结

通过上述步骤,我们详细介绍了如何使用YOLOv4模型来训练DeepFashion2数据集。这个过程包括了数据准备、数据预处理、模型训练和使用示例。希望这篇文章能帮助你更好地理解和应用YOLOv4模型在服装检测任务上的应用。


版权声明:本博客内容为原创,转载请保留原文链接及作者信息。

参考文章

标签:YOLOv4,训练,img,--,train,DeepFashion2,path,详解
From: https://blog.csdn.net/NiNg_1_234/article/details/144493068

相关文章

  • MyBatis详解---关联映射
    目录引入一、创建表结构1.学生表2.老师表二、查询学生对应的老师1.第一种形式连表查询 ①:设置实体类②:查询语句2.第二种形式分步查询(分段查询--支持懒加载)①:设置实体类②:查询语句三、查询教师的学生(一对多)1.第一种形式:按照结果嵌套处理 ①.设置实体类②......
  • Kubernetes Service 详解:如何轻松管理集群中的服务
    KubernetesService详解:如何轻松管理集群中的服务在Kubernetes中,Service是一个非常核心的概念。它解决了容器之间的通信问题,确保了无论容器如何启动或销毁,服务都能保持稳定的访问方式。今天,我想通过一篇简单易懂的文章,带大家一起探讨一下Kubernetes中的Service,它的作用......
  • 微信native支付对接案例详解
    微信native支付对接案例详解效果展示native支付产品介绍接入前准备开发指引API列表支付通知开发者社区整体原则就是按照官方文档一步一步来支付产品微信认证注意:只有服务号才能对接微信支付。每年都需要花300块认证费用。......
  • OJ题目详解——1.8~05:计算鞍点
    描述给定一个5*5的矩阵,每行只有一个最大值,每列只有一个最小值,寻找这个矩阵的鞍点。鞍点指的是矩阵中的一个元素,它是所在行的最大值,并且是所在列的最小值。例如:在下面的例子中(第4行第1列的元素就是鞍点,值为8)。11356912478101056911864721510112025......
  • OJ题目详解——1.8~06:图像相似度
    描述给出两幅相同大小的黑白图像(用0-1矩阵)表示,求它们的相似度。说明:若两幅图像在相同位置上的像素点颜色相同,则称它们在该位置具有相同的像素点。两幅图像的相似度定义为相同像素点数占总像素点数的百分比。输入第一行包含两个整数m和n,表示图像的行数和列数,中间用单个空格......
  • OJ题目详解——1.8~11:图像旋转
    描述输入一个n行m列的黑白图像,将它顺时针旋转90度后输出。输入第一行包含两个整数n和m,表示图像包含像素点的行数和列数。1<=n<=100,1<=m<=100。接下来n行,每行m个整数,表示图像的每个像素点灰度。相邻两个整数之间用单个空格隔开,每个元素均在0~255之间。输出m行,每行......
  • OJ题目详解——1.8~14:扫雷游戏地雷数计算
    描述扫雷游戏是一款十分经典的单机小游戏。它的精髓在于,通过已翻开格子所提示的周围格地雷数,来判断未翻开格子里是否是地雷。现在给出n行m列的雷区中的地雷分布,要求计算出每个非地雷格的周围格地雷数。注:每个格子周围格有八个:上、下、左、右、左上、右上、左下、右下。输入......
  • 提货卡系统有哪些?功能特点与适用场景详解
    提货卡系统以其便捷性逐渐成为消费者青睐的选择。提货卡并非只是一个简单的购物工具,而是连接用户与商家的桥梁。很多人认为提货卡的存在仅仅是为了方便提货,其实其中蕴含的功能特点和系统类型远不止于此。提货卡系统的多样性让人倍感期待,究竟有哪些类型的提货卡系统值得关注?有......
  • 数据结构:Win32 API详解
    目录一.Win32API的介绍二.控制台程序(Console)与COORD1..控制台程序(Console):2.控制台窗口坐标COORD:3.GetStdHandle函数:(1)语法:(2)参数:4.GetConsoleCursorInfo函数:(1)语法:(2)参数:(3)CONSOLE_CURSOR_INFO结构体:5.SetConsoleCursorInfo函数:实例:6.SetConsoleCursorPosition......
  • 使用任务队列TaskQueue和线程池ThreadPool技术实现自定义定时任务框架详解
    前言在桌面软件开发中,定时任务是一个常见的需求,比如定时清理日志、发送提醒邮件或执行数据备份等操作。在C#中有一个非常著名的定时任务处理库Hangfire,不过在我们深入了解Hangfire之前,我们可以手动开发一个定时任务案例,用以帮助我们理解Hangfire的核心原理。我们可以利用......