**
声明:博客内所有文章代码仅供参考!
**
如何训练这个——草莓成熟度检测数据集,共800余张大棚内实景拍摄,区分为成熟,未成熟,草莓花梗三类,提供yolo标注,1.4GB
草莓成熟度检测数据集,共800余张大棚内实景拍摄,区分为成熟,未成熟,草莓花梗三类,提供yolo标注,1.4GB
构建一个用于草莓成熟度检测的YOLOv5模型。我们将会创建以下文件:
train.py
- 训练脚本datasets.py
- 数据集定义config.yaml
- 配置文件requirements.txt
- 依赖项
config.yaml
首先,我们需要配置文件来指定训练参数、数据路径等。
# config.yaml
train: ../datasets/train/images/
val: ../datasets/valid/images/
nc: 3
names: ['unripe', 'ripe', 'flower']
requirements.txt
接下来,列出所有需要安装的Python包。
torch>=1.8
torchvision>=0.9
pycocotools
opencv-python
matplotlib
albumentations
datasets.py
定义数据集类以便于加载草莓成熟度检测的数据集。
import os
from pathlib import Path
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
class StrawberryDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = Path(root_dir)
self.transform = transform
self.img_files = list((self.root_dir / 'images').glob('*.jpg'))
self.label_files = [Path(str(img_file).replace('images', 'labels').replace('.jpg', '.txt')) for img_file in self.img_files]
def __len__(self):
return len(self.img_files)
def __getitem__(self, idx):
img_path = self.img_files[idx]
label_path = self.label_files[idx]
image = Image.open(img_path).convert("RGB")
boxes = []
labels = []
with open(label_path, 'r') as file:
lines = file.readlines()
for line in lines:
class_id, x_center, y_center, width, height = map(float, line.strip().split())
boxes.append([x_center, y_center, width, height])
labels.append(int(class_id))
if self.transform:
transformed = self.transform(image=np.array(image), bboxes=boxes, class_labels=labels)
image = transformed['image']
boxes = transformed['bboxes']
labels = transformed['class_labels']
target = {}
target['boxes'] = torch.tensor(boxes, dtype=torch.float32)
target['labels'] = torch.tensor(labels, dtype=torch.int64)
return image, target
# 定义数据增强
data_transforms = {
'train': A.Compose([
A.Resize(width=640, height=640),
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.5),
A.Rotate(limit=180, p=0.7),
A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.3),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2(),
], bbox_params=A.BboxParams(format='yolo')),
'test': A.Compose([
A.Resize(width=640, height=640),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2(),
], bbox_params=A.BboxParams(format='yolo')),
}
train.py
最后,编写训练脚本来训练YOLOv5模型。
import torch
import torch.optim as optim
from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2
from datasets import StrawberryDataset, data_transforms
from torch.utils.data import DataLoader
import yaml
import time
with open('config.yaml', 'r') as f:
config = yaml.safe_load(f)
def collate_fn(batch):
images = [item[0] for item in batch]
targets = [item[1] for item in batch]
images = torch.stack(images)
return images, targets
def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq):
model.train()
metric_logger = MetricLogger(delimiter=" ")
header = f"Epoch: [{epoch}]"
lr_scheduler = None
for images, targets in metric_logger.log_every(data_loader, print_freq, header):
images = list(image.to(device) for image in images)
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
loss_dict = model(images, targets)
losses = sum(loss for loss in loss_dict.values())
optimizer.zero_grad()
losses.backward()
optimizer.step()
if lr_scheduler is not None:
lr_scheduler.step()
metric_logger.update(loss=losses.item(), **loss_dict)
def main():
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
dataset_train = StrawberryDataset(root_dir=config['train'], transform=data_transforms['train'])
dataset_val = StrawberryDataset(root_dir=config['val'], transform=data_transforms['test'])
data_loader_train = DataLoader(dataset_train, batch_size=4, shuffle=True, num_workers=4, collate_fn=collate_fn)
data_loader_val = DataLoader(dataset_val, batch_size=4, shuffle=False, num_workers=4, collate_fn=collate_fn)
model = fasterrcnn_resnet50_fpn_v2(pretrained=True)
num_classes = config['nc'] + 1 # background + number of classes
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = torch.nn.Linear(in_features, num_classes)
model.to(device)
params = [p for p in model.parameters() if p.requires_grad]
optimizer = optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)
for epoch in range(10): # number of epochs
train_one_epoch(model, optimizer, data_loader_train, device=device, epoch=epoch, print_freq=10)
# save every epoch
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}, f'model_epoch_{epoch}.pth')
if __name__ == "__main__":
main()
总结
以上代码涵盖了从数据准备到模型训练的所有步骤。你可以根据需要调整配置文件中的参数,并运行训练脚本来开始训练YOLOv5模型。确保你的数据集目录结构符合预期,并且所有的文件路径都是正确的。
标签:成熟度,检测,草莓,torch,epoch,train,images,import,self From: https://blog.csdn.net/2401_86822270/article/details/144838837