首页 > 其他分享 >如何构建一个用于草莓成熟度检测的YOLOv5模型,并使用Yolov5训练使用草莓成熟度检测数据集模型 并实现可视化及评估推理 3类 yolo标注

如何构建一个用于草莓成熟度检测的YOLOv5模型,并使用Yolov5训练使用草莓成熟度检测数据集模型 并实现可视化及评估推理 3类 yolo标注

时间:2025-01-03 22:35:08浏览次数:3  
标签:成熟度 检测 草莓 torch epoch train images import self

**

声明:博客内所有文章代码仅供参考!

**

如何训练这个——草莓成熟度检测数据集,共800余张大棚内实景拍摄,区分为成熟,未成熟,草莓花梗三类,提供yolo标注,1.4GB在这里插入图片描述
草莓成熟度检测数据集,共800余张大棚内实景拍摄,区分为成熟,未成熟,草莓花梗三类,提供yolo标注,1.4GB在这里插入图片描述
在这里插入图片描述
构建一个用于草莓成熟度检测的YOLOv5模型。我们将会创建以下文件:

  1. train.py - 训练脚本
  2. datasets.py - 数据集定义
  3. config.yaml - 配置文件
  4. requirements.txt - 依赖项

config.yaml

首先,我们需要配置文件来指定训练参数、数据路径等。

# config.yaml
train: ../datasets/train/images/
val: ../datasets/valid/images/

nc: 3
names: ['unripe', 'ripe', 'flower']

requirements.txt

接下来,列出所有需要安装的Python包。

torch>=1.8
torchvision>=0.9
pycocotools
opencv-python
matplotlib
albumentations

datasets.py

定义数据集类以便于加载草莓成熟度检测的数据集。

import os
from pathlib import Path
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

class StrawberryDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = Path(root_dir)
        self.transform = transform
        self.img_files = list((self.root_dir / 'images').glob('*.jpg'))
        self.label_files = [Path(str(img_file).replace('images', 'labels').replace('.jpg', '.txt')) for img_file in self.img_files]

    def __len__(self):
        return len(self.img_files)

    def __getitem__(self, idx):
        img_path = self.img_files[idx]
        label_path = self.label_files[idx]

        image = Image.open(img_path).convert("RGB")
        boxes = []
        labels = []

        with open(label_path, 'r') as file:
            lines = file.readlines()
            for line in lines:
                class_id, x_center, y_center, width, height = map(float, line.strip().split())
                boxes.append([x_center, y_center, width, height])
                labels.append(int(class_id))

        if self.transform:
            transformed = self.transform(image=np.array(image), bboxes=boxes, class_labels=labels)
            image = transformed['image']
            boxes = transformed['bboxes']
            labels = transformed['class_labels']

        target = {}
        target['boxes'] = torch.tensor(boxes, dtype=torch.float32)
        target['labels'] = torch.tensor(labels, dtype=torch.int64)

        return image, target

# 定义数据增强
data_transforms = {
    'train': A.Compose([
        A.Resize(width=640, height=640),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.Rotate(limit=180, p=0.7),
        A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.3),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ], bbox_params=A.BboxParams(format='yolo')),
    'test': A.Compose([
        A.Resize(width=640, height=640),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ], bbox_params=A.BboxParams(format='yolo')),
}

train.py

最后,编写训练脚本来训练YOLOv5模型。

import torch
import torch.optim as optim
from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2
from datasets import StrawberryDataset, data_transforms
from torch.utils.data import DataLoader
import yaml
import time

with open('config.yaml', 'r') as f:
    config = yaml.safe_load(f)

def collate_fn(batch):
    images = [item[0] for item in batch]
    targets = [item[1] for item in batch]
    images = torch.stack(images)
    return images, targets

def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq):
    model.train()
    metric_logger = MetricLogger(delimiter="  ")
    header = f"Epoch: [{epoch}]"
    lr_scheduler = None

    for images, targets in metric_logger.log_every(data_loader, print_freq, header):
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

        if lr_scheduler is not None:
            lr_scheduler.step()

        metric_logger.update(loss=losses.item(), **loss_dict)

def main():
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

    dataset_train = StrawberryDataset(root_dir=config['train'], transform=data_transforms['train'])
    dataset_val = StrawberryDataset(root_dir=config['val'], transform=data_transforms['test'])

    data_loader_train = DataLoader(dataset_train, batch_size=4, shuffle=True, num_workers=4, collate_fn=collate_fn)
    data_loader_val = DataLoader(dataset_val, batch_size=4, shuffle=False, num_workers=4, collate_fn=collate_fn)

    model = fasterrcnn_resnet50_fpn_v2(pretrained=True)
    num_classes = config['nc'] + 1  # background + number of classes
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = torch.nn.Linear(in_features, num_classes)
    model.to(device)

    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)

    for epoch in range(10):  # number of epochs
        train_one_epoch(model, optimizer, data_loader_train, device=device, epoch=epoch, print_freq=10)

        # save every epoch
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        }, f'model_epoch_{epoch}.pth')

if __name__ == "__main__":
    main()

总结

