首页 > 其他分享 >用Detr训练自定义数据

用Detr训练自定义数据

时间:2024-12-26 22:19:36浏览次数:3  
标签:定义数据 img args boxes 训练 print path model Detr

前面记录了Detr及其改进Deformable Detr这一篇记录一下用Detr训练自己的数据集。先看下Detr附录中给出的大体源码,整体非常清晰。

接下来记录大体实现过程

一、数据准备

借助labelme对数据进行标注

然后将标注数据转换成COCO格式,得到以下几个文件

其中JPEGImages存放所有图片,Visualization存放可视化结果,annotations.json保存所有图片的标注信息

二、模型训练

2.1 编写DataLoader

在detr/datasets目录下创建一个custom_data.py文件用于处理自己的数据。创建一个类,主要包含__getitem____len__方法。

在新建一个build方法用于detr构建数据。

再到当前目录下的__init__.py文件中添加新的数据类型

def build_dataset(image_set, args):
    if args.dataset_file == 'coco':
        return build_coco(image_set, args)
    if args.dataset_file == 'coco_panoptic':
        # to avoid making panopticapi required for coco
        from .coco_panoptic import build as build_coco_panoptic
        return build_coco_panoptic(image_set, args)
    if args.dataset_file == 'tooth':
        from .custom_data import build as build_tooth  
        return build_tooth(image_set, args)

2.2 训练

修改配置参数
mian.py中新增数据路径参数

修改类别数量,在models/detr.py中修改类别数,类别数要设置为实际类型+1,加1是添加背景类。

num_classes = 2 if args.dataset_file != 'coco' else 91 

加载预训练模型

if args.resume:
    if args.resume.startswith('https'):
        checkpoint = torch.hub.load_state_dict_from_url(args.resume, map_location='cpu', check_hash=True)
    else:
        checkpoint = torch.load(args.resume, map_location='cpu')
        # ==============================================================
        # 这一段是修改了的,去除多余的参数,并将load_state_dict设置为strict=False,这样它便会只加载模型结构相同部分的预训练参数
        del checkpoint["model"]["class_embed.weight"]
        del checkpoint["model"]["class_embed.bias"]
        del checkpoint["model"]["query_embed.weight"]
    model_without_ddp.load_state_dict(checkpoint['model'], strict=False)

开始训练

python -m torch.distributed.launch --nproc_per_node=1 --use_env main.py --tooth_path /home/jinhai_zhou/data/2D_seg/ --dataset_file tooth --output_dir ./output/path/box_model --resume "./models/detr-r50-e632da11.pth" 

我这里检测训练了500次左右开始收敛,分割训练了大概200多次开始接近收敛

如果训练分割模型,建议分两步,先训练检测模型,然后再训练分割头。

三、测试

新增一个predict.py文件,用于测试
里面主要包含检测和画图两部分内容

  • 检测
def detect(im, model, transform, threshold=0.7):
    # mean-std normalize the input image (batch-size: 1)
    img = transform(im).unsqueeze(0)
    print("image.shape:", img.shape)
    # demo model only support by default images with aspect ratio between 0.5 and 2
    # if you want to use images with an aspect ratio outside this range
    # rescale your image so that the maximum size is at most 1333 for best results
    # assert img.shape[-2] <= 1600 and img.shape[-1] <= 1600, 'demo model only supports images up to 1600 pixels on each side'

    # propagate through the model
    outputs = model(img)

    # keep only predictions with 0.7+ confidence
    probas = outputs['pred_logits'].softmax(-1)[0, :, :-1]
    keep = probas.max(-1).values > threshold

    # convert boxes from [0; 1] to image scales
    bboxes_scaled = rescale_bboxes(outputs['pred_boxes'][0, keep], im.size)
    return probas[keep], bboxes_scaled
  • 绘制结果
def plot_results(pil_img, prob, boxes, output):
    CLASSES = [
         'N/A', 'teeth'
    ]

    # colors for visualization
    COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
            [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]
    plt.figure(figsize=(16,10))
    plt.imshow(pil_img)
    ax = plt.gca()
    for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), COLORS * 100):
        ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
                                   fill=False, color=c, linewidth=3))
        cl = p.argmax()
        text = f'{CLASSES[cl]}: {p[cl]:0.2f}'
        ax.text(xmin, ymin, text, fontsize=15,
                bbox=dict(facecolor='yellow', alpha=0.5))
    plt.axis('off')
    plt.savefig(output)
    plt.close()
    # plt.show()

