首页 > 编程语言 >YOLOP 多任务算法详解

YOLOP 多任务算法详解

时间:2024-11-22 12:41:08浏览次数:1  
标签:检测 py 损失 epoch 详解 YOLOP 格式 多任务

YOLOP 是华中科技大学研究团队在 2021 年开源的研究成果,其将目标检测/可行驶区域分割和车道线检测三大视觉任务同时放在一起处理,并且在 Jetson TX2 开发板子上能够达到 23FPS。

论文标题:YOLOP You Only Look Once for Panoptic Driving Perception

论文地址:https://arxiv.org/abs/2108.11250

官方代码:https://github.com/hustvl/YOLOP

01 网络结构

YOLOP 的核心亮点就是多任务学习,而各部分都是拿其它领域的成果进行缝合,其网络结构如下图所示:

三个子任务共用一个Backbone和Neck,然后分出来三个头来执行不同的任务。

1.1 Encoder

根据论文所述,整个网络可以分成一个 Encoder 和 3 个 Decoder。

Encoder 包含 Backbone 和 Neck,Backbone 照搬了 YOLOv4 所采用的 CSPDarknet,Neck 也和 YOLOv4 类似,使用了空间金字塔(SPP)模块和特征金字塔网络(FPN)模块。

1.2 Decoders

Decoders 即三个任务头:

  1. Detect Head

    目标检测头使用了 Path Aggregation Network (PAN)结构,这个结构可以将多个尺度特征图的特征图进行融合,其实还是 YOLOv4 那一套。

  2. Drivable Area Segment Head & Lane Line Segment Head

    可行驶区域分割头和车道线检测头都属于语义分割任务,因此 YOLOP 使用了相同的网络结构,经过三次上采样,将输出特征图恢复为(W, H, 2)的大小,再进行具体任务的处理。

1.3 Loss Function

损失函数包括三部分,即三个任务的损失。

  1. 目标检测损失

    目标检测是直接照搬 YOLOv4 的,因此和 YOLOv4 采用的损失一样,经典的边界框损失、目标损失和类别损失,各自加了个权重。

  2. 语义分割损失

  1. 总体损失,总体损失为三部分损失之和:

02 代码结构

03 训练--tools/train.py

3.1 设置 DDP 参数

pytorch 中 DDP 使用:

(1)参数加载;

(2)模型转换成 DDP 模型;

(3)训练数据 sampler,来使得各个进程上的数据各不相同;

(4)分布式模型的保存。

3.2 读取网络结构

models/YOLOP.py

3.3 定义损失函数及优化器

core/loss.py    utils/utils.py

3.4 网络结构划分

用于单任务训练固定其他网络部分层。

3.5 初始化学习率

后续在 train()中 warmup 会调整学习率。

首先定义一个优化器,定义好优化器以后,就可以给这个优化器绑定一个指数衰减学习率控制器。

(1) torch.optim.lr_scheduler.LambdaLR  
语法:class torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=-1)

参数:

      optimizer (Optimizer):要更改学习率的优化器,sgd或adam;
      lr_lambda(function or list):根据epoch计算λ \lambdaλ的函数;或者是一个list的这样的   function,分别计算各个parameter groups的学习率更新用到的λ \lambdaλ;
      last_epoch (int):最后一个epoch的index,如果是训练了很多个epoch后中断了,继续训练,这个值就等于加载的模型的epoch。默认为-1表示从头开始训练,即从epoch=1开始。

04 dataset/bdd.py 文件

4.1 数据读取

1.该文件继承 AutoDriveDataset.py。

2、按比例缩放操作:letterbox()图像增加灰边

3、数据增强操作

4.2 数据增强

utils/utils.py 文件:

  • random_perspective()放射变换增强
  • augment_hsv()颜色 HSV 通道增强
  • cutout()

05 models/YOLOP.py 文件

YOLOP 包括三个检测任务,目标检测+可行驶区域检测+车道线检测。

06 损失函数

loss.py postprocess.py

build_targets 思想:

build_targets 主要为了拿到所有 targets(扩充了周围 grids)对应的类别,框,batch 中图片数索引和 anchor 索引,以及具体的 anchors。

每个 gt 按照正样本选取策略,生成相应的 5 个框,再根据与默认 anchor 匹配,计算宽高的比例值,根据阈值过滤不相符的框,得到最终正样本。

#[b, a, gj, gi]为shape=54的向量,pi为[4,3,48,80,6]维矩阵,从pi中按照b, a, gj, gi的索引挑出想要的目标,最终为[54,6]维ps = pi[b, a, gj, gi]  # prediction subset corresponding to targets 。b, a, gj, gi为索引值,在pi中挑

6.1 目标检测损失

predictions[0] 目标检测分支[[4,3,48,80,6],[4,3,24,40,6],[4,3,12,20,6]]。

