深度学习目标检测框架训练使用YOLOv8训练钓鱼检测数据集 并构建一个基于YOLOv8的钓鱼检测系统
使用YOLOv8训练钓鱼检测数据集,如何针对钓鱼检测进行调整和实现的详细步骤。
1. 安装依赖
确保安装了必要的库。对于钓鱼检测,所需的库应该与之前提供的相同,但请根据实际情况检查是否需要额外的库。
pip install torch torchvision ultralytics pyqt5 opencv-python pandas
2. 数据准备
假设您的钓鱼检测数据集已经按照YOLO标准格式进行了标注,并且包含1062张图片和1248个标注。数据集目录结构应如下:
训练使用钓鱼检测数据集 1000张 钓鱼垂钓检测 带标注 voc yolo
datasets/
└── fishing_detection/
├── images/
│ ├── train/
│ └── val/
├── labels_yolo/
│ ├── train/
│ └── val/
每个图像对应一个同名的.txt
文件(YOLO格式),而标签文件是CSV或XML格式的注释文件。如果标注是VOC格式,您可能需要转换为YOLO格式。
将VOC格式转换为YOLO格式
如果您有VOC格式的标注文件,可以使用类似如下的Python脚本来转换它们到YOLO格式。
voc_to_yolo.py
import xml.etree.ElementTree as ET
import os
def convert_voc_to_yolo(xml_file, output_dir, class_name='fishing'):
tree = ET.parse(xml_file)
root = tree.getroot()
img_width = int(root.find('size/width').text)
img_height = int(root.find('size/height').text)
with open(output_dir + '/' + root.find('filename').text.split('.')[0] + '.txt', 'w') as f:
for obj in root.findall('object'):
if obj.find('name').text != class_name:
continue
bbox = obj.find('bndbox')
x_min = float(bbox.find('xmin').text)
y_min = float(bbox.find('ymin').text)
x_max = float(bbox.find('xmax').text)
y_max = float(bbox.find('ymax').text)
x_center = ((x_min + x_max) / 2) / img_width
y_center = ((y_min + y_max) / 2) / img_height
bbox_width = (x_max - x_min) / img_width
bbox_height = (y_max - y_min) / img_height
f.write(f"0 {x_center} {y_center} {bbox_width} {bbox_height}\n")
# 调用函数进行转换
for filename in os.listdir('path/to/voc/annotations'):
if filename.endswith('.xml'):
convert_voc_to_yolo('path/to/voc/annotations/' + filename, 'labels_yolo/train')
3. 文件内容
3.1 Config.py
配置文件用于定义数据集路径、模型路径等。对于单类别(钓鱼)问题,类的数量设置为1。
# Config.py
DATASET_PATH = 'datasets/fishing_detection/'
MODEL_PATH = 'runs/detect/train/weights/best.pt'
IMG_SIZE = 640
BATCH_SIZE = 16
EPOCHS = 50
CONF_THRESHOLD = 0.5
3.2 train.py
修改train.py
以适应钓鱼检测任务。由于只有1个类别,nc: 1
,并且类别名称是’fishing’。
from ultralytics import YOLO
import os
# Load a model
model = YOLO('yolov8n.pt') # You can also use other versions like yolov8s.pt, yolov8m.pt, etc.
# Define dataset configuration
dataset_config = f"""
train: {os.path.join(os.getenv('DATASET_PATH', 'datasets/fishing_detection/'), 'images/train')}
val: {os.path.join(os.getenv('DATASET_PATH', 'datasets/fishing_detection/'), 'images/val')}
nc: 1
names: ['fishing']
"""
# Save dataset configuration to a YAML file
with open('fishing.yaml', 'w') as f:
f.write(dataset_config)
# Train the model
results = model.train(data='fishing.yaml', epochs=int(os.getenv('EPOCHS', 50)), imgsz=int(os.getenv('IMG_SIZE', 640)), batch=int(os.getenv('BATCH_SIZE', 16)))
3.3 detect_tools.py 和 3.4 UIProgram/MainProgram.py
这些文件保持不变,因为它们是通用的YOLOv8检测工具和GUI界面代码。
3.5 requirements.txt 和 3.6 setup.py
同样保持不变,除非您需要添加额外的依赖项。
3.7 README.md
更新README文档以反映新的项目细节,例如数据集信息、训练命令和GUI启动命令。
4. 运行步骤
- 确保数据集路径正确:将您的数据集放在
datasets/fishing_detection
目录下。 - 安装必要的库:确保已安装所有所需库。
- 运行代码:
- 首先运行训练代码来训练YOLOv8模型:
python train.py
- 然后运行GUI代码来启动检测系统:
python UIProgram/MainProgram.py
- 首先运行训练代码来训练YOLOv8模型:
注意事项
- 如果您使用的是不同的预训练权重,请替换
'yolov8n.pt'
为相应的权重文件。 - 确保环境变量已正确设置,或者直接在代码中硬编码路径和其他参数。
- 根据实际需求调整
IMG_SIZE
,BATCH_SIZE
,EPOCHS
等超参数。
以上就是使用YOLOv8训练钓鱼检测数据集的完整过程。
构建一个基于YOLOv8的钓鱼检测系统涉及到多个步骤,包括环境设置、数据准备、模型训练、评估和推理部署。以下是一个详细的指南,帮助你完成整个过程。
1. 环境设置
确保你的开发环境已经安装了必要的库和工具。如果你还没有安装这些依赖项,请按照之前的建议进行安装:
pip install torch torchvision ultralytics pyqt5 opencv-python pandas
2. 数据准备
2.1 数据集结构
假设你的数据集已经按照YOLO标准格式进行了标注,并且包含1062张图片和1248个标注。确保数据集目录结构如下:
datasets/
└── fishing_detection/
├── images/
│ ├── train/
│ └── val/
├── labels_yolo/
│ ├── train/
│ └── val/
每个图像对应一个同名的.txt
文件(YOLO格式),而标签文件是CSV或XML格式的注释文件。如果标注是VOC格式,你需要将其转换为YOLO格式。
2.2 转换VOC到YOLO格式
如果你的数据是以VOC格式提供的,可以使用Python脚本将它们转换为YOLO格式:
import xml.etree.ElementTree as ET
import os
def convert_voc_to_yolo(xml_file, output_dir, class_name='fishing'):
tree = ET.parse(xml_file)
root = tree.getroot()
img_width = int(root.find('size/width').text)
img_height = int(root.find('size/height').text)
with open(os.path.join(output_dir, root.find('filename').text.split('.')[0] + '.txt'), 'w') as f:
for obj in root.findall('object'):
if obj.find('name').text != class_name:
continue
bbox = obj.find('bndbox')
x_min = float(bbox.find('xmin').text)
y_min = float(bbox.find('ymin').text)
x_max = float(bbox.find('xmax').text)
y_max = float(bbox.find('ymax').text)
x_center = ((x_min + x_max) / 2) / img_width
y_center = ((y_min + y_max) / 2) / img_height
bbox_width = (x_max - x_min) / img_width
bbox_height = (y_max - y_min) / img_height
f.write(f"0 {x_center} {y_center} {bbox_width} {bbox_height}\n")
# Example usage
for filename in os.listdir('path/to/voc/annotations'):
if filename.endswith('.xml'):
convert_voc_to_yolo(os.path.join('path/to/voc/annotations', filename), 'labels_yolo/train')
3. 文件内容
3.1 Config.py
配置文件用于定义数据集路径、模型路径等。
# Config.py
DATASET_PATH = 'datasets/fishing_detection/'
MODEL_PATH = 'runs/detect/train/weights/best.pt'
IMG_SIZE = 640
BATCH_SIZE = 16
EPOCHS = 50
CONF_THRESHOLD = 0.5
3.2 train.py
训练YOLOv8模型的脚本。
from ultralytics import YOLO
import os
# Load a model
model = YOLO('yolov8n.pt') # You can also use other versions like yolov8s.pt, yolov8m.pt, etc.
# Define dataset configuration
dataset_config = f"""
train: {os.path.join(os.getenv('DATASET_PATH', 'datasets/fishing_detection/'), 'images/train')}
val: {os.path.join(os.getenv('DATASET_PATH', 'datasets/fishing_detection/'), 'images/val')}
nc: 1
names: ['fishing']
"""
# Save dataset configuration to a YAML file
with open('fishing.yaml', 'w') as f:
f.write(dataset_config)
# Train the model
results = model.train(data='fishing.yaml', epochs=int(os.getenv('EPOCHS', 50)), imgsz=int(os.getenv('IMG_SIZE', 640)), batch=int(os.getenv('BATCH_SIZE', 16)))
3.3 detect_tools.py
用于检测的工具函数。
from ultralytics import YOLO
import cv2
import numpy as np
def load_model(model_path):
return YOLO(model_path)
def detect_objects(frame, model, conf_threshold=0.5):
results = model(frame, conf=conf_threshold)
detections = []
for result in results:
boxes = result.boxes.cpu().numpy()
for box in boxes:
r = box.xyxy[0].astype(int)
cls = int(box.cls[0])
conf = round(float(box.conf[0]), 2)
label = f"fishing {conf}"
detections.append((r, label))
return detections
def draw_detections(frame, detections):
for (r, label) in detections:
cv2.rectangle(frame, (r[0], r[1]), (r[2], r[3]), (0, 255, 0), 2)
cv2.putText(frame, label, (r[0], r[1] - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
return frame
3.4 UIProgram/MainProgram.py
主程序,使用PyQt5构建图形界面。
import sys
import cv2
from PyQt5.QtWidgets import QApplication, QMainWindow, QLabel, QVBoxLayout, QWidget, QPushButton
from PyQt5.QtGui import QImage, QPixmap
from PyQt5.QtCore import Qt, QTimer
from detect_tools import load_model, detect_objects, draw_detections
import os
class VideoWindow(QMainWindow):
def __init__(self):
super().__init__()
self.setWindowTitle("Fishing Detection")
self.setGeometry(100, 100, 800, 600)
self.central_widget = QWidget()
self.setCentralWidget(self.central_widget)
self.layout = QVBoxLayout()
self.central_widget.setLayout(self.layout)
self.label = QLabel()
self.layout.addWidget(self.label)
self.start_button = QPushButton("Start Detection")
self.start_button.clicked.connect(self.start_detection)
self.layout.addWidget(self.start_button)
self.cap = None
self.timer = QTimer()
self.timer.timeout.connect(self.update_frame)
self.model = load_model(os.getenv('MODEL_PATH', 'runs/detect/train/weights/best.pt'))
def start_detection(self):
if not self.cap:
self.cap = cv2.VideoCapture(0) # Use webcam
self.timer.start(30)
def update_frame(self):
ret, frame = self.cap.read()
if not ret:
return
detections = detect_objects(frame, self.model, conf_threshold=float(os.getenv('CONF_THRESHOLD', 0.5)))
frame = draw_detections(frame, detections)
rgb_image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
h, w, ch = rgb_image.shape
bytes_per_line = ch * w
qt_image = QImage(rgb_image.data, w, h, bytes_per_line, QImage.Format_RGB888)
pixmap = QPixmap.fromImage(qt_image)
self.label.setPixmap(pixmap.scaled(800, 600, Qt.KeepAspectRatio))
if __name__ == "__main__":
app = QApplication(sys.argv)
window = VideoWindow()
window.show()
sys.exit(app.exec_())
3.5 requirements.txt
列出所有依赖项。
torch
torchvision
ultralytics
pyqt5
opencv-python
pandas
3.6 setup.py
用于安装项目的脚本。
from setuptools import setup, find_packages
setup(
name='fishing_detection',
version='0.1',
packages=find_packages(),
install_requires=[
'torch',
'torchvision',
'ultralytics',
'pyqt5',
'opencv-python',
'pandas'
],
entry_points={
'console_scripts': [
'train=train:main',
'detect=UIProgram.MainProgram:main'
]
}
)
3.7 README.md
项目说明文档。
# Fishing Detection System
This project uses YOLOv8 and PyQt5 to create a real-time fishing detection system for images. The system detects fishing activities in various environments.
## Installation
1. Clone the repository:
```bash
git clone https://github.com/yourusername/fishing-detection.git
cd fishing-detection
-
Install dependencies:
pip install -r requirements.txt
-
Set up environment variables (optional):
export DATASET_PATH=./datasets/fishing_detection/ export MODEL_PATH=./runs/detect/train/weights/best.pt export IMG_SIZE=640 export BATCH_SIZE=16 export EPOCHS=50 export CONF_THRESHOLD=0.5
Training
To train the YOLOv8 model:
python train.py
Running the GUI
To run the graphical user interface:
python UIProgram/MainProgram.py
Usage Tutorial
See 使用教程.xt for detailed usage instructions.
### 4. 运行步骤
- **确保数据集路径正确**:将你的数据集放在 `datasets/fishing_detection` 目录下。
- **安装必要的库**:确保已安装所有所需库。
- **运行代码**:
- 首先运行训练代码来训练YOLOv8模型:
```bash
python train.py
```
- 然后运行GUI代码来启动检测系统:
```bash
python UIProgram/MainProgram.py
```
### 5. 模型评估与优化
在训练完成后,你可以通过验证集评估模型性能,查看mAP(平均精度均值)和其他指标。根据评估结果,调整超参数如学习率、批次大小、图像尺寸等,以优化模型性能。
### 6. 结果分析与可视化
利用内置的方法或自定义脚本来分析结果和可视化预测边界框。这有助于理解模型的表现并识别可能的改进点。
### 7. 用户界面开发
为了构建用户界面,你可以使用Flask或FastAPI等框架创建RESTful服务,或者直接用Streamlit这样的快速原型开发工具。上述代码中已经包含了使用PyQt5创建的简单GUI示例。
帮助你顺利构建基于YOLOv8的钓鱼检测系统。
标签:钓鱼,检测,self,py,YOLOv8,train,fishing,os,find
From: https://blog.csdn.net/2401_88441190/article/details/145196382