首页 > 编程问答 >通过预训练模型预测对象不起作用

通过预训练模型预测对象不起作用

时间:2024-07-31 03:48:22浏览次数:7  
标签:python tensorflow machine-learning object-detection yolo

我想通过给定图像作为输入来预测对象,并希望我的模型能够预测标签。我已经使用基于带注释的数据库的张量流训练了一个模型,其中要预测的目标对象已添加到预训练模型中。我正在使用的代码如下,我将目标对象图像设置为输入并希望获得预测输出:

class MultiObjectDetection():

    def __init__(self, classes_name):
        
        self._classes_name = classes_name
        self._num_classes = len(classes_name)

        self._common_params = {'image_size': 448, 'num_classes': self._num_classes, 
                'batch_size':1}
        self._net_params = {'cell_size': 7, 'boxes_per_cell':2, 'weight_decay': 0.0005}
        self._net = YoloTinyNet(self._common_params, self._net_params, test=True)
        
    def predict_object(self, image):
        predicts = self._net.inference(image)
        return predicts

    def process_predicts(self, resized_img, predicts, thresh=0.2):
        """
        process the predicts of object detection with one image input.
        
        Args:
            resized_img: resized source image.
            predicts: output of the model.
            thresh: thresh of bounding box confidence.
        Return:
            predicts_dict: {"stick": [[x1, y1, x2, y2, scores1], [...]]}.
        """
        cls_num = self._num_classes
        bbx_per_cell = self._net_params["boxes_per_cell"]
        cell_size = self._net_params["cell_size"]
        img_size = self._common_params["image_size"]
        p_classes = predicts[0, :, :, 0:cls_num]
        C = predicts[0, :, :, cls_num:cls_num+bbx_per_cell] # two bounding boxes in one cell.
        coordinate = predicts[0, :, :, cls_num+bbx_per_cell:] # all bounding boxes position.
        
        p_classes = np.reshape(p_classes, (cell_size, cell_size, 1, cls_num))
        C = np.reshape(C, (cell_size, cell_size, bbx_per_cell, 1))
        
        P = C * p_classes # confidencefor all classes of all bounding boxes (cell_size, cell_size, bounding_box_num, class_num) = (7, 7, 2, 1).
        
        predicts_dict = {}
        for i in range(cell_size):
            for j in range(cell_size):
                temp_data = np.zeros_like(P, np.float32)
                temp_data[i, j, :, :] = P[i, j, :, :]
                position = np.argmax(temp_data) # refer to the class num (with maximum confidence) for every bounding box.
                index = np.unravel_index(position, P.shape)
                
                if P[index] > thresh:
                    class_num = index[-1]
                    coordinate = np.reshape(coordinate, (cell_size, cell_size, bbx_per_cell, 4)) # (cell_size, cell_size, bbox_num_per_cell, coordinate)[xmin, ymin, xmax, ymax]
                    max_coordinate = coordinate[index[0], index[1], index[2], :]
                    
                    xcenter = max_coordinate[0]
                    ycenter = max_coordinate[1]
                    w = max_coordinate[2]
                    h = max_coordinate[3]
                    
                    xcenter = (index[1] + xcenter) * (1.0*img_size /cell_size)
                    ycenter = (index[0] + ycenter) * (1.0*img_size /cell_size)
                    
                    w = w * img_size 
                    h = h * img_size 
                    xmin = 0 if (xcenter - w/2.0 < 0) else (xcenter - w/2.0)
                    ymin = 0 if (xcenter - w/2.0 < 0) else (ycenter - h/2.0)
                    xmax = resized_img.shape[0] if (xmin + w) > resized_img.shape[0] else (xmin + w)
                    ymax = resized_img.shape[1] if (ymin + h) > resized_img.shape[1] else (ymin + h)
                    
                    class_name = self._classes_name[class_num]
                    predicts_dict.setdefault(class_name, [])
                    predicts_dict[class_name].append([int(xmin), int(ymin), int(xmax), int(ymax), P[index]])
                    
        return predicts_dict
    
    def non_max_suppress(self, predicts_dict, threshold=0.5):
        """
        implement non-maximum supression on predict bounding boxes.
        Args:
            predicts_dict: {"stick": [[x1, y1, x2, y2, scores1], [...]]}.
            threshhold: iou threshold
        Return:
            predicts_dict processed by non-maximum suppression
        """
        for object_name, bbox in predicts_dict.items():
            bbox_array = np.array(bbox, dtype=np.float)
            x1, y1, x2, y2, scores = bbox_array[:,0], bbox_array[:,1], bbox_array[:,2], bbox_array[:,3], bbox_array[:,4]
            areas = (x2-x1+1) * (y2-y1+1)
            order = scores.argsort()[::-1]
            keep = []
            while order.size > 0:
                i = order[0]
                keep.append(i)
                xx1 = np.maximum(x1[i], x1[order[1:]])
                yy1 = np.maximum(y1[i], y1[order[1:]])
                xx2 = np.minimum(x2[i], x2[order[1:]])
                yy2 = np.minimum(y2[i], y2[order[1:]])
                inter = np.maximum(0.0, xx2-xx1+1) * np.maximum(0.0, yy2-yy1+1)
                iou = inter/(areas[i]+areas[order[1:]]-inter)
                indexs = np.where(iou<=threshold)[0]
                order = order[indexs+1]
            bbox = bbox_array[keep]
            predicts_dict[object_name] = bbox.tolist()
            predicts_dict = predicts_dict
        return predicts_dict



