首页 > 编程问答 >如何为 NYU 数据集训练 Yolo 3D

如何为 NYU 数据集训练 Yolo 3D

时间:2024-07-23 05:34:34浏览次数:15  
标签:python machine-learning computer-vision yolov5

我已经在 KITTI 数据集上训练了我的 Yolo 3D 模型,现在我想在 NYU 数据集上训练它。为了在 YOLO 3D 模型中训练它,我必须对 NYU 数据集进行哪些更改?

我想知道 YOLO 3D 接受的数据集格式。

(编辑) 我使用的模型是 YOLO 3D-lightning https://github.com/ruhyadi/yolo3d-lightning

基本上我使用预训练的权重进行推理

python inference.py \
  source_dir="./data/demo/videos/2011_09_26/image_02/data" \
  detector.model_path="./weights/detector_yolov5s.pt" \
  regressor_weights="./weights/mobilenetv3-best.pt"

实际上这个模型是在 KITI 数据集中训练的但现在我想在 NYU 数据集中训练这个模型,这就是为什么我想知道我需要更改 NYU 数据集的格式,以便我可以使用它来训练 YOLO 3D。

示例以便转换kiti 到 yolo 这个脚本给出了

"""
Convert KITTI format to YOLO format.
"""

import os
import numpy as np
from glob import glob
from tqdm import tqdm
import argparse

from typing import Tuple


class KITTI2YOLO:
    def __init__(
        self,
        dataset_path: str = "../data/KITTI",
        classes: Tuple = ["car", "van", "truck", "pedestrian", "cyclist"],
        img_width: int = 1224,
        img_height: int = 370,
    ):

        self.dataset_path = dataset_path
        self.img_width = img_width
        self.img_height = img_height
        self.classes = classes
        self.ids = {self.classes[i]: i for i in range(len(self.classes))}

        # create new directory
        self.label_path = os.path.join(self.dataset_path, "labels")
        if not os.path.isdir(self.label_path):
            os.makedirs(self.label_path)
        else:
            print("[INFO] Directory already exist...")

    def convert(self):
        files = glob(os.path.join(self.dataset_path, "label_2", "*.txt"))
        for file in tqdm(files):
            with open(file, "r") as f:
                filename = os.path.join(self.label_path, file.split("/")[-1])
                dump_txt = open(filename, "w")
                for line in f:
                    parse_line = self.parse_line(line)
                    if parse_line["name"].lower() not in self.classes:
                        continue

                    xmin, ymin, xmax, ymax = parse_line["bbox_camera"]
                    xcenter = ((xmax - xmin) / 2 + xmin) / self.img_width
                    if xcenter > 1.0:
                        xcenter = 1.0
                    ycenter = ((ymax - ymin) / 2 + ymin) / self.img_height
                    if ycenter > 1.0:
                        ycenter = 1.0
                    width = (xmax - xmin) / self.img_width
                    if width > 1.0:
                        width = 1.0
                    height = (ymax - ymin) / self.img_height
                    if height > 1.0:
                        height = 1.0

                    bbox_yolo = f"{self.ids[parse_line['name'].lower()]} {xcenter:.3f} {ycenter:.3f} {width:.3f} {height:.3f}"
                    dump_txt.write(bbox_yolo + "\n")

                dump_txt.close()

    def parse_line(self, line):
        parts = line.split(" ")
        output = {
            "name": parts[0].strip(),
            "xyz_camera": (float(parts[11]), float(parts[12]), float(parts[13])),
            "wlh": (float(parts[9]), float(parts[10]), float(parts[8])),
            "yaw_camera": float(parts[14]),
            "bbox_camera": (
                float(parts[4]),
                float(parts[5]),
                float(parts[6]),
                float(parts[7]),
            ),
            "truncation": float(parts[1]),
            "occlusion": float(parts[2]),
            "alpha": float(parts[3]),
        }

        # Add score if specified
        if len(parts) > 15:
            output["score"] = float(parts[15])
        else:
            output["score"] = np.nan

        return output


if __name__ == "__main__":

    # argparser
    parser = argparse.ArgumentParser(description="KITTI to YOLO Convertion")
    parser.add_argument("--dataset_path", type=str, default="../data/KITTI")
    parser.add_argument(
        "--classes",
        type=Tuple,
        default=["car", "van", "truck", "pedestrian", "cyclist"],
    )
    parser.add_argument("--img_width", type=int, default=1224)
    parser.add_argument("--img_height", type=int, default=370)
    args = parser.parse_args()

    kitit2yolo = KITTI2YOLO(
        dataset_path=args.dataset_path,
        classes=args.classes,
        img_width=args.img_width,
        img_height=args.img_height,
    )
    kitit2yolo.convert()