以上代码涵盖了从数据准备到模型训练的所有步骤。你可以根据需要调整配置文件中的参数,并运行训练脚本来开始训练YOLOv5模型。确保你的数据集目录结构符合预期,并且所有的文件路径都是正确的。

标签:成熟度,检测,草莓,torch,epoch,train,images,import,self
From: https://blog.csdn.net/2401_86822270/article/details/144838837

相关文章

  • 使用Mask R-CNN模型来进行目标检测和实例分割 大规模高分辨率树种单木分割数据集 处理
    单木分割数据集。从14个不同树种类中分割和标注了23,000个树冠,采集使用了DJIPhantom4RTK无人机树种单木分割数据集。从14个不同树种类中分割和标注了23,000个树冠,采集使用了DJIPhantom4RTK无人机。正射tif影像,点云、arcgis详细标注单株树木矢量数据(并标明树木类型),数......
  • 如何利用深度学习框架训练使用 可以使用YOLOv5模型来进行目标检测 智慧化生产工地 钢
    如何训练自己的数据集——智慧化生产工地资产盘点,超大规模钢筋计数数据集,共23400组图像,多视角,多角度,多场景,采用voc方式标注。智慧化生产工地资产盘点,超大规模钢筋计数数据集,共23400组图像,多视角,多角度,多场景,采用voc方式标注。为了实现智慧工地资产盘点中的超大规模钢筋计......
  • NLP 中文拼写检测纠正论文-07-NLPTEA-2020中文语法错误诊断共享任务概述
    拼写纠正系列NLP中文拼写检测实现思路NLP中文拼写检测纠正算法整理NLP英文拼写算法,如果提升100W倍的性能?NLP中文拼写检测纠正Paperjava实现中英文拼写检查和错误纠正?可我只会写CRUD啊!一个提升英文单词拼写检测性能1000倍的算法?单词拼写纠正-03-leetcodeedit-d......
  • NLP 中文拼写检测纠正论文-07-NLPTEA-2020中文语法错误诊断共享任务概述
    拼写纠正系列NLP中文拼写检测实现思路NLP中文拼写检测纠正算法整理NLP英文拼写算法,如果提升100W倍的性能?NLP中文拼写检测纠正Paperjava实现中英文拼写检查和错误纠正?可我只会写CRUD啊!一个提升英文单词拼写检测性能1000倍的算法?单词拼写纠正-03-leetcode......
  • 利用MATLAB实现了视频图像行人识别与检测
    利用MATLAB实现了视频图像行人识别与检测资源文件列表piotr_toolbox/bbNms.m , 8611piotr_toolbox/pNms.rar , 22300piotr_toolbox/toolbox/channels/chnsCompute.m , 9239piotr_toolbox/toolbox/channels/chnsPyramid.m , 10558piotr_toolbox/toolbox/channels/chn......
  • DL00684-山体滑坡实例/语义分割检测完整python代码含数据集
    https://item.taobao.com/item.htm?ft=t&id=872378688356山体滑坡是引发重大自然灾害的常见地质现象,尤其在山区、丘陵等地带,滑坡不仅对人民生命财产安全构成威胁,还会造成环境破坏和基础设施损毁。传统的山体滑坡检测方法依赖人工监测、地质勘探和局部传感器,这些方法不仅反应速度......
  • 遥感目标检测 数据集
    遥感目标检测——DOTA_数据集-飞桨AIStudio星河社区全称DOTA:ALarge-scaleDatasetforObjectDeTectioninAerialImages数据集,包括1.0、1.5、2.0共2个版本,用于遥感目标检测源地址:https://captain-whu.github.io/DOTA/index.htmlhttps://aistudio.baidu.com/datasetdeta......
  • 【芳心科技】E. 超声波倒车视频检测
    实物效果图:实现功能:1.摄像头+显示。2.多个超声波测距3.把测距得到的距离以可视化线条表示在屏幕上4.显示温湿度5.显示时间6.距离报警,阈值可设定原理图:程序源码:视频链接:可前往抖音、B站、快手等视频平台搜索【芳心科技】或【芳芯科技】查看演示视频。资料......
  • H7-TOOL固件2.27发布,新增加40多款芯片脱机烧录,含多款车轨芯片,发布LUA API手册,CAN助手
    H7-TOOL详细介绍(含操作手册):http://www.armbbs.cn/forum.php?mod=viewthread&tid=89934【PC软件】V2.271.脱机烧录功能更新:  -prog_lib.c1拖16时部分成功时,修改start_prog_0()为返回ERROR  -高级脚本范例中新增1拖16使用VOUT输出结果状态  -修正GD32H7xx_1M和......
  • CPU-Z处理器检测工具 v2.13.0中文绿色单文件
    点击上方蓝字关注我前言CPU-Z是一个非常厉害的CPU检测小帮手。它能识别很多种类的CPU,而且打开和检测的速度都很快。这个工具能清楚地告诉我们关于CPU、主板、内存、显卡等硬件的详细信息,比如是哪个厂家生产的、处理器的名字、是怎么做出来的、封装技术怎么样,还有它们的运行频率......