首页 > 其他分享 >YOLOv10添加输出各类别训练过程指标

YOLOv10添加输出各类别训练过程指标

时间:2024-07-02 09:53:40浏览次数:25  
标签:box 输出 stats self metrics 添加 YOLOv10 save class

昨天有群友,在交流群【群号:392784757】里提到了这个需求,进行实现一下

image.png

V10 官方代码结构相较于 V8 稍微复杂一些

image.png

yolov10 是基于 v8 的代码完成开发,yolov10 进行了继承来简化代码开发

因此 V10 的代码修改 基本和 V8 这篇一致
https://blog.csdn.net/csy1021/article/details/134406419

但存在一些不同,会在下面提到

版本环境

YOLOv10 2024.07.01 版本

修改

trainer.py

1 添加 save_metrics_per_class()

在 save_metrics 函数后面,添加下面的 save_metrics_per_class 函数

def save_metrics_per_class(self, box):

    """Saves training metrics per class to a CSV file."""

    # ap ap50 p r 提示作用
    keys = ['ap', 'ap50', 'p', 'r']
    n = 4 + 1  # number of cols

    for i in box.ap_class_index:
        cur_class = self.model.names[box.ap_class_index[i]]
        save_path = self.save_dir.joinpath("result_" + cur_class + ".csv")
        vals = [box.ap[i], box.ap50[i], box.p[i], box.r[i]]
        s = '' if save_path.exists() else (('%23s,' * n % tuple(['epoch'] + keys)).rstrip(',') + '\n')  # header

        with open(save_path, 'a') as f:
            f.write(s + ('%23.5g,' * n % tuple([self.epoch] + vals)).rstrip(',') + '\n')

2 validate() 修改

def validate(self):
    """
    Runs validation on test set using self.validator.

    The returned dict is expected to contain "fitness" key.
    """
    # metrics = self.validator(self)
    metrics,box = self.validator(self)
    fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy())  # use loss as fitness measure if not found
    if not self.best_fitness or self.best_fitness < fitness:
        self.best_fitness = fitness
    # return metrics, fitness
    return metrics, fitness,box

找到【这里比 v8 的判断要多】

if (self.args.val and (((epoch+1) % self.args.val_period == 0) or (self.epochs - epoch) <= 10)) \
                    or final_epoch or self.stopper.possible_stop or self.stop:
                    self.metrics, self.fitness = self.validate()

修改为

if (self.args.val and (((epoch+1) % self.args.val_period == 0) or (self.epochs - epoch) <= 10)) \
                    or final_epoch or self.stopper.possible_stop or self.stop:
                    # self.metrics, self.fitness = self.validate()
                    self.metrics, self.fitness,box = self.validate()

3 找到 self.save_metrics


self.save_metrics(metrics={**self.label_loss_items(self.tloss), **self.metrics, **self.lr})
后面添加调用
self.save_metrics_per_class(box)

validator.py

找到 stats = self.get_stats()
改为 stats,box = self.get_stats()

找到 return {k: round(float(v), 5) for k, v in results.items()}
改为 return {k: round(float(v), 5) for k, v in results.items()}, box

val.py

get_stats() 【注意与 v8 不同】

def get_stats(self):
    """Returns metrics statistics and results dictionary."""
    stats = {k: torch.cat(v, 0).cpu().numpy() for k, v in self.stats.items()}  # to numpy
    # if len(stats) and stats["tp"].any():
    # if len(stats) and stats[0].any():
    if len(stats) :
        self.metrics.process(**stats)
    self.nt_per_class = np.bincount(
        stats["target_cls"].astype(int), minlength=self.nc
    )  # number of targets per class
    # return self.metrics.results_dict
    return self.metrics.results_dict,self.metrics.box

save_metrics_per_class() 函数 【注意与 v8 不同】

image.png 可以看到支持的指标有 all_ap (可用来计算其他ap指标),map,map50,f1,p ap,r mr ... 我在函数中使用的是 ap,ap50,p,r,需要其他的可以再添加 ==注意:添加指标,使用的是 . 而不是 ["xxxx"] 如 box.ap[i] 而不是 box['ap'][i]==
def save_metrics_per_class(self, box):

    """Saves training metrics per class to a CSV file."""

    # ap ap50 p r 提示作用
    keys = ['ap', 'ap50', 'p', 'r']
    n = 4 + 1  # number of cols

    for i in box.ap_class_index:
        cur_class = self.model.names[box.ap_class_index[i]]
        save_path = self.save_dir.joinpath("result_" + cur_class + ".csv")
        vals = [box.ap[i], box.ap50[i], box.p[i], box.r[i]]
        s = '' if save_path.exists() else (('%23s,' * n % tuple(['epoch'] + keys)).rstrip(',') + '\n')  # header

        with open(save_path, 'a') as f:
            f.write(s + ('%23.5g,' * n % tuple([self.epoch] + vals)).rstrip(',') + '\n')

注意!不同点

def get_stats(self):
    """Returns metrics statistics and results dictionary."""
    stats = {k: torch.cat(v, 0).cpu().numpy() for k, v in self.stats.items()}  # to numpy
    # if len(stats) and stats["tp"].any(): # v10
    # if len(stats) and stats[0].any(): # v8
    if len(stats) : # 修改后
        self.metrics.process(**stats)
    self.nt_per_class = np.bincount(
        stats["target_cls"].astype(int), minlength=self.nc
    )  # number of targets per class
    # return self.metrics.results_dict
    return self.metrics.results_dict,self.metrics.box

