首页 > 其他分享 >ByteTrak训练自定义训练集

ByteTrak训练自定义训练集

时间:2024-09-12 22:25:23浏览次数:19  
标签:训练 自定义 mot train video ByteTrak path os dir

ByteTrack目标追踪训练主要参考的博文是https://blog.csdn.net/Ddddd4431/article/details/126910083

但是这位博主的数据集准备跟我的还有点不一样,他用的是labelimg标注,我用的是Darklabel对视频直接进行标注。而ByteTrak的训练格式是COCO数据集格式。而Darklabel对视频标注生成的是MOT文件。如何使用Darklabel标注视频数据集可下面这篇博文https://blog.csdn.net/qq_61033357/article/details/136331771

1、数据集转换

前面介绍了我是直接对视频进行标注,由于我使用的Darklabel版本最多只能标注100个目标,因此我将我的视频进行裁剪,然后进行标注。标注好的视频会生成对应的MOT文件,并以csv的格式保存你的标注数据。一般MOT文件所包含的信息是[fn, id, x1, y1, w, h, c=-1, c=-1, c=-1, c=-1, cname]。

转换时,先按照一定的比例将标注好的视频和MOT文件分到train和val两个文件夹中

下面是用通义千问写的分类代码:

import os
import random
import shutil

# 定义输入和输出目录
input_dir = ''#定义你自己的路径
output_base_dir = ''#定义你自己的路径

# 定义输出目录
train_dir = os.path.join(output_base_dir, 'train')
val_dir = os.path.join(output_base_dir, 'val')

# 创建输出目录
os.makedirs(train_dir, exist_ok=True)
os.makedirs(val_dir, exist_ok=True)

# 定义训练集和验证集的比例
train_ratio = 0.8

# 获取MOT文件和视频文件的列表
mot_files = [f for f in os.listdir(input_dir) if f.startswith('output_') and f.endswith('.csv')]
video_files = [f for f in os.listdir(input_dir) if f.startswith('output_') and f.endswith('.mp4')]

# 确保MOT文件和视频文件数量相同
assert len(mot_files) == len(video_files), "数量不匹配"

# 将MOT文件和视频文件配对
file_pairs = list(zip(mot_files, video_files))

# 混洗文件配对列表以随机分配
random.shuffle(file_pairs)

# 分割文件列表
split_index = int(len(file_pairs) * train_ratio)
train_pairs = file_pairs[:split_index]
val_pairs = file_pairs[split_index:]


# 定义函数来复制文件
def copy_file_pairs(file_pairs, dest_dir):
    for mot_file, video_file in file_pairs:
        # 复制MOT文件
        src_mot_path = os.path.join(input_dir, mot_file)
        dest_mot_path = os.path.join(dest_dir, mot_file)
        shutil.copy(src_mot_path, dest_mot_path)

        # 复制视频文件
        src_video_path = os.path.join(input_dir, video_file)
        dest_video_path = os.path.join(dest_dir, video_file)
        shutil.copy(src_video_path, dest_video_path)


# 复制训练集文件
copy_file_pairs(train_pairs, train_dir)

# 复制验证集文件
copy_file_pairs(val_pairs, val_dir)

print("数据集划分完成!")

分好类后转换为coco数据集训练格式,下面也是用通义千问写的转化代码:

import os
import cv2
import pandas as pd
import json
from sklearn.model_selection import train_test_split

# 定义数据集类别字典
category_dict = {'your class': 1, 'your class': 2,
                  ...}

def read_mot_file(mot_path):
    try:
        # 读取CSV文件并指定列名
        column_names = ['fn', 'id', 'x1', 'y1', 'w', 'h', 'c1', 'c2', 'c3', 'c4', 'cname']
        df = pd.read_csv(mot_path, header=None, names=column_names)

        # 检查是否有缺失值
        if df.isnull().values.any():
            print(f"MOT file {mot_path} contains missing values.")
            return None

        # 返回DataFrame
        return df
    except Exception as e:
        print(f"Failed to read MOT file {mot_path}: {e}")
        return None