class_names = ["aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable",
                   "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor",
                   "small_ball"]
modelFile = ('models\\train\\model.ckpt-0')
track_object = "small_ball"
print("object detection and tracking...")

multiObjectDetect = MultiObjectDetection(IP, class_names)
image = tf.placeholder(tf.float32, (1, 448, 448, 3))
object_predicts = multiObjectDetect.predict_object(image)



sess = tf.Session()
saver = tf.train.Saver(multiObjectDetect._net.trainable_collection)


saver.restore(sess, modelFile)

index = 0
while 1:
    
    src_img = cv2.imread("./weirdobject.jpg")
    resized_img = cv2.resize(src_img, (448, 448))
 
    np_img = cv2.cvtColor(resized_img, cv2.COLOR_BGR2RGB)
    np_img = np_img.astype(np.float32)
    np_img = np_img / 255.0 * 2 - 1
    np_img = np.reshape(np_img, (1, 448, 448, 3))

 
    np_predict = sess.run(object_predicts, feed_dict={image: np_img})
    predicts_dict = multiObjectDetect.process_predicts(resized_img, np_predict)
    predicts_dict = multiObjectDetect.non_max_suppress(predicts_dict)
    
    print ("predict dict = ", predicts_dict)

此代码的问题是 predicts_dict 返回:

predict dict =  {'sheep': [[233.0, 92.0, 448.0, -103.0, 5.3531270027160645], [167.0, 509.0, 209.0, 101.0, 4.947688579559326], [0.0, 0.0, 448.0, 431.0, 3.393721580505371]], 'horse': [[374.0, 33.0, 282.0, 448.0, 5.277851581573486], [135.0, 688.0, -33.0, -14.0, 3.5144259929656982], [1.0, 117.0, 112.0, -138.0, 2.656987190246582]], 'bicycle': [[461.0, 781.0, 154.0, -381.0, 5.918102741241455], [70.0, 344.0, 391.0, -138.0, 3.031444787979126], [378.0, 497.0, 46.0, 149.0, 2.7629122734069824], [541.0, 583.0, 69.0, 307.0, 2.7170517444610596], [323.0, 22.0, 336.0, 448.0, 1.608760952949524]], 'bottle': [[390.0, 218.0, -199.0, 448.0, 4.582971096038818], [0.0, 0.0, 448.0, -410.0, 0.9097045063972473]], 'sofa': [[346.0, 102.0, 323.0, -38.0, 2.371835947036743]], 'dog': [[319.0, 254.0, -282.0, 373.0, 4.022889137268066]], 'cat': [[63.0, -195.0, 365.0, -92.0, 3.5134828090667725]], 'person': [[22.0, -122.0, 154.0, 448.0, 3.927537441253662], [350.0, 155.0, -36.0, -445.0, 2.679833173751831], [119.0, 416.0, -43.0, 292.0, 0.9529445171356201], [251.0, 445.0, 225.0, 188.0, 0.9001350402832031]], 'train': [[329.0, 485.0, -24.0, -235.0, 2.7050414085388184], [483.0, 362.0, 237.0, -86.0, 2.555817127227783], [13.0, 365.0, 373.0, 448.0, 0.6229299902915955]], 'small_ball': [[217.0, 737.0, 448.0, -315.0, 1.739920973777771], [117.0, 283.0, 153.0, 122.0, 1.5690066814422607]], 'boat': [[164.0, 805.0, 34.0, -169.0, 4.972668170928955], [0.0, 0.0, 397.0, 69.0, 2.353729486465454], [302.0, 605.0, 15.0, -22.0, 2.0259625911712646]], 'aeroplane': [[470.0, 616.0, -305.0, -37.0, 3.431873321533203], [0.0, 0.0, 448.0, -72.0, 2.836672306060791]], 'bus': [[0.0, 0.0, -101.0, -280.0, 1.2078320980072021]], 'pottedplant': [[620.0, -268.0, -124.0, 418.0, 2.158564805984497], [0.0, 0.0, 448.0, -779.0, 1.6623022556304932]], 'tvmonitor': [[0.0, 0.0, 448.0, 85.0, 3.238999128341675], [240.0, 772.0, 200.0, 91.0, 1.7443398237228394], [546.0, 155.0, 448.0, 448.0, 1.1334525346755981], [107.0, 441.0, 432.0, 219.0, 0.5971617698669434]], 'chair': [[470.0, -187.0, 106.0, 235.0, 3.8548083305358887], [524.0, 740.0, -103.0, 99.0, 3.636549234390259], [0.0, 0.0, 275.0, -325.0, 3.0997846126556396], [711.0, -231.0, -146.0, 392.0, 2.205275535583496]], 'diningtable': [[138.0, -310.0, 111.0, 448.0, 4.660728931427002], [317.0, -66.0, 313.0, 6.0, 4.535496234893799], [0.0, 0.0, -41.0, 175.0, 1.8571208715438843], [21.0, -92.0, 76.0, 172.0, 1.2035608291625977], [0.0, 0.0, 448.0, -250.0, 1.00322687625885]], 'car': [[312.0, 232.0, 132.0, 309.0, 3.205225706100464], [514.0, -76.0, 218.0, 448.0, 1.4289973974227905], [0.0, 0.0, 448.0, 142.0, 0.7124998569488525]]}

要注意的是 weirdobject.jpg 包含模型经过训练可识别的 small_ball

我期望 predict dict = {'small_ball':[...]} 这怎么可能?预测输出在哪里?如何修复代码?


代码似乎存在几个问题,这些问题可能导致意外的预测结果:

  1. 置信度阈值过低: process_predicts 函数中,使用 0.2 的置信度阈值。这意味着即使模型对预测没有把握,也会返回边界框。尝试增加此阈值以过滤掉置信度较低的预测。

python predicts_dict = multiObjectDetect.process_predicts(resized_img, np_predict, thresh=0.5) # 增加阈值

  1. 非极大值抑制(NMS)问题: 虽然在代码中包含了 NMS,但返回值并未被使用。NMS 对于移除重叠的边界框至关重要,这些边界框可能检测到同一对象。确保在代码中实际使用了 NMS 的输出。

python predicts_dict = multiObjectDetect.non_max_suppress(predicts_dict) predicts_dict = predicts_dict # 更新 predicts_dict

  1. 模型训练不足或数据有问题: 如果在调整了置信度阈值和修复了 NMS 问题后问题仍然存在,的模型可能没有得到充分的训练来准确检测 small_ball 。这可能是由于训练数据不足、训练数据质量差或模型架构本身造成的。考虑使用更多数据训练的模型,特别是包含 small_ball 的图像。还可以尝试微调不同的模型架构或调整超参数以获得更好的性能。

  2. 输入图像预处理: 请确保以与训练模型时相同的方式对输入图像进行预处理。这可能包括调整图像大小、缩放像素值和规范化。任何预处理步骤的差异都可能导致模型性能不佳。

  3. 类别名称中的错误: 双重检查在 class_names 列表中使用的类别名称,并确保它们与训练数据中的名称完全匹配,包括大小写和任何空格。