targets[0] 目标检测标签 [32,6],格式为[batch_num,class,x1,y1,x2,y2]。根据 build_targets 在每个检测层生成 相 应的正样本 tbox[]。

将每层的预测结果 tensor pi 根据正样本格式得到 ps = pi[b, a, gj, gi]。

计算每个检测层预测与正样本之间的 ciou 坐标损失。

iou = bbox_iou(pbox.T, tbox[i], x1y1x2y2=False, CIoU=True)  # iou(prediction, target)
 lbox += (1.0 - iou).mean()  # iou loss 坐标损失

obj 损失:

cls 类别损失:

6.2 可行驶区域损失

6.3 车道线损失

07 网络模型输出格式形式

7.1 网络模型检测输出格式

det_out:障碍物检测输出格式:[25200,6] 其中 6 表示[x1,y1,x2,y2,conf,cls],25200 :(80x80+40x40+20x20)x3。

7.2 网络模型车道线输出格式

lane_line_seg : 车道线分割输出格式:1,2,640,640。

7.3 网络模型可行驶区域输出格式

drive_area_seg : 可行驶区域分割输出格式:1,2,640,640。

08 前视停车场数据集检测效果图

标签:检测,py,损失,epoch,详解,YOLOP,格式,多任务
From: https://www.cnblogs.com/horizondeveloper/p/18562547

相关文章

  • Java教程:SE进阶【十万字详解】(中)
    ✨博客主页:https://blog.csdn.net/m0_63815035?type=blog......
  • 【C++】右值引用与移动语义详解:如何利用万能引用实现完美转发
    C++语法相关知识点可以通过点击以下链接进行学习一起加油!命名空间缺省参数与函数重载C++相关特性类和对象-上篇类和对象-中篇类和对象-下篇日期类C/C++内存管理模板初阶String使用String模拟实现Vector使用及其模拟实现List使用及其模拟实现容器适配器Stack与QueuePriority......
  • 做大模型备案的企业看过来,详解大模型备案重难点【评估测试题+备案源文件】
            通过对大模型备案所涉及的测试题、安全评估报告以及其他相关材料的深入分析,本文详细探讨了大模型备案过程中的关键点和难点问题。我们不仅审视了备案流程中的各个环节,还对可能遇到的挑战进行了全面的讨论,以确保大模型的安全性和合规性。文章目录(一)适用主体(......
  • 【机器学习】解锁AI密码:神经网络算法详解与前沿探索
    ......
  • git使用详解
     一、git介绍1、git简介Git是一个开源的分布式版本控制系统(最先进的,没有之一),用于敏捷高效地处理任何或小或大的项目。Git是LinusTorvalds为了帮助管理Linux内核开发而开发的一个开放源码的版本控制软件。Git与常用的版本控制工具CVS,Subversion(SVN)等不同,它采......
  • Java反序列化-Commons Collections4利用链详解
    前言CC4的构造方式与CC3相似,主要的区别在于触发反序列化的方式不同。CC4通过使用PriorityQueue(优先队列)来触发反序列化,而恶意代码加载方式依旧沿用了CC3。exp:TemplatesImpltemplates=newTemplatesImpl();Classc=templates.getClass();FieldnameField=c......
  • STM32相关知识——DMA的基本概念与工作原理详解
    STM32相关知识——DMA的基本概念与工作原理详解目录什么是DMADMA的作用DMA与CPU的区别DMA的工作原理DMA控制器数据传输流程DMA传输模式优先级和通道管理STM32中DMA的应用外设与内存之间的数据传输内存与内存的数据传输示例应用场景数学公式数据传输速率计算总线带宽......
  • STM32定时器知识——看门狗详解
    STM32定时器知识——看门狗详解目录引言STM32看门狗概述看门狗的工作原理看门狗的主要组成4.1独立看门狗(IWDG)4.2窗口看门狗(WWDG)看门狗的主要参数5.1时钟源5.2预分频器5.3重载值看门狗的配置步骤6.1配置独立看门狗(IWDG)6.2配置窗口看门狗(WWDG)看门狗的数学公式......
  • 从0开始学习Linux——Shell编程详解【04】
     期目录:从0开始学习Linux——简介&安装从0开始学习Linux——搭建属于自己的Linux虚拟机从0开始学习Linux——文本编辑器从0开始学习Linux——Yum工具从0开始学习Linux——远程连接工具从0开始学习Linux——文件目录从0开始学习Linux——网络配置从0开始学习Linux——......
  • Day35--static关键字详解
    Day35--static关键字详解示例:packagecom.liu.oop.demo07;//staticpublicclassStudent{privatestaticintage;//静态的变量privatedoublescore;//非静态的变量publicstaticvoidmain(String[]args){Students1=newS......