def extract_and_save_frames(df, video_path, output_dir):
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        print(f"Failed to open video file: {video_path}")
        return None, None

    # 创建输出目录
    os.makedirs(output_dir, exist_ok=True)

    frame_numbers = df['fn'].unique()
    annotations = []
    image_info = []
    image_id = 0  # 图像ID计数器
    annotation_id = 0  # 注释ID计数器

    video_name = os.path.splitext(os.path.basename(video_path))[0]

    for fn in frame_numbers:
        cap.set(cv2.CAP_PROP_POS_FRAMES, fn - 1)  # OpenCV的帧索引从0开始
        ret, frame = cap.read()
        if not ret:
            print(f"Failed to read frame {fn} from {video_path}")
            continue

        # 构建输出文件名
        output_path = os.path.join(output_dir, f"{video_name}_{fn:06d}.jpg")
        cv2.imwrite(output_path, frame)

        # 记录图像信息
        image_info.append({
            "id": image_id,
            "file_name": f"{video_name}_{fn:06d}.jpg",
            "width": frame.shape[1],
            "height": frame.shape[0]
        })

        # 获取该帧的所有标注信息
        frame_df = df[df['fn'] == fn]

        # 添加标注信息
        for index, row in frame_df.iterrows():
            annotations.append({
                "id": annotation_id,
                "image_id": image_id,
                "category_id": category_dict[row['cname']],
                "bbox": [int(row['x1']), int(row['y1']), int(row['w']), int(row['h'])],
                "area": int(row['w']) * int(row['h']),
                "iscrowd": 0
            })
            annotation_id += 1

        image_id += 1

    cap.release()

    return annotations, image_info

def generate_coco_annotations(annotations, image_info, output_path):
    coco_data = {
        "images": image_info,
        "annotations": annotations,
        "categories": [{"id": v, "name": k} for k, v in category_dict.items()]
    }

    # 写入JSON文件
    with open(output_path, 'w') as f:
        json.dump(coco_data, f, indent=4)

def process_videos(train_dir, val_dir, output_dir):
    # 创建输出目录
    train_output_dir = os.path.join(output_dir, 'train')
    val_output_dir = os.path.join(output_dir, 'val')

    os.makedirs(train_output_dir, exist_ok=True)
    os.makedirs(val_output_dir, exist_ok=True)

    # 创建annotations目录
    annotations_dir = os.path.join(output_dir, 'annotations')
    os.makedirs(annotations_dir, exist_ok=True)

    # 获取所有MOT文件和视频文件的路径
    train_mot_files = [os.path.join(train_dir, f) for f in os.listdir(train_dir) if f.endswith('.csv')]
    train_video_files = [os.path.join(train_dir, f.replace('.csv', '.mp4')) for f in os.listdir(train_dir) if
                         f.endswith('.csv')]

    val_mot_files = [os.path.join(val_dir, f) for f in os.listdir(val_dir) if f.endswith('.csv')]
    val_video_files = [os.path.join(val_dir, f.replace('.csv', '.mp4')) for f in os.listdir(val_dir) if
                       f.endswith('.csv')]

    # 初始化总的图像信息和标注信息列表
    all_train_image_info = []
    all_train_annotations = []
    all_val_image_info = []
    all_val_annotations = []

    # 处理训练集
    for mot_path, video_path in zip(train_mot_files, train_video_files):
        mot_filename = os.path.basename(mot_path)
        video_filename = os.path.basename(video_path)

        if mot_filename.split('.')[0] != video_filename.split('.')[0]:
            print(f"Filename mismatch between MOT file {mot_path} and video file {video_path}.")
            continue

        mot_df = read_mot_file(mot_path)
        if mot_df is not None:
            annotations, image_info = extract_and_save_frames(mot_df, video_path, train_output_dir)
            if annotations and image_info:
                all_train_annotations.extend(annotations)
                all_train_image_info.extend(image_info)

    # 处理验证集
    for mot_path, video_path in zip(val_mot_files, val_video_files):
        mot_filename = os.path.basename(mot_path)
        video_filename = os.path.basename(video_path)

        if mot_filename.split('.')[0] != video_filename.split('.')[0]:
            print(f"Filename mismatch between MOT file {mot_path} and video file {video_path}.")
            continue

        mot_df = read_mot_file(mot_path)
        if mot_df is not None:
            annotations, image_info = extract_and_save_frames(mot_df, video_path, val_output_dir)
            if annotations and image_info:
                all_val_annotations.extend(annotations)
                all_val_image_info.extend(image_info)

    # 生成总的标注文件
    if all_train_image_info and all_train_annotations:
        json_output_path = os.path.join(annotations_dir, 'train.json')
        generate_coco_annotations(all_train_annotations, all_train_image_info, json_output_path)

    if all_val_image_info and all_val_annotations:
        json_output_path = os.path.join(annotations_dir, 'val.json')
        generate_coco_annotations(all_val_annotations, all_val_image_info, json_output_path)

# 示例调用
train_dir = ''#你自己的位置
val_dir = ''#你自己的位置
output_dir = ''#你自己的位置

process_videos(train_dir, val_dir, output_dir)

 如果最后你运行出来的结果是下面这样的,那基本上可以训练了~

