前面记录了Detr及其改进Deformable Detr。这一篇记录一下用Detr训练自己的数据集。先看下Detr附录中给出的大体源码,整体非常清晰。
接下来记录大体实现过程
一、数据准备
借助labelme对数据进行标注
然后将标注数据转换成COCO格式,得到以下几个文件
其中JPEGImages
存放所有图片,Visualization
存放可视化结果,annotations.json
保存所有图片的标注信息
二、模型训练
2.1 编写DataLoader
在detr/datasets目录下创建一个custom_data.py
文件用于处理自己的数据。创建一个类,主要包含__getitem__
和__len__
方法。
在新建一个build
方法用于detr构建数据。
再到当前目录下的__init__.py
文件中添加新的数据类型
def build_dataset(image_set, args):
if args.dataset_file == 'coco':
return build_coco(image_set, args)
if args.dataset_file == 'coco_panoptic':
# to avoid making panopticapi required for coco
from .coco_panoptic import build as build_coco_panoptic
return build_coco_panoptic(image_set, args)
if args.dataset_file == 'tooth':
from .custom_data import build as build_tooth
return build_tooth(image_set, args)
2.2 训练
修改配置参数
在mian.py
中新增数据路径参数
修改类别数量,在models/detr.py
中修改类别数,类别数要设置为实际类型+1,加1是添加背景类。
num_classes = 2 if args.dataset_file != 'coco' else 91
加载预训练模型
if args.resume:
if args.resume.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url(args.resume, map_location='cpu', check_hash=True)
else:
checkpoint = torch.load(args.resume, map_location='cpu')
# ==============================================================
# 这一段是修改了的,去除多余的参数,并将load_state_dict设置为strict=False,这样它便会只加载模型结构相同部分的预训练参数
del checkpoint["model"]["class_embed.weight"]
del checkpoint["model"]["class_embed.bias"]
del checkpoint["model"]["query_embed.weight"]
model_without_ddp.load_state_dict(checkpoint['model'], strict=False)
开始训练
python -m torch.distributed.launch --nproc_per_node=1 --use_env main.py --tooth_path /home/jinhai_zhou/data/2D_seg/ --dataset_file tooth --output_dir ./output/path/box_model --resume "./models/detr-r50-e632da11.pth"
我这里检测训练了500次左右开始收敛,分割训练了大概200多次开始接近收敛
如果训练分割模型,建议分两步,先训练检测模型,然后再训练分割头。
三、测试
新增一个predict.py
文件,用于测试
里面主要包含检测和画图两部分内容
- 检测
def detect(im, model, transform, threshold=0.7):
# mean-std normalize the input image (batch-size: 1)
img = transform(im).unsqueeze(0)
print("image.shape:", img.shape)
# demo model only support by default images with aspect ratio between 0.5 and 2
# if you want to use images with an aspect ratio outside this range
# rescale your image so that the maximum size is at most 1333 for best results
# assert img.shape[-2] <= 1600 and img.shape[-1] <= 1600, 'demo model only supports images up to 1600 pixels on each side'
# propagate through the model
outputs = model(img)
# keep only predictions with 0.7+ confidence
probas = outputs['pred_logits'].softmax(-1)[0, :, :-1]
keep = probas.max(-1).values > threshold
# convert boxes from [0; 1] to image scales
bboxes_scaled = rescale_bboxes(outputs['pred_boxes'][0, keep], im.size)
return probas[keep], bboxes_scaled
- 绘制结果
def plot_results(pil_img, prob, boxes, output):
CLASSES = [
'N/A', 'teeth'
]
# colors for visualization
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
[0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]
plt.figure(figsize=(16,10))
plt.imshow(pil_img)
ax = plt.gca()
for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), COLORS * 100):
ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
fill=False, color=c, linewidth=3))
cl = p.argmax()
text = f'{CLASSES[cl]}: {p[cl]:0.2f}'
ax.text(xmin, ymin, text, fontsize=15,
bbox=dict(facecolor='yellow', alpha=0.5))
plt.axis('off')
plt.savefig(output)
plt.close()
# plt.show()
测试
device = torch.device(args.device)
# fix the seed for reproducibility
seed = args.seed + utils.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
model, criterion, postprocessors = build_model(args)
model.to(device)
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('number of params:', n_parameters)
output_dir = Path(args.output_dir)
if args.resume:
checkpoint = torch.load(args.resume, map_location='cpu')
model.load_state_dict(checkpoint['model'], strict=args.strict)
print("load model {} is success!".format(args.resume))
else:
print("Don't load model!")
return
# standard PyTorch mean-std input image normalization
transform = T.Compose([
T.Resize(800),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
if args.img_path is not None:
assert Path(args.img_path).is_file(), "{} not an image path".format(args.img_path)
im = Image.open(img_path)
scores, boxes = detect(im, model, transform=transform)
print("scores: ", scores)
print("boxes: ", boxes)
if args.img_dirs is not None:
assert Path(args.img_dirs).is_dir(), "{} not a dir path".format(args.img_dirs)
img_paths = Path(args.img_dirs).glob("*.jpg")
# print("loads {} images".format(len(list(img_paths))))
for idx, img_path in enumerate(img_paths):
print(img_path)
im = Image.open(img_path)
scores, boxes = detect(im, model, transform=transform)
print(" scores: ", scores)
print("boxes: ", boxes)
out_path = Path(output_dir) / img_path.name
print("out_path: ", out_path)
plot_results(im, scores, boxes, out_path)
标签:定义数据,img,args,boxes,训练,print,path,model,Detr
From: https://www.cnblogs.com/xiaxuexiaoab/p/18632312