类似地我想得到一个将 NYU 转换为 yolo 的脚本


好的,了解想将 NYU 数据集格式转换为 YOLO 3D 训练所需的格式,尤其是使用的是 yolo3d-lightning 实现。

以下是需要了解的有关 NYU 数据集和 YOLO 格式的信息:

  1. NYU 数据集格式: NYU 数据集通常提供 RGB 图像和深度图,以及相机姿态信息。深度信息以每个像素的距离(以米为单位)存储。标注可能因特定下载的数据集版本而异,但通常包括边界框和类别信息。

  2. YOLO 3D 格式: yolo3d-lightning 模型期望数据采用类似 KITTI 的格式,每个图像有一个文本文件。每个文本文件包含每行一个对象的以下信息:

    • 类别 ID: 类的数值索引(例如,0 代表汽车,1 代表行人)。
    • 中心 x: 归一化边界框中心 x 坐标(0 到 1)。
    • 中心 y: 归一化边界框中心 y 坐标(0 到 1)。
    • 宽度: 归一化边界框宽度(0 到 1)。
    • 高度: 归一化边界框高度(0 到 1)。
    • 深度: 到对象的距离(以米为单位)。
    • 航向角: 围绕垂直轴的对象方向(以弧度为单位)。

将 NYU 转换为 YOLO 格式的步骤:

  1. 准备数据:

    • 创建一个名为“图像”的文件夹来存储图像,以及一个名为“标签”的文件夹来存储对应的标签文本文件。
    • 确保拥有 RGB 图像、深度图、相机校准信息和标注。
  2. 生成标签文件: 对于每个图像,需要创建一个具有相同名称的文本文件(例如,00001.jpg 对应于 00001.txt)。

  3. 解析标注:

    • 从 NYU 标注文件中读取边界框、类别和深度信息。
    • 计算边界框中心、宽度和高度。
    • 从深度图中提取边界框中心的深度值。
    • 根据的 NYU 数据集的可用信息计算或获取航向角(如果需要)。
  4. 将边界框坐标标准化:

    • 将边界框中心 x、中心 y、宽度和高度除以图像的相应宽度和高度。
  5. 将数据写入标签文件:

    • 对于图像中的每个对象,将类别 ID、归一化的边界框中心 x、中心 y、宽度、高度、深度和航向角(如果可用)写入相应的文本文件。

Python 转换脚本示例:

import os
import numpy as np
from math import cos, sin

# 设置 NYU 数据集和输出目录的路径
nyu_dataset_path = '/path/to/nyu/dataset'
output_path = '/path/to/output'

# 创建输出目录(如果不存在)
os.makedirs(os.path.join(output_path, 'labels'), exist_ok=True)

# 设置图像宽度和高度
image_width = 640  # 根据的数据调整
image_height = 480

# 定义类名和 ID 的映射
class_names = {
    0: 'person',
    1: 'chair',
    # ... 添加的其他类
}

# 循环遍历图像
for image_name in os.listdir(os.path.join(nyu_dataset_path, 'images')):
    if not image_name.endswith('.jpg'):
        continue

    # 获取图像 ID
    image_id = image_name.split('.')[0]

    # 加载深度图
    depth_map = np.load(os.path.join(nyu_dataset_path, 'depth', f'{image_id}.npy'))

    # 加载标注(根据的标注格式调整)
    with open(os.path.join(nyu_dataset_path, 'labels', f'{image_id}.txt')) as f:
        labels = [line.strip().split() for line in f]

    # 创建标签文件
    label_file_path = os.path.join(output_path, 'labels', f'{image_id}.txt')
    with open(label_file_path, 'w') as label_file:
        # 处理每个对象
        for label in labels:
            class_id = int(label[0])
            xmin = float(label[1])
            ymin = float(label[2])
            xmax = float(label[3])
            ymax = float(label[4])

            # 计算边界框中心、宽度和高度
            bbox_center_x = ((xmax - xmin) / 2 + xmin) / image_width
            bbox_center_y = ((ymax - ymin) / 2 + ymin) / image_height
            bbox_width = (xmax - xmin) / image_width
            bbox_height = (ymax - ymin) / image_height

            # 获取边界框中心的深度
            depth = depth_map[int(bbox_center_y * image_height), int(bbox_center_x * image_width)]

            # 计算航向角(如果需要,请根据的数据调整)
            # yaw = ...

            # 将数据写入标签文件
            label_file.write(f'{class_id} {bbox_center_x:.6f} {bbox_center_y:.6f} {bbox_width:.6f} {bbox_height:.6f} {depth:.6f} {yaw:.6f}\n')

修改脚本: * 用的 NYU 数据集的实际路径替换 nyu_dataset_path output_path 。 * 根据的数据格式调整 image_width image_height 。 * 修改 class_names 以匹配的数据集中的类。 * 根据 NYU 标注文件结构调整加载标注的部分。 * 如果需要,请包含航向角计算。

此脚本将读取的 NYU 数据集并为 yolo3d-lightning 训练生成 YOLO 格式的标签文件。

请记住: 这个脚本提供了一个基本框架,可能需要根据的 NYU 数据集的特定格式和 yolo3d-lightning 模型的要求对其进行调整。

标签:python,machine-learning,computer-vision,yolov5
From: 78779752

相关文章

  • 如何使用 Python Flask 将新的咖啡馆(元素)添加到数据库(SQLite)?
    这是我的代码:@app.route("/add",methods=["POST"])defpost_new_cafe():new_cafe=Cafe(name=request.form.get("name"),map_url=request.form.get("map_url"),img_url=request.form.get("img......
  • 使用 tkinter 为 python 创建 GUI 时如何解决语法错误?
    我是一名Python初学者,决定使用tkinter制作一个小型GUI,该GUI接受用户(潜在餐馆)的3个输入,然后单击按钮后随机输出其中一家餐馆。我不断收到语法错误,并认为它与我的buttonfunc或调用它的命令有关。此代码尚未包含在GUI上输出餐厅的任何位置。任何帮助将不胜感激#Pyth......
  • 在 python 中打开 gnome 终端立即显示为僵尸
    作为背景,我正在编写一个脚本来训练多个pytorch模型。我有一个训练脚本,我希望能够在gnome终端中作为子进程运行。这样做的主要原因是我可以随时关注训练进度。如果我可能有多个GPU,我想在单独的窗口中多次运行我的训练脚本。为了实现这一点,我一直在使用popen。以下代码用于打......
  • python threading.Condition 的意外行为
    我正在尝试同步多个线程。我期望使用threading.Condition和threading.Barrier时的脚本输出大致相同,但事实并非如此。请解释一下为什么会发生这种情况。一般来说,我需要线程在一个无限循环中执行工作(一些IO操作),但是每个循环都是以主线程的权限开始的,而权限是仅在......
  • Python - 逆透视数据框
    我有一个按日期时间索引的表,每个日期时间都有多个层(中心和交货间隔):日期时间中心交货间隔结算点价格2024-01-0101:00:00休斯顿中心1......
  • 试图理解这个错误:致命的Python错误:PyEval_RestoreThread:该函数必须在持有GIL的情况下
    我有一个小型tkinter应用程序,我一直在其中实现最小的“拖放”,主要作为学习实验。我真正关心的是删除文件的文件路径。一切实际上都工作正常,直到我尝试在拖放后打包标签小部件。下面的最小工作示例。有问题的行会用注释指出。我通常不会在调试方面遇到太多麻烦,但我只是不知......
  • 如何使代码格式再次适用于 Python(Mac 上的 Visual Studio Code)?
    在Mac上,Option+Shift+F现在会显示“没有安装用于‘python’文件的格式化程序”。消息框:我尝试安装这个插件,但没有看到这种情况的变化:我已经为Python安装了这两个插件:但是正如@starball提到的,它可能已经减少了支持现在。......
  • 无法在 python 中安装 pip install expliot - bluepy 的 Building Wheel (pyproject.t
    在此处输入图像描述当我尝试在Windows计算机中通过cmd安装pipinstallexpliot包时,我收到2个错误名称×Buildingwheelforbluepy(pyproject.toml)didnotrunsuccessfully.│exitcode:1**AND**opt=self.warn_dash_deprecation......
  • python 用单斜杠-反斜杠替换url字符串中的双斜杠
    我的URL包含错误的双斜杠(“//”),我需要将其转换为单斜杠。不用说,我想保持“https:”后面的双斜杠不变。可以在字符串中进行此更改的最短Python代码是什么?我一直在尝试使用re.sub,带有冒号否定的正则表达式(即,[^:](//)),但它想要替换整个匹配项(包括前面......
  • 如何使用 Selenium Python 搜索 Excel 文件中的文本
    我有一些数据在Excel文件中。我想要转到Excel文件,然后搜索文本(取自网站表),然后获取该行的所有数据,这些数据将用于在浏览器中填充表格。示例:我希望selenium搜索ST0003然后获取名称,该学生ID的父亲姓名,以便我可以在大学网站中填写此信息。我想我会从网站......