coco数据集的格式是,annotations文件只保存train.json和val.json,这两个文件记录了你所标注视频的每一帧的所有信息,因为我没有找到直接训练视频的,所找到的信息都是将视频转化成视频帧然后进行训练的,所以上述代码可以直接将视频按照帧号裁剪并将标注信息输入到annotations文件内夹下;train和val两个文件夹保存的是裁剪出来的视频帧!是jpg!

2、训练

打开终端,打开你新建的环境,输入训练代码:

python D:/bytetrack/ByteTrack-main/tools/train.py -f D:/bytetrack/ByteTrack-main/exps/example/mot/tree_yolox_x_ch.py -d 0 -b 3 --fp16 -o -c D:/bytetrack/ByteTrack-main/pretrained/yolox_m.pth

其中,看过文章置顶博文就会明白,mot文件夹下的tree_yolox_x_ch.py文件是你自己的配置文件,我训练的是树,所以是tree开头,具体命名格式根据自己喜好来。

训练截图如下:

3、训练结果保存

训练结束后会保存到一个名为YOLOX_outputs的文件夹下面

txt文件是训练日志,三个压缩文件包类似于yolo训练生成的权重。

4、检测训练结果

检测代码

python D:/bytetrack/ByteTrack-main/tools/demo_track.py video -f D:/bytetrack/ByteTrack-main/exps/example/mot/tree_yolox_x_ch.py --path D:/bytetrack/ByteTrack-main/datasets/xxx.MP4 -c D:/bytetrack/ByteTrack-main/YOLOX_outputs/tree_yolox_x_ch/last_epoch_ckpt.pth.tar --fp16 --fuse --save_result

检测截图

会将输出结果保存到

5、总结

我也是刚开始玩目标追踪不久,也不是专业码农出身哈哈哈哈哈~写这篇博文主要是发现相关训练博文太少了,一是记录自己,二是给别人一个参考~有问题也欢迎讨论哦~

标签:训练,自定义,mot,train,video,ByteTrak,path,os,dir
From: https://blog.csdn.net/2201_75281851/article/details/142177568

相关文章

  • 代码随想录算法训练营,9月12日 | 513.找树左下角的值,112. 路径总和,106.从中序与后序遍
    513.找树左下角的值题目链接:513.找树左下角的值文档讲解︰代码随想录(programmercarl.com)视频讲解︰找树左下角的值日期:2024-09-12想法:1.迭代:用层序遍历,遍历每层时记录下第一个节点的值,到最后一层就是要求的值;2.递归:根据最大的深度来找目标值。Java代码如下://迭代classSolut......
  • 旋转按钮—C#自定义控件1
    C#自定义控件—旋转按钮 C#用户控件之旋转按钮按钮功能:手自动旋转,标签文本显示、点击二次弹框确认(源码在最后边);【制作方法】找到控件的中心坐标,画背景外环、内圆;再绘制矩形开关,进行角度旋转即可获得;【关键节点】No.1获取中心坐标,思考要绘制图形的相对坐标、宽度......
  • 【人脸检测】SCRFD:训练数据采样和计算分配策略结合的高效人脸检测方法
    SampleandComputationRedistributionforEfficientFaceDetection论文链接:http://arxiv.org/abs/2105.04714代码链接:https://github.com/deepinsight/insightface/tree/master/detection/scrfd一、摘要 文中指出训练数据采样和计算分配策略是实现高效准确人脸检......
  • C#中设置自定义控件工具箱图标
    在设计自定义控件时,系统默认生成的图标比较单一且难看,如何为控件设计自己的图标呢,这里给出了一种基于ToolBoxBitmap 属性设置自定义控件工具箱图标的方法。1、首先将图标文件名改为自定义控件名,如自定义控件类为: public partial class UserDefindControl: UserControl {......
  • 图论篇--代码随想录算法训练营第五十七天打卡| 最小生成树问题
    题目链接:53.寻宝(第七期模拟笔试)题目描述:在世界的某个区域,有一些分散的神秘岛屿,每个岛屿上都有一种珍稀的资源或者宝藏。国王打算在这些岛屿上建公路,方便运输。不同岛屿之间,路途距离不同,国王希望你可以规划建公路的方案,如何可以以最短的总公路距离将所有岛屿联通起来(注意:这......
  • Yolo第Y2周:如何正确解读YOLO算法训练结果的各项指标
    目录Yolo第Y2周:如何正确解读YOLO算法训练结果的各项指标weights文件夹:最终的仙丹results.png:训练总图要略loss系列:打明牌的能力box_loss边界框损失:衡量画框cls_loss分类损失:判断框里的物体dfl_loss分布式焦点损失:精益求精验证集:学得好,不一定考得好精度和召回率:又准又全的考量r......