如何训练自己的数据集——智慧化生产工地资产盘点,超大规模钢筋计数数据集,共23400组图像,多视角,多角度,多场景,采用voc方式标注。
智慧化生产工地资产盘点,超大规模钢筋计数数据集,共23400组图像,多视角,多角度,多场景,采用voc方式标注。
为了实现智慧工地资产盘点中的超大规模钢筋计数任务,我们可以使用YOLOv5模型来进行目标检测。以下是详细的步骤和代码示例,包括数据集定义、配置文件、训练脚本等。
目录结构
首先,确保你的项目目录结构如下:
/rebar_counting_project
/datasets
/train
/images
*.jpg
/annotations
*.xml
/valid
/images
*.jpg
/annotations
*.xml
/scripts
train.py
datasets.py
config.yaml
requirements.txt
config.yaml
配置文件 config.yaml
包含训练参数、数据路径等信息。
# config.yaml
train: ../datasets/train/images/
val: ../datasets/valid/images/
nc: 1
names: ['rebar']
requirements.txt
列出所有需要安装的Python包。
torch>=1.8
torchvision>=0.9
pycocotools
opencv-python
matplotlib
albumentations
labelme2coco
datasets.py
定义数据集类以便于加载钢筋计数的数据集,并进行数据增强。
import os
from pathlib import Path
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
import xml.etree.ElementTree as ET
class RebarDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = Path(root_dir)
self.transform = transform
self.img_files = list((self.root_dir / 'images').glob('*.jpg'))
self.label_files = [Path(str(img_file).replace('images', 'annotations').replace('.jpg', '.xml')) for img_file in self.img_files]
def __len__(self):
return len(self.img_files)
def __getitem__(self, idx):
img_path = self.img_files[idx]
label_path = self.label_files[idx]
image = Image.open(img_path).convert("RGB")
boxes = []
labels = []
tree = ET.parse(label_path)
root = tree.getroot()
size = root.find('size')
width = int(size.find('width').text)
height = int(size.find('height').text)
for obj in root.findall('object'):
bbox = obj.find('bndbox')
xmin = float(bbox.find('xmin').text) / width
ymin = float(bbox.find('ymin').text) / height
xmax = float(bbox.find('xmax').text) / width
ymax = float(bbox.find('ymax').text) / height
boxes.append([xmin, ymin, xmax, ymax])
labels.append(1) # Assuming only one class 'rebar'
if self.transform:
transformed = self.transform(image=np.array(image), bboxes=boxes, class_labels=labels)
image = transformed['image']
boxes = transformed['bboxes']
labels = transformed['class_labels']
target = {}
target['boxes'] = torch.tensor(boxes, dtype=torch.float32)
target['labels'] = torch.tensor(labels, dtype=torch.int64)
return image, target
# 定义数据增强
data_transforms = {
'train': A.Compose([
A.Resize(width=640, height=640),
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.5),
A.Rotate(limit=180, p=0.7),
A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.3),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2(),
], bbox_params=A.BboxParams(format='pascal_voc')),
'test': A.Compose([
A.Resize(width=640, height=640),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2(),
], bbox_params=A.BboxParams(format='pascal_voc')),
}
train.py
编写训练脚本来训练YOLOv5模型。
import torch
import torch.optim as optim
from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2
from datasets import RebarDataset, data_transforms
from torch.utils.data import DataLoader
import yaml
import time
with open('config.yaml', 'r') as f:
config = yaml.safe_load(f)
def collate_fn(batch):
images = [item[0] for item in batch]
targets = [item[1] for item in batch]
images = torch.stack(images)
return images, targets
def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq):
model.train()
metric_logger = MetricLogger(delimiter=" ")
header = f"Epoch: [{epoch}]"
for images, targets in metric_logger.log_every(data_loader, print_freq, header):
images = list(image.to(device) for image in images)
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
loss_dict = model(images, targets)
losses = sum(loss for loss in loss_dict.values())
optimizer.zero_grad()
losses.backward()
optimizer.step()
metric_logger.update(loss=losses.item(), **loss_dict)
class MetricLogger(object):
def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
def update(self, **kwargs):
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
self.meters[k].update(v)
def __getattr__(self, attr):
if attr in self.meters:
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError(f"'MetricLogger' object has no attribute '{attr}'")
def log_every(self, iterable, print_freq, header=None):
i = 0
if not header:
header = ""
start_time = time.time()
end = time.time()
iter_time = SmoothedValue(fmt='{avg:.4f}')
eta_string = SmoothedValue(fmt='{eta}')
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
log_msg = [
header,
'[{0' + space_fmt + '}/{1}]',
'eta: {eta}',
'{meters}',
'time: {time}'
]
if torch.cuda.is_available():
log_msg.append('max mem: {memory:.0f}')
log_msg = self.delimiter.join(log_msg)
MB = 1024.0 * 1024.0
for obj in iterable:
data_time.update(time.time() - end)
yield obj
iter_time.update(time.time() - end)
if i % print_freq == 0 or i == len(iterable) - 1:
eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string.update(datetime.timedelta(seconds=int(eta_seconds)))
if torch.cuda.is_available():
print(log_msg.format(
i, len(iterable), eta=eta_string, meters=str(self),
time=str(iter_time), memory=torch.cuda.max_memory_allocated() / MB))
else:
print(log_msg.format(
i, len(iterable), eta=eta_string, meters=str(self),
time=str(iter_time)))
i += 1
end = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('{} Total time: {} ({:.4f} s / it)'.format(
header, total_time_str, total_time / len(iterable)))
class SmoothedValue(object):
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def __init__(self, window_size=20, fmt=None):
if fmt is None:
fmt = "{median:.4f} ({global_avg:.4f})"
self.deque = deque(maxlen=window_size)
self.total = 0.0
self.count = 0
self.fmt = fmt
def update(self, value, n=1):
self.deque.append(value)
self.count += n
self.total += value * n
def synchronize_between_processes(self):
"""
Warning: does not synchronize the deque!
"""
if not is_dist_avail_and_initialized():
return
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
dist.barrier()
dist.all_reduce(t)
t = t.tolist()
self.count = int(t[0])
self.total = t[1]
@property
def median(self):
d = torch.tensor(list(self.deque))
return d.median().item()
@property
def avg(self):
d = torch.tensor(list(self.deque), dtype=torch.float32)
return d.mean().item()
@property
def global_avg(self):
return self.total / self.count
@property
def max(self):
return max(self.deque)
@property
def value(self):
return self.deque[-1]
def __str__(self):
return self.fmt.format(
median=self.median,
avg=self.avg,
global_avg=self.global_avg,
max=self.max,
value=self.value)
def main():
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
dataset_train = RebarDataset(root_dir=config['train'], transform=data_transforms['train'])
dataset_val = RebarDataset(root_dir=config['val'], transform=data_transforms['test'])
data_loader_train = DataLoader(dataset_train, batch_size=4, shuffle=True, num_workers=4, collate_fn=collate_fn)
data_loader_val = DataLoader(dataset_val, batch_size=4, shuffle=False, num_workers=4, collate_fn=collate_fn)
model = fasterrcnn_resnet50_fpn_v2(pretrained=True)
num_classes = config['nc'] + 1 # background + number of classes
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = torch.nn.Linear(in_features, num_classes)
model.to(device)
params = [p for p in model.parameters() if p.requires_grad]
optimizer = optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)
for epoch in range(10): # number of epochs
train_one_epoch(model, optimizer, data_loader_train, device=device, epoch=epoch, print_freq=10)
# save every epoch
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}, f'model_epoch_{epoch}.pth')
if __name__ == "__main__":
main()
总结
以上代码涵盖了从数据准备到模型训练的所有步骤。你可以根据需要调整配置文件中的参数,并运行训练脚本来开始训练Fast R-CNN模型。确保你的数据集目录结构符合预期,并且所有的文件路径都是正确的。
标签:__,训练,voc,模型,torch,train,time,self,def From: https://blog.csdn.net/2401_86822270/article/details/144838861