测试

    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed + utils.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    model, criterion, postprocessors = build_model(args)
    model.to(device)

    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print('number of params:', n_parameters)

    output_dir = Path(args.output_dir)
    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu') 
        model.load_state_dict(checkpoint['model'], strict=args.strict)
        print("load model {} is success!".format(args.resume))
    else:
        print("Don't load model!")
        return
    
    # standard PyTorch mean-std input image normalization
    transform = T.Compose([
        T.Resize(800),
        T.ToTensor(),
        T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    if args.img_path is not None:
        assert Path(args.img_path).is_file(), "{} not an image path".format(args.img_path)
        im = Image.open(img_path)
        scores, boxes = detect(im, model, transform=transform)
        print("scores: ", scores)
        print("boxes: ", boxes)

    if args.img_dirs is not None:
        assert Path(args.img_dirs).is_dir(), "{} not a dir path".format(args.img_dirs)
        img_paths = Path(args.img_dirs).glob("*.jpg")
        # print("loads {} images".format(len(list(img_paths))))
        for idx, img_path in enumerate(img_paths):
            print(img_path)
            im = Image.open(img_path)
            scores, boxes = detect(im, model, transform=transform)
            print(" scores: ", scores)
            print("boxes: ", boxes)
            out_path = Path(output_dir) / img_path.name
            print("out_path: ", out_path)
            plot_results(im, scores, boxes, out_path) 

标签:定义数据,img,args,boxes,训练,print,path,model,Detr
From: https://www.cnblogs.com/xiaxuexiaoab/p/18632312

相关文章

  • Manus手套动作捕捉AI训练灵巧手
    人工智能(AI)和机器人技术的融合日益紧密,使用真实动作数据+AI扩容训练机器人的方式正在被用于开发更富表现力的机器人。Manus手套凭借精准的动作捕捉技术和导出数据的强大兼容性,在灵巧手的研发和应用中发挥了重要作用。手部动作精确捕捉Manus手套在AI训练灵巧手方面的优势显......
  • 【神经网络训练过程可视化】
    一、直方图可视化数据分布1.知识介绍在PyTorch模型的每一层注册一个forwardhook,从而能够捕获每层的输出简单列表存储形式(只能顺序查看每层输出,下文会有改进版用字典将层名字和层输出值对应)activations=[]defhook_fn(module,input,output):activations.appe......
  • PyCharm专项训练4 最小生成树算法
    一、实验目的:本文的实验目的是通过编程实践,掌握并应用Prime算法和Kruskal算法来求解给定图的最小生成树问题。二、实验内容:数据准备:使用networkx库创建一个图G,并添加指定的节点和带权重的边。算法实现:实现Kruskal算法,通过构建最小生成树T,并找出构成最小生成树的边......
  • PyCharm专项训练5 最短路径算法
    一、实验目的    本文的实验目的是通过编程实践,掌握并应用Dijkstra(迪杰斯特拉)算法和Floyd(弗洛伊德)算法来解决图论中的最短路径问题。二、实验内容数据准备:使用邻接表的形式定义两个图graph_dijkstra和graph_floyd,图中包含节点以及节点之间的边和对应的权重。算......
  • 【数据集】【YOLO】【目标检测】灭火器识别数据集 3261 张,YOLO灭火器识别算法实战训练
     一、数据集介绍【数据集】灭火器识别数据集3261张,目标检测,包含YOLO/VOC格式标注。数据集中包含1种分类:names:['extinguisher'],表示"灭火器"。数据集图片来自国内外网站、网络爬虫、监控采集等;可用于监控和移动设备灭火器识别。检测场景为工业园区、办公大楼、居民楼......
  • 模型训练中性能指标
    在机器学习和深度学习的模型训练过程中,评估模型性能是至关重要的一环。不同的任务和应用场景可能会采用不同的评估指标,常见的包括准确率(Accuracy)、精确率或精度(Precision)、召回率(Recall)和均值平均精度(mAP)。本文将介绍这些评估指标的定义、计算方法及其在实际中的应用。1.Accur......
  • 【NLP】关于大模型训练常见概念讲解
    随着LLM学界和工业界日新月异的发展,不仅预训练所用的算力和数据正在疯狂内卷,后训练(post-training)的对齐和微调等方法也在不断更新。下面笔者根据资料整理一些关于大模型训练常见概念解释。前排提示,文末有大模型AGI-CSDN独家资料包哦!1Pre-training(预训练)预训练是指在模型......
  • Flink 训练项目教程
    Flink训练项目教程Flink训练项目教程flink-training-exercises项目地址:https://gitcode.com/gh_mirrors/fli/flink-training-exercises项目的目录结构及介绍Flink训练项目的目录结构如下:flink-training-exercises/├──build.gradle├──gradlew├──gradlew.ba......
  • RT-DETR学习笔记(1)
    视频教程:RT-DETR|2、backbone_哔哩哔哩_bilibili 一、图像预处理经过图像预处理、图像增强后的图片尺寸都为640*640超参数multi_scale设置了不同的尺寸sz是经过对multi-scale随机选择得到的一个尺寸,这里假设是576则640*640图像会通过双线性插值(interpolate)方法resize到576*......
  • springboot毕设少儿体能训练在线课程预约管理系统程序+论文+部署
    本系统(程序+源码)带文档lw万字以上 文末可获取一份本项目的java源码和数据库参考。系统程序文件列表开题报告内容一、研究背景随着社会发展和人们健康意识的提高,少儿体能训练受到越来越多的关注。现代社会中,少儿面临着各种电子设备的诱惑,户外活动和体能锻炼相对不足。同时......