v10
image.png
v8
image.png

image.png

如果不修改 这个判断条件

if len(stats) and stats["tp"].any(): # v10
# if len(stats) and stats[0].any(): # v8 仅作对比
if len(stats) : # 修改后

可能会出现 前几次 epoch 数据不记录的问题 【这里也可能是和我的数据集有关,我测试了几次,增加 batch-size 发现仍然 stats["tp"] 仍然全为 false 过不了,后面 epoch 会正常 】这里大家可以自行测试后决定,如果正常,就不需要改

其他

增加训练过程各类指标打印(可选,默认开启是有条件的)
val.py 找到 print_results() 函数 在
LOGGER.info(pf % ('all', self.seen, self.nt_per_class.sum(), *self.metrics.mean_results())) 后面
添加

for i, c in enumerate(self.metrics.ap_class_index):
    LOGGER.info(pf % (self.names[c], self.seen, self.nt_per_class[c], *self.metrics.class_result(i)))

有问题,欢迎留言、进群讨论或私聊:【群号:392784757】

标签:box,输出,stats,self,metrics,添加,YOLOv10,save,class
From: https://www.cnblogs.com/caibucai/p/18279313

相关文章

  • YOLOv10改进 | 注意力篇 | YOLOv10引入24年最新Mamba注意力机制MLLAttention
    1. MLLAttention介绍1.1 摘要: Mamba是一种有效的状态空间模型,具有线性计算复杂度。最近,它在处理各种视觉任务的高分辨率输入方面表现出了令人印象深刻的效率。在本文中,我们揭示了强大的Mamba模型与线性注意力Transformer具有惊人的相似之处,而线性注意力Transform......
  • C++文件输入输出
    参考博文:https://blog.csdn.net/houbincarson/article/details/136327765/*文件输入输出fstream有三个文件流类:std::ifstream:用于从文件中读取数据的输入流对象。std::ofstream:用于向文件中写入数据的输出流对象。std::fstream:用于读写文件的输入输出流对象。*/#include<f......
  • STM32串口如何输出中文
    当你想在串口调试助手实现换行功能时却不行时,试一试将\n改为\r\n因为我用的是XCOM串口调试助手,就遇到了这样的问题而当你加入intfputc(intch,FILE*f)函数却实现不了printf,putchar调用时需要加入#include<stdio.h>并勾选魔术棒中的UseMicroLIBintfputc(intch,FILE*f)......
  • YOLOv10改进教程|C2f-CIB加入注意力机制
      一、导读    论文链接:https://arxiv.org/abs/2311.11587    代码链接:GitHub-CV-ZhangXin/AKConv YOLOv10训练、验证及推理教程二、C2f-CIB加入注意力机制2.1复制代码        打开ultralytics->nn->modules->block.py文件,复制SE......
  • 动态添加Timeline轨道和片段
    上图是利用代码制作的,下图是原来的样子:如下代码是动态创建各种Timeline轨道的代码:(控制角色碰撞到Cube触发以下的Timeline动画)usingCinemachine;usingUnityEngine;usingUnityEngine.Events;usingUnityEngine.Playables;usingUnityEngine.Timeline;publiccla......
  • 【苍穹外卖】P18通过前端页面添加员工,传过来的值为空
    漏掉了注解@RequestBodypublicResultsave(@RequestBodyEmployeeDTOemployeeDTO){//把漏掉的@RequestBody加上log.info("新增员工:{}",employeeDTO);employeeService.save(employeeDTO);returnResult.success();}重新启动项目......
  • win11添加开机自启动
    方法1win+R打开运行,输入shell:startup会打开一个文件夹将想要启动的程序快捷方式放进文件夹在设置里面搜索“启动”,可以看到开机启动项,确认已经打开。以上,针对不用管理员权限启动的程序,有效。方法2下面看需要管理员权限的:按Win+R,输入regedit,打开注册表编辑......
  • Python武器库 - 科研中常用的python图像操作 - 图像添加文字
    应用场景:在科研中,有时需要在生成结果中标注文字作为说明,或者添加文字在一行图片的开头作为标题(这个效果通常需要配合在一行图片的开头添加一张空(纯黑)图片,在该图片中添加文字作为标题,使用python-opencv来创建一张纯色图片的操作,详情见我的另一篇随笔https://www.cnblogs.com......
  • 用不同的方法输出时间记录器的时、分、秒,注意对象指针的使用方法
            对象有地址,存放对象的起始地址的指针变量就是指向对象的指针变量。对象中的成员也有地址,存放对象成员地址的指针变量就是指向对象成员的指针变量。        1.指向对象数据成员的指针        定义指向对象数据成员的指针变量的方法和定义指向......
  • 微服务服务添加数据源、认证授权、日志记录,安全处理
    为了增强SpringBoot后端服务的功能,我们可以添加数据库支持、认证授权、日志记录和安全处理。以下是如何集成这些功能的基本步骤。数据库集成添加依赖:在pom.xml或build.gradle中添加数据库驱动和SpringDataJPA的依赖。配置数据库:在src/main/resources/applicat......