首页 > 其他分享 >mmdetection测试模型指标(输出每个类别的误检率、漏检率、正确率)

mmdetection测试模型指标(输出每个类别的误检率、漏检率、正确率)

时间:2022-10-18 15:38:00浏览次数:44  
标签:误检率 num img labels 漏检 正确率 boxes gt miss

#!/usr/bin/env python
# -*- coding:utf-8-*-
# file: model_test1.py
# @author: jory.d
# @contact: 
# @time: 2022/01/07 22:41
# @desc:  模型测试, 查看误检和漏检


"""
python tests/model_test1_qr_code.py

"""

import os
import os.path as osp
import copy
import json

import cv2
import numpy as np
import torch
from mmcv import Config
from mmcv.cnn import fuse_conv_bn
from mmcv.runner import (load_checkpoint)

from mmdet.core.bbox.iou_calculators import bbox_overlaps
from mmdet.datasets import build_dataset, build_dataloader
from mmdet.models import build_detector

BATCH_SIZE = 1
THRESHOLD = 0.6
CONFIG = "configs/xx/cfg_2_ssd_128x128.py"
WORK_DIR = "train_result/20221011_cfg_2_ssd_128x128-11"
EPOCH = 1340
CKPT = f"{WORK_DIR}/epoch_{EPOCH}.pth"
SAVE_DIR = f"{WORK_DIR}/test_2"
SAVE_IMG_FLAG = 1
TP_ASSIGN_IOU = 0.65

def build_model():
    cfg = Config.fromfile(CONFIG)
    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True
    _ckpt = CKPT
    if not osp.isfile(_ckpt) or not osp.exists(_ckpt):
        raise FileNotFoundError(f'{_ckpt} is not existed.')

    # build the dataloader
    dataset = build_dataset(cfg.data.test)
    data_loader = build_dataloader(
        dataset,
        samples_per_gpu=BATCH_SIZE,
        workers_per_gpu=0,
        dist=False,
        shuffle=False)

    # build the model and load checkpoint
    cfg.model.train_cfg = None
    _test_cfg = cfg.get('test_cfg')
    if _test_cfg is None:
        _test_cfg = cfg.model.get('test_cfg')
    assert _test_cfg is not None
    cfg.model['test_cfg']['score_thr'] = 0.2
    cfg.model['test_cfg']['nms']['iou_threshold'] = 0.5
    model = build_detector(cfg.model)
    checkpoint = load_checkpoint(model, _ckpt, map_location='cpu')
    model = fuse_conv_bn(model)
    model.eval()
    # old versions did not save class info in checkpoints, this walkaround is
    # for backward compatibility
    if 'CLASSES' in checkpoint.get('meta', {}):
        model.CLASSES = checkpoint['meta']['CLASSES']
    else:
        model.CLASSES = dataset.CLASSES

    return data_loader, model


def dummy_img():
    INPUT_W, INPUT_H,INPUT_C = 128,128,3
    MEAN = [128., ] * INPUT_C
    STD = [128., ] * INPUT_C
    # read pic
    img0 = cv2.imread("./14.png", cv2.IMREAD_COLOR)
    # img0 = np.random.randint(0, 255, [INPUT_H, INPUT_W, 3], dtype=np.uint8)
    h, w = img0.shape[:2]
    # ========================== preprocess img =============================
    img1 = cv2.resize(img0, (INPUT_W, INPUT_H))
    ori_image = cv2.cvtColor(img1, cv2.COLOR_BGR2RGB)
    # ori_image = img1
    if 1 == INPUT_C:
        ori_image = cv2.cvtColor(ori_image, cv2.COLOR_RGB2GRAY)
        input_image = ori_image.astype(np.float32) - np.asarray(MEAN)
        input_image /= np.asarray(STD)
        input_image = input_image[np.newaxis, ...]
    else:
        # [h,w,c]
        input_image = ori_image.astype(np.float32) - np.asarray([[MEAN]])
        input_image /= np.asarray([[STD]])
        input_image = np.transpose(input_image, [2, 0, 1])

    img_c, img_h, img_w = input_image.shape
    img_data = input_image[np.newaxis, :, :, :]
    return img_data


@torch.no_grad()
def model_test():
    _dataloader, _model = build_model()
    _num_class = len(_model.CLASSES)
    miss_num = {i: 0 for i in range(_num_class)}
    error_num, true_num = copy.deepcopy(miss_num), copy.deepcopy(miss_num)
    total_num = copy.deepcopy(miss_num)
    ignore_num = 0
    for i, data in enumerate(_dataloader):

        # if i >=1: break

        img_metas = data['img_metas'].data[0]
        img_tensors = data['img'].data[0]
        b, c, h, w = img_tensors.shape
        assert len(img_metas) == b

        print(img_tensors.dtype)
        a = np.asarray(dummy_img(), dtype=np.float32)
        img_tensors = torch.from_numpy(a)
        print(img_tensors)

        results = _model(return_loss=False, rescale=False, img=[img_tensors], img_metas=[img_metas])
        print(results)
        exit(10)
        for jj, meta in enumerate(img_metas):

            # print(meta)
            img_filepath = meta['filename']
            fname = osp.basename(img_filepath)
            ori_shape = meta['ori_shape']  # [h,w,c]
            mean = meta['img_norm_cfg']['mean']
            std = meta['img_norm_cfg']['std']
            gt_boxes = meta['gt_bboxes'].data
            gt_boxes_ignore = meta['gt_bboxes_ignore'].data
            gt_labels = meta['gt_labels'].data
            det_result = results[jj]    #[class_0_res, class_1_res, ...]

            # print('det_result: ', det_result)
            # print('det_result size: ', det_result)
            # print('gt_boxes: ', gt_boxes.numpy())
            for k, gt_label in enumerate(gt_labels.numpy()):
                total_num[gt_label] += 1

            ignore_num += gt_boxes_ignore.shape[0]

            miss_labels, error_labels = [], []
            miss_boxes, error_boxes = [], []
            true_labels, true_boxes = [], []
            if len(det_result) == 0:
                for k, gt_label in enumerate(gt_labels.numpy()):
                    miss_num[gt_label] += 1
                    miss_labels.append(gt_label)
                    miss_boxes.append(gt_boxes[k].numpy())
                continue

            for class_idx, det_bboxes in enumerate(det_result):
                det_bboxes = np.asarray(det_bboxes)  # [n,5]
                det_bboxes_src = det_bboxes[det_bboxes[:, -1] >= THRESHOLD]
                det_bboxes = det_bboxes_src[:, :4]
                _inds = (gt_labels == class_idx)
                _gt_boxes = gt_boxes[_inds]
                _gt_labels = gt_labels[_inds]

                # print('_gt_labels: ', _gt_labels)
                # print('_gt_boxes: ', _gt_boxes)
                # print(f'cls: {class_idx}, det_bboxes: {det_bboxes}')
                # continue
                det_num = det_bboxes.shape[0]
                gt_num = _gt_boxes.shape[0]
                gt_num2 = gt_boxes_ignore.shape[0]
                if det_num == 0 and gt_num > 0:
                    miss_num[class_idx] += 1
                    miss_labels.append(_gt_labels.numpy())
                    miss_boxes.append(_gt_boxes.numpy())
                elif gt_num == 0 and gt_num2 == 0 and det_num > 0:
                    error_num[class_idx] += 1
                    error_labels.append([class_idx for _ in range(det_num)])
                    error_boxes.append(det_bboxes_src)
                else:
                    _ious = bbox_overlaps(torch.from_numpy(det_bboxes), _gt_boxes, mode='iou')  # [n,m]
                    # print('_ious: ', _ious)
                    assign_inds = []
                    _tp_labels, _fp_labels = [], []
                    _tp_boxes, _fp_boxes = [], []
                    for i, iou_d2gts in enumerate(_ious):
                        is_assign = False
                        # foreach ious of det_box and gt_boxes
                        for j, iou in enumerate(iou_d2gts):
                            if iou >= TP_ASSIGN_IOU and class_idx == _gt_labels[j]:
                                is_assign = True
                                assign_inds.append(j)
                                _tp_labels.append(class_idx)
                                _tp_boxes.append(det_bboxes_src[i])
                                break

                        if not is_assign:
                            _fp_labels.append(class_idx)
                            _fp_boxes.append(det_bboxes_src[i])

                    # print(_tp_labels, _tp_boxes)

                    if len(_tp_labels)>0:
                        true_num[class_idx] += len(_tp_labels)
                        true_labels.append(_tp_labels)
                        true_boxes.append(_tp_boxes)
                    if len(_fp_labels)>0:
                        error_num[class_idx] += len(_fp_labels)
                        error_labels.append(_fp_labels)
                        error_boxes.append(_fp_boxes)

                    all_gt_inds = [i for i in range(_gt_boxes.size(0))]
                    if len(assign_inds) < len(all_gt_inds):
                        miss_inds = list(set(all_gt_inds).difference(set(assign_inds)))
                        if len(miss_inds) > 0:
                            miss_num[class_idx] += len(miss_inds)
                            miss_labels.append(_gt_labels[miss_inds].numpy())
                            miss_boxes.append(_gt_boxes[miss_inds].numpy())

            print(f'miss_num: {miss_num}', f'error_num: {error_num}', f'true_num: {true_num}',
                  f'total_num: {total_num}')
            _error_flag = len(error_labels) > 0
            _miss_flag = len(miss_labels) > 0
            _true_flag = len(true_labels) > 0
            # if not _error_flag and not _miss_flag:
            #     continue

            if _error_flag:
                # print('error_labels: ', error_labels)
                # print('error_boxes: ', error_boxes)
                error_labels = np.asarray(error_labels).reshape(-1)
                error_boxes = np.asarray(error_boxes).reshape(-1, 5)
            if _miss_flag:
                # print('miss_labels: ', miss_labels)
                # print('miss_boxes: ', miss_boxes)
                miss_labels = np.asarray(miss_labels).reshape(-1)
                miss_boxes = np.asarray(miss_boxes).reshape(-1, 4)

            if len(true_labels) > 0:
                true_labels = np.asarray(true_labels).reshape(-1)
                true_boxes = np.asarray(true_boxes).reshape(-1, 5)

            if SAVE_IMG_FLAG:
                img = img_tensors[jj].permute(1, 2, 0)  # [c,h,w]
                img = img * std + mean
                _ch = img.size(-1)
                img = np.asarray(img.numpy(), dtype=np.uint8)
                if _ch == 1:
                    img = cv2.cvtColor(np.squeeze(img), cv2.COLOR_GRAY2BGR)
                else:
                    img = img[:, :, ::-1]
                    img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)

                h, w = img.shape[:2]
                for idx, (cls, box) in enumerate(zip(gt_labels, gt_boxes)):
                    x1, y1, x2, y2 = list(map(int, box))
                    cls = int(cls)
                    cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 1)
                    cv2.putText(img, f"cls:{cls}", (x1, max(5, y1 - 5)),
                                cv2.FONT_HERSHEY_COMPLEX, 0.5, (0, 255, 0), 1)

                for idx, (cls, box) in enumerate(zip(true_labels, true_boxes)):
                    x1, y1, x2, y2 = list(map(int, box[:4]))
                    p = box[-1]
                    cls = int(cls)
                    cv2.rectangle(img, (x1, y1), (x2, y2), (255, 0, 0), 1)
                    # print(f'============= {x1,y1, x2,y2, p, cls}')
                    cv2.putText(img, "cls:{}({:.2f})".format(cls, p), (x1, min(h - 10, y2 + 5)),
                                cv2.FONT_HERSHEY_COMPLEX, 0.5, (255, 0, 0), 1)

                for idx, (cls, box) in enumerate(zip(miss_labels, miss_boxes)):
                    x1, y1, x2, y2 = list(map(int, box))
                    cv2.rectangle(img, (x1, y1), (x2, y2), (255, 255, 0), 3)

                for idx, (cls, box) in enumerate(zip(error_labels, error_boxes)):
                    x1, y1, x2, y2 = list(map(int, box[:4]))
                    p = box[-1]
                    cls = int(cls)
                    cv2.rectangle(img, (x1, y1), (x2, y2), (0, 0, 255), 1)
                    cv2.putText(img, "cls:{}({:.2f})".format(cls, p), (x1, min(h - 10, y2 + 5)),
                                cv2.FONT_HERSHEY_COMPLEX, 0.5, (0, 0, 255), 1)

                img = cv2.resize(img, dsize=(0, 0), fx=2, fy=2)
                if _miss_flag:
                    _savefilepath = f'{SAVE_DIR}/miss/{fname}'
                    os.makedirs(osp.dirname(_savefilepath), exist_ok=True)
                    cv2.imwrite(_savefilepath, img)
                if _error_flag:
                    _savefilepath = f'{SAVE_DIR}/error/{fname}'
                    os.makedirs(osp.dirname(_savefilepath), exist_ok=True)
                    cv2.imwrite(_savefilepath, img)
                if _true_flag:
                    _savefilepath = f'{SAVE_DIR}/true/{fname}'
                    os.makedirs(osp.dirname(_savefilepath), exist_ok=True)
                    cv2.imwrite(_savefilepath, img)

    #####################################################################
    _report_json_path = f'{SAVE_DIR}/report.json'
    _data = {i: 0 for i in range(_num_class)}
    for i in range(_num_class):
        _data[i] = {
            "total": total_num[i],
            "ignore": ignore_num,
            "miss": miss_num[i],
            "miss_p": miss_num[i] / total_num[i],
            "error": error_num[i],
            "error_p": error_num[i] / total_num[i],
            "true": true_num[i],
            "true_p": true_num[i] / total_num[i],
        }

    from pprint import pprint
    pprint(_data)
    os.makedirs(osp.dirname(_report_json_path), exist_ok=True)
    with open(_report_json_path, 'w', encoding='utf-8') as f:
        json.dump(_data, f, indent=4)

    print('done.')


if __name__ == '__main__':
    model_test()

 

标签:误检率,num,img,labels,漏检,正确率,boxes,gt,miss
From: https://www.cnblogs.com/dxscode/p/16802672.html

相关文章