通过解决这些潜在问题,应该能够改进对象检测模型并获得更准确的预测。

标签:python,tensorflow,machine-learning,object-detection,yolo
From: 78809007

相关文章

  • 基于遗传(GA)、粒子群(PSO)、模拟退火(SA)、禁忌搜索(ST)、蚁群算法(ACO)、自自组织神
        ......
  • 10个append()函数在Python程序开发中的创新应用
    文末赠免费精品编程资料~~在Python编程的世界里,append()函数是列表操作中最常见的方法之一。它允许我们在列表的末尾添加一个元素,这一简单的功能却能激发无限的创造力。今天,我们将探讨append()函数在Python程序开发中的10种创新应用,从基本用法到高级技巧,逐步深入。1.构......
  • 全网最适合入门的面向对象编程教程:28 类和对象的Python实现-Python编程原则、哲学和规
    全网最适合入门的面向对象编程教程:28类和对象的Python实现-Python编程原则、哲学和规范大汇总摘要:本文主要介绍了在使用Python进行面向对象编程时,Python异常处理的原则-“请求谅解,而非许可”,以及软件设计和Python的编程原则,同时介绍了PEP8规范。原文链接:FreakStud......
  • python生成器
    一前言环境:python3.10win10二生成器1关于生成器先看一个例子    定义了一个函数,当我们运行该函数时,并未像普通函数那样执行函数体内的代码    从其中的英文可知,执行函数得到了一个生成器对象,这个生成器对象也叫做generatoriterator(生成器迭代器),generatorit......
  • 生成MySQL-oracle-SQL server数据字典(附Python代码)
    生成数据字典,早年写的,请注意新的版本变化。(1)MySQL元数据SQLUSEinformation_schema;#取出库和表。select  TABLE_SCHEMAAS'数据库名称',  TABLE_NAMEAS'表名',  TABLE_TYPEAS'表类型',  ROW_FORMATAS'行格式',  ENGINEAS'数据库引擎',  TABL......
  • Python - Method Resolution Order (MRO)
    TheorderinwhichPythonsearchesforattributesinbaseclassesiscalledmethodresolutionorder(MRO).Itgivesalinearizedpathforaninheritancestructure.PythoncomputesanMROforeveryclassinthehierarchy;thisMROiscomputedusingthe‘C3......
  • 计算机毕业设计选题推荐-零食批发商仓库管理系统-Java/Python项目实战
    ✨作者主页:IT研究室✨个人简介:曾从事计算机专业培训教学,擅长Java、Python、微信小程序、Golang、安卓Android等项目实战。接项目定制开发、代码讲解、答辩教学、文档编写、降重等。☑文末获取源码☑精彩专栏推荐⬇⬇⬇Java项目Python项目安卓项目微信小程序项目......
  • 【自动化测试必学语言】python:语言基础
    目录Python介绍语言的分类注释单行注释多行注释变量定义变量使用变量变量名的命名规范数据类型数字类型非数字类型type()函数input输入print输出格式化输出快捷键(小操作)运算符算术运算符 比较运算符Python介绍作者:吉多·范罗苏姆(Guidov......
  • Python基础知识笔记——常用函数
    一、range()函数range()函数用于生成一个整数序列。它通常用于循环结构中,例如for循环,以提供循环的迭代次数。range()函数可以有1到3个参数。#range(start,stop,step)range(2,6,2)#生成从2开始,到6结束(不包括6),步长为2的一串数字#参数指定不完全时,默认从0开始,步长......
  • [python] 启发式算法库scikit-opt使用指北
    scikit-opt是一个封装了多种启发式算法的Python代码库,可以用于解决优化问题。scikit-opt官方仓库见:scikit-opt,scikit-opt官网文档见:scikit-opt-doc。scikit-opt安装代码如下:pipinstallscikit-opt#调用scikit-opt并查看版本importskosko.__version__'0.6.6'0背景介......