Yolov8 源码解析(二十六)
.\yolov8\tests\test_engine.py
# 导入所需的模块和库
import sys # 系统模块
from unittest import mock # 导入 mock 模块
# 导入自定义模块和类
from tests import MODEL # 导入 tests 模块中的 MODEL 对象
from ultralytics import YOLO # 导入 ultralytics 库中的 YOLO 类
from ultralytics.cfg import get_cfg # 导入 ultralytics 库中的 get_cfg 函数
from ultralytics.engine.exporter import Exporter # 导入 ultralytics 库中的 Exporter 类
from ultralytics.models.yolo import classify, detect, segment # 导入 ultralytics 库中的 classify, detect, segment 函数
from ultralytics.utils import ASSETS, DEFAULT_CFG, WEIGHTS_DIR # 导入 ultralytics 库中的 ASSETS, DEFAULT_CFG, WEIGHTS_DIR 变量
def test_func(*args): # 定义测试函数,用于评估 YOLO 模型性能指标
"""Test function callback for evaluating YOLO model performance metrics."""
print("callback test passed") # 打印测试通过消息
def test_export():
"""Tests the model exporting function by adding a callback and asserting its execution."""
exporter = Exporter() # 创建 Exporter 对象
exporter.add_callback("on_export_start", test_func) # 添加回调函数到导出开始事件
assert test_func in exporter.callbacks["on_export_start"], "callback test failed" # 断言回调函数已成功添加
f = exporter(model=YOLO("yolov8n.yaml").model) # 导出模型
YOLO(f)(ASSETS) # 使用导出后的模型进行推理
def test_detect():
"""Test YOLO object detection training, validation, and prediction functionality."""
overrides = {"data": "coco8.yaml", "model": "yolov8n.yaml", "imgsz": 32, "epochs": 1, "save": False} # 定义参数覆盖字典
cfg = get_cfg(DEFAULT_CFG) # 获取默认配置
cfg.data = "coco8.yaml" # 设置配置数据文件
cfg.imgsz = 32 # 设置配置图像尺寸
# Trainer
trainer = detect.DetectionTrainer(overrides=overrides) # 创建检测训练器对象
trainer.add_callback("on_train_start", test_func) # 添加回调函数到训练开始事件
assert test_func in trainer.callbacks["on_train_start"], "callback test failed" # 断言回调函数已成功添加
trainer.train() # 执行训练
# Validator
val = detect.DetectionValidator(args=cfg) # 创建检测验证器对象
val.add_callback("on_val_start", test_func) # 添加回调函数到验证开始事件
assert test_func in val.callbacks["on_val_start"], "callback test failed" # 断言回调函数已成功添加
val(model=trainer.best) # 使用最佳模型进行验证
# Predictor
pred = detect.DetectionPredictor(overrides={"imgsz": [64, 64]}) # 创建检测预测器对象
pred.add_callback("on_predict_start", test_func) # 添加回调函数到预测开始事件
assert test_func in pred.callbacks["on_predict_start"], "callback test failed" # 断言回调函数已成功添加
# 确认 sys.argv 为空没有问题
with mock.patch.object(sys, "argv", []):
result = pred(source=ASSETS, model=MODEL) # 执行预测
assert len(result), "predictor test failed" # 断言预测结果不为空
overrides["resume"] = trainer.last # 设置训练器的恢复模型
trainer = detect.DetectionTrainer(overrides=overrides) # 创建新的检测训练器对象
try:
trainer.train() # 执行训练
except Exception as e:
print(f"Expected exception caught: {e}") # 捕获并打印预期的异常
return
Exception("Resume test failed!") # 报告恢复测试失败
def test_segment():
"""Tests image segmentation training, validation, and prediction pipelines using YOLO models."""
overrides = {"data": "coco8-seg.yaml", "model": "yolov8n-seg.yaml", "imgsz": 32, "epochs": 1, "save": False} # 定义参数覆盖字典
cfg = get_cfg(DEFAULT_CFG) # 获取默认配置
cfg.data = "coco8-seg.yaml" # 设置配置数据文件
cfg.imgsz = 32 # 设置配置图像尺寸
# YOLO(CFG_SEG).train(**overrides) # works
# Trainer
trainer = segment.SegmentationTrainer(overrides=overrides) # 创建分割训练器对象
trainer.add_callback("on_train_start", test_func) # 添加回调函数到训练开始事件
assert test_func in trainer.callbacks["on_train_start"], "callback test failed" # 断言回调函数已成功添加
trainer.train() # 执行训练
# Validator
val = segment.SegmentationValidator(args=cfg) # 创建分割验证器对象
# 添加回调函数到“on_val_start”事件,使其在val对象开始时调用test_func函数
val.add_callback("on_val_start", test_func)
# 断言确认test_func确实添加到val对象的“on_val_start”事件回调列表中
assert test_func in val.callbacks["on_val_start"], "callback test failed"
# 使用trainer.best模型对val对象进行验证,验证best.pt模型的性能
val(model=trainer.best) # validate best.pt
# 创建SegmentationPredictor对象pred,覆盖参数imgsz为[64, 64]
pred = segment.SegmentationPredictor(overrides={"imgsz": [64, 64]})
# 添加回调函数到“on_predict_start”事件,使其在pred对象开始预测时调用test_func函数
pred.add_callback("on_predict_start", test_func)
# 断言确认test_func确实添加到pred对象的“on_predict_start”事件回调列表中
assert test_func in pred.callbacks["on_predict_start"], "callback test failed"
# 使用指定的模型进行预测,源数据为ASSETS,模型为WEIGHTS_DIR / "yolov8n-seg.pt"
result = pred(source=ASSETS, model=WEIGHTS_DIR / "yolov8n-seg.pt")
# 断言确保结果非空,验证预测器的功能
assert len(result), "predictor test failed"
# 测试恢复功能
overrides["resume"] = trainer.last # 设置恢复参数为trainer的最后状态
trainer = segment.SegmentationTrainer(overrides=overrides) # 使用指定参数创建SegmentationTrainer对象
try:
trainer.train() # 尝试训练模型
except Exception as e:
# 捕获异常并输出异常信息
print(f"Expected exception caught: {e}")
return
# 如果发生异常未被捕获,则抛出异常信息“Resume test failed!”
Exception("Resume test failed!")
def test_classify():
"""Test image classification including training, validation, and prediction phases."""
# 定义需要覆盖的配置项
overrides = {"data": "imagenet10", "model": "yolov8n-cls.yaml", "imgsz": 32, "epochs": 1, "save": False
# 根据默认配置获取配置对象
cfg = get_cfg(DEFAULT_CFG)
# 调整配置项中的数据集为 imagenet10
cfg.data = "imagenet10"
# 调整配置项中的图片尺寸为 32
cfg.imgsz = 32
# YOLO(CFG_SEG).train(**overrides) # works
# 创建分类训练器对象,应用 overrides 中的配置项
trainer = classify.ClassificationTrainer(overrides=overrides)
# 添加在训练开始时执行的回调函数 test_func
trainer.add_callback("on_train_start", test_func)
# 断言 test_func 是否成功添加到训练器的 on_train_start 回调中
assert test_func in trainer.callbacks["on_train_start"], "callback test failed"
# 开始训练
trainer.train()
# 创建分类验证器对象,使用 cfg 中的配置项
val = classify.ClassificationValidator(args=cfg)
# 添加在验证开始时执行的回调函数 test_func
val.add_callback("on_val_start", test_func)
# 断言 test_func 是否成功添加到验证器的 on_val_start 回调中
assert test_func in val.callbacks["on_val_start"], "callback test failed"
# 执行验证,使用训练器中的最佳模型
val(model=trainer.best)
# 创建分类预测器对象,应用 imgsz 为 [64, 64] 的配置项
pred = classify.ClassificationPredictor(overrides={"imgsz": [64, 64]})
# 添加在预测开始时执行的回调函数 test_func
pred.add_callback("on_predict_start", test_func)
# 断言 test_func 是否成功添加到预测器的 on_predict_start 回调中
assert test_func in pred.callbacks["on_predict_start"], "callback test failed"
# 使用 ASSETS 中的数据源和训练器中的最佳模型进行预测
result = pred(source=ASSETS, model=trainer.best)
# 断言预测结果不为空,表示预测器测试通过
assert len(result), "predictor test failed"
.\yolov8\tests\test_explorer.py
# 导入必要的库和模块:PIL 图像处理库和 pytest 测试框架
import PIL
import pytest
# 从 ultralytics 包中导入 Explorer 类和 ASSETS 资源
from ultralytics import Explorer
from ultralytics.utils import ASSETS
# 使用 pytest 的标记 @pytest.mark.slow 标记此函数为慢速测试
@pytest.mark.slow
def test_similarity():
"""测试 Explorer 中相似性计算和 SQL 查询的正确性和返回长度。"""
# 创建 Explorer 对象,使用配置文件 'coco8.yaml'
exp = Explorer(data="coco8.yaml")
# 创建嵌入表格
exp.create_embeddings_table()
# 获取索引为 1 的相似项
similar = exp.get_similar(idx=1)
# 断言相似项的长度为 4
assert len(similar) == 4
# 使用图像文件 'bus.jpg' 获取相似项
similar = exp.get_similar(img=ASSETS / "bus.jpg")
# 断言相似项的长度为 4
assert len(similar) == 4
# 获取索引为 [1, 2] 的相似项,限制返回结果为 2 个
similar = exp.get_similar(idx=[1, 2], limit=2)
# 断言相似项的长度为 2
assert len(similar) == 2
# 获取相似性索引
sim_idx = exp.similarity_index()
# 断言相似性索引的长度为 4
assert len(sim_idx) == 4
# 执行 SQL 查询,查询条件为 'labels LIKE '%zebra%''
sql = exp.sql_query("WHERE labels LIKE '%zebra%'")
# 断言 SQL 查询结果的长度为 1
assert len(sql) == 1
@pytest.mark.slow
def test_det():
"""测试检测功能,并验证嵌入表格是否包含边界框。"""
# 创建 Explorer 对象,使用配置文件 'coco8.yaml' 和模型 'yolov8n.pt'
exp = Explorer(data="coco8.yaml", model="yolov8n.pt")
# 强制创建嵌入表格
exp.create_embeddings_table(force=True)
# 断言表格中的边界框列的长度大于 0
assert len(exp.table.head()["bboxes"]) > 0
# 获取索引为 [1, 2] 的相似项,限制返回结果为 10 个
similar = exp.get_similar(idx=[1, 2], limit=10)
# 断言相似项的长度大于 0
assert len(similar) > 0
# 执行绘制相似项的操作,返回值应为 PIL 图像对象
similar = exp.plot_similar(idx=[1, 2], limit=10)
# 断言返回值是 PIL 图像对象
assert isinstance(similar, PIL.Image.Image)
@pytest.mark.slow
def test_seg():
"""测试分割功能,并确保嵌入表格包含分割掩码。"""
# 创建 Explorer 对象,使用配置文件 'coco8-seg.yaml' 和模型 'yolov8n-seg.pt'
exp = Explorer(data="coco8-seg.yaml", model="yolov8n-seg.pt")
# 强制创建嵌入表格
exp.create_embeddings_table(force=True)
# 断言表格中的分割掩码列的长度大于 0
assert len(exp.table.head()["masks"]) > 0
# 获取索引为 [1, 2] 的相似项,限制返回结果为 10 个
similar = exp.get_similar(idx=[1, 2], limit=10)
# 断言相似项的长度大于 0
assert len(similar) > 0
# 执行绘制相似项的操作,返回值应为 PIL 图像对象
similar = exp.plot_similar(idx=[1, 2], limit=10)
# 断言返回值是 PIL 图像对象
assert isinstance(similar, PIL.Image.Image)
@pytest.mark.slow
def test_pose():
"""测试姿势估计功能,并验证嵌入表格是否包含关键点。"""
# 创建 Explorer 对象,使用配置文件 'coco8-pose.yaml' 和模型 'yolov8n-pose.pt'
exp = Explorer(data="coco8-pose.yaml", model="yolov8n-pose.pt")
# 强制创建嵌入表格
exp.create_embeddings_table(force=True)
# 断言表格中的关键点列的长度大于 0
assert len(exp.table.head()["keypoints"]) > 0
# 获取索引为 [1, 2] 的相似项,限制返回结果为 10 个
similar = exp.get_similar(idx=[1, 2], limit=10)
# 断言相似项的长度大于 0
assert len(similar) > 0
# 执行绘制相似项的操作,返回值应为 PIL 图像对象
similar = exp.plot_similar(idx=[1, 2], limit=10)
# 断言返回值是 PIL 图像对象
assert isinstance(similar, PIL.Image.Image)
.\yolov8\tests\test_exports.py
# 导入所需的库和模块
import shutil # 文件操作工具,用于复制、移动和删除文件和目录
import uuid # 用于生成唯一的UUID
from itertools import product # 用于生成迭代器的笛卡尔积
from pathlib import Path # 用于处理文件和目录路径的类
import pytest # 测试框架
# 导入测试所需的模块和函数
from tests import MODEL, SOURCE
from ultralytics import YOLO # 导入YOLO模型
from ultralytics.cfg import TASK2DATA, TASK2MODEL, TASKS # 导入配置信息
from ultralytics.utils import (
IS_RASPBERRYPI, # 检查是否在树莓派上运行
LINUX, # 检查是否在Linux系统上运行
MACOS, # 检查是否在macOS系统上运行
WINDOWS, # 检查是否在Windows系统上运行
checks, # 各种系统和Python版本的检查工具集合
)
from ultralytics.utils.torch_utils import TORCH_1_9, TORCH_1_13 # Torch相关的工具函数和版本检查
# 测试导出 YOLO 模型到 ONNX 格式,使用不同的配置和参数进行测试
def test_export_onnx_matrix(task, dynamic, int8, half, batch, simplify):
# 调用 YOLO 类,根据任务选择相应的模型,然后导出为 ONNX 格式的文件
file = YOLO(TASK2MODEL[task]).export(
format="onnx",
imgsz=32,
dynamic=dynamic,
int8=int8,
half=half,
batch=batch,
simplify=simplify,
)
# 使用导出的模型进行推理,传入相同的源数据多次以达到批处理要求
YOLO(file)([SOURCE] * batch, imgsz=64 if dynamic else 32) # exported model inference
# 清理生成的文件
Path(file).unlink() # cleanup
@pytest.mark.slow
@pytest.mark.parametrize("task, dynamic, int8, half, batch", product(TASKS, [False], [False], [False], [1, 2]))
# 测试导出 YOLO 模型到 TorchScript 格式,考虑不同的配置和参数组合
def test_export_torchscript_matrix(task, dynamic, int8, half, batch):
# 调用 YOLO 类,根据任务选择相应的模型,然后导出为 TorchScript 格式的文件
file = YOLO(TASK2MODEL[task]).export(
format="torchscript",
imgsz=32,
dynamic=dynamic,
int8=int8,
half=half,
batch=batch,
)
# 使用导出的模型进行推理,传入特定的源数据以达到批处理要求
YOLO(file)([SOURCE] * 3, imgsz=64 if dynamic else 32) # exported model inference at batch=3
# 清理生成的文件
Path(file).unlink() # cleanup
@pytest.mark.slow
# 在 macOS 上测试导出 YOLO 模型到 CoreML 格式,使用各种参数配置
@pytest.mark.skipif(not MACOS, reason="CoreML inference only supported on macOS")
@pytest.mark.skipif(not TORCH_1_9, reason="CoreML>=7.2 not supported with PyTorch<=1.8")
@pytest.mark.skipif(checks.IS_PYTHON_3_12, reason="CoreML not supported in Python 3.12")
@pytest.mark.parametrize(
"task, dynamic, int8, half, batch",
[ # 生成所有组合,但排除 int8 和 half 都为 True 的情况
(task, dynamic, int8, half, batch)
for task, dynamic, int8, half, batch in product(TASKS, [False], [True, False], [True, False], [1])
if not (int8 and half) # 排除 int8 和 half 都为 True 的情况
],
)
def test_export_coreml_matrix(task, dynamic, int8, half, batch):
# 调用 YOLO 类,根据任务选择相应的模型,然后导出为 CoreML 格式的文件
file = YOLO(TASK2MODEL[task]).export(
format="coreml",
imgsz=32,
dynamic=dynamic,
int8=int8,
half=half,
batch=batch,
)
# 使用导出的模型进行推理,传入特定的源数据以达到批处理要求
YOLO(file)([SOURCE] * batch, imgsz=32) # exported model inference at batch=3
# 清理生成的文件夹
shutil.rmtree(file) # cleanup
@pytest.mark.slow
# 在 Python 版本大于等于 3.10 时,在 Linux 上测试导出 YOLO 模型到 TFLite 格式
@pytest.mark.skipif(not checks.IS_PYTHON_MINIMUM_3_10, reason="TFLite export requires Python>=3.10")
@pytest.mark.skipif(not LINUX, reason="Test disabled as TF suffers from install conflicts on Windows and macOS")
@pytest.mark.parametrize(
"task, dynamic, int8, half, batch",
[ # 生成所有组合,但排除 int8 和 half 都为 True 的情况
(task, dynamic, int8, half, batch)
for task, dynamic, int8, half, batch in product(TASKS, [False], [True, False], [True, False], [1])
if not (int8 and half) # 排除 int8 和 half 都为 True 的情况
],
)
# 测试导出 YOLO 模型到 TFLite 格式,考虑各种导出配置
def test_export_tflite_matrix(task, dynamic, int8, half, batch):
# 调用 YOLO 类,根据任务选择相应的模型,然后导出为 TFLite 格式的文件
file = YOLO(TASK2MODEL[task]).export(
format="tflite",
imgsz=32,
dynamic=dynamic,
int8=int8,
half=half,
batch=batch,
)
# 使用导出的模型进行推理,传入特定的源数据以达到批处理要求
YOLO(file)([SOURCE] * batch, imgsz=32) # exported model inference at batch=3
# 清理生成的文件夹
shutil.rmtree(file) # cleanup
# 使用指定任务的模型从YOLO导出模型,并以tflite格式输出到文件
file = YOLO(TASK2MODEL[task]).export(
format="tflite",
imgsz=32,
dynamic=dynamic,
int8=int8,
half=half,
batch=batch,
)
# 使用导出的模型进行推理,输入为[SOURCE]的重复项,批量大小为3,图像尺寸为32
YOLO(file)([SOURCE] * batch, imgsz=32) # 批量大小为3时导出模型的推理
# 删除导出的模型文件,进行清理工作
Path(file).unlink() # 清理
# 根据条件跳过测试,若 TORCH_1_9 为假则跳过,提示 PyTorch<=1.8 不支持 CoreML>=7.2
@pytest.mark.skipif(not TORCH_1_9, reason="CoreML>=7.2 not supported with PyTorch<=1.8")
# 若在 Windows 系统上则跳过,提示 CoreML 在 Windows 上不受支持
@pytest.mark.skipif(WINDOWS, reason="CoreML not supported on Windows") # RuntimeError: BlobWriter not loaded
# 若在树莓派上则跳过,提示 CoreML 在树莓派上不受支持
@pytest.mark.skipif(IS_RASPBERRYPI, reason="CoreML not supported on Raspberry Pi")
# 若 Python 版本为 3.12 则跳过,提示 CoreML 不支持 Python 3.12
@pytest.mark.skipif(checks.IS_PYTHON_3_12, reason="CoreML not supported in Python 3.12")
def test_export_coreml():
"""Test YOLO exports to CoreML format, optimized for macOS only."""
if MACOS:
# 在 macOS 上导出 YOLO 模型到 CoreML 格式,并优化为指定的 imgsz 大小
file = YOLO(MODEL).export(format="coreml", imgsz=32)
# 使用导出的 CoreML 模型进行预测,仅支持在 macOS 上进行,对于 nms=False 的模型
YOLO(file)(SOURCE, imgsz=32) # model prediction only supported on macOS for nms=False models
else:
# 在非 macOS 系统上导出 YOLO 模型到 CoreML 格式,使用默认的 nms=True 和指定的 imgsz 大小
YOLO(MODEL).export(format="coreml", nms=True, imgsz=32)
# 若 Python 版本小于 3.10 则跳过,提示 TFLite 导出要求 Python>=3.10
@pytest.mark.skipif(not checks.IS_PYTHON_MINIMUM_3_10, reason="TFLite export requires Python>=3.10")
# 若不在 Linux 系统上则跳过,提示在 Windows 和 macOS 上 TensorFlow 安装可能会冲突
@pytest.mark.skipif(not LINUX, reason="Test disabled as TF suffers from install conflicts on Windows and macOS")
def test_export_tflite():
"""Test YOLO exports to TFLite format under specific OS and Python version conditions."""
# 创建 YOLO 模型对象
model = YOLO(MODEL)
# 导出 YOLO 模型到 TFLite 格式,使用指定的 imgsz 大小
file = model.export(format="tflite", imgsz=32)
# 使用导出的 TFLite 模型进行预测
YOLO(file)(SOURCE, imgsz=32)
# 直接跳过此测试,无特定原因说明
@pytest.mark.skipif(True, reason="Test disabled")
# 若不在 Linux 系统上则跳过,提示 TensorFlow 在 Windows 和 macOS 上安装可能会冲突
@pytest.mark.skipif(not LINUX, reason="TF suffers from install conflicts on Windows and macOS")
def test_export_pb():
"""Test YOLO exports to TensorFlow's Protobuf (*.pb) format."""
# 创建 YOLO 模型对象
model = YOLO(MODEL)
# 导出 YOLO 模型到 TensorFlow 的 Protobuf 格式,使用指定的 imgsz 大小
file = model.export(format="pb", imgsz=32)
# 使用导出的 Protobuf 模型进行预测
YOLO(file)(SOURCE, imgsz=32)
# 直接跳过此测试,无特定原因说明
@pytest.mark.skipif(True, reason="Test disabled as Paddle protobuf and ONNX protobuf requirementsk conflict.")
def test_export_paddle():
"""Test YOLO exports to Paddle format, noting protobuf conflicts with ONNX."""
# 导出 YOLO 模型到 Paddle 格式,使用指定的 imgsz 大小
YOLO(MODEL).export(format="paddle", imgsz=32)
# 标记为慢速测试
@pytest.mark.slow
def test_export_ncnn():
"""Test YOLO exports to NCNN format."""
# 导出 YOLO 模型到 NCNN 格式,使用指定的 imgsz 大小
file = YOLO(MODEL).export(format="ncnn", imgsz=32)
# 使用导出的 NCNN 模型进行预测
YOLO(file)(SOURCE, imgsz=32) # exported model inference
.\yolov8\tests\test_integrations.py
# Ultralytics YOLO
标签:二十六,ultralytics,cfg,YOLO,Yolov8,源码,test,import,model
From: https://www.cnblogs.com/apachecn/p/18398107