首页 > 其他分享 >猫狗检测分类系统

猫狗检测分类系统

时间:2023-04-30 18:23:25浏览次数:28  
标签:detect img 检测 分类 系统 im0 im print path

  • 源码及演示地址:

    演示地址:https://www.wchime.xyz:8083/#/

    后端代码:https://gitee.com/mom925/pet-web-api

    前端代码:https://gitee.com/mom925/uniapp-pets

 

  • 项目说明
    • 项目技术:django+mysql+uwsgi+nginx+uniapp
    • 逻辑:前端用户上传图片,django接收图片传给模型预测,首先做检测,检测是猫还是狗,然后将检测到的框裁剪下来,送进分类算法,对其进行具体得分类,得到结果再将其返回给前端
    • 部署:采用docker容器部署,可见我docker分类文章
    • 演示图:

       

  

  • 主要代码说明  

    主要代码是其中对图片进行预测的算法。

              

    其中检测算法代码结构如上,都在predict_deploy文件夹下

    

    首先是主文件my_detect.py,二进制图片文件进行预测入口。

    

import os
import sys
from pathlib import Path

from tools_detect import draw_box_and_save_img, dataLoad, predict_classify, detect_img_2_classify_img, get_time_uuid

FILE = Path(__file__).resolve()
ROOT = FILE.parents[0]  # YOLOv5 root directory
if str(ROOT) not in sys.path:
    sys.path.append(str(ROOT))  # add ROOT to PATH
ROOT = Path(os.path.relpath(ROOT, Path.cwd()))  # relative

from models.common import DetectMultiBackend
from utils.general import (non_max_suppression)
from utils.plots import save_one_box

import config as cfg

conf_thres = cfg.conf_thres
iou_thres = cfg.iou_thres

detect_size = cfg.detect_img_size
classify_size = cfg.classify_img_size


def detect_img(img, device, detect_weights='', detect_class=[], save_dir=''):
    # 选择计算设备
    # device = select_device(device)
    # 加载数据
    imgsz = (detect_size, detect_size)
    im0s, im = dataLoad(img, imgsz, device)
    # print(im0)
    # print(im)
    # 加载模型
    model = DetectMultiBackend(detect_weights, device=device)
    stride, names, pt = model.stride, model.names, model.pt
    # print((1, 3, *imgsz))
    model.warmup(imgsz=(1, 3, *imgsz))  # warmup

    pred = model(im, augment=False, visualize=False)
    # print(pred)
    pred = non_max_suppression(pred, conf_thres, iou_thres, None, False, max_det=1000)
    # print(pred)
    im0 = im0s.copy()
    # 画框,保存图片
    # ret_bytes= None
    ret_bytes = draw_box_and_save_img(pred, names, detect_class, save_dir, im0, im)
    ret_li = list()
    # print(pred)
    im0_arc = int(im0.shape[0]) * int(im0.shape[1])
    count = 1
    for det in reversed(pred[0]):
        # print(det)
        # print(det)
        # 目标太小跳过
        xyxy_arc = (int(det[2]) - int(det[0])) * (int(det[3]) - int(det[1]))
        # print(xyxy_arc)
        if xyxy_arc / im0_arc < 0.01:
            continue
        # 裁剪图片
        xyxy = det[:4]
        im_crop = save_one_box(xyxy, im0, file=Path('im.jpg'), gain=1.1, pad=10, square=False, BGR=False, save=False)
        # 将裁剪的图片转为分类的大小及tensor类型
        im_crop = detect_img_2_classify_img(im_crop, classify_size, device)

        d = dict()
        # print(det)
        c = int(det[-1])
        label = detect_class[c]
        # 开始做具体分类
        if label == detect_class[0]:
            classify_predict = predict_classify(cfg.cat_weight, im_crop, device)
            classify_label = cfg.cat_class[int(classify_predict)]
        else:
            classify_predict = predict_classify(cfg.dog_weight, im_crop, device)
            classify_label = cfg.dog_class[int(classify_predict)]
        # print(classify_label)
        d['details'] = classify_label
        conf = round(float(det[-2]), 2)
        d['label'] = label+str(count)
        d['conf'] = conf
        ret_li.append(d)
        count += 1

    return ret_li, ret_bytes


def start_predict(img, save_dir=''):
    weights = cfg.detect_weight
    detect_class = cfg.detect_class
    device = cfg.device
    ret_li, ret_bytes = detect_img(img, device, weights, detect_class, save_dir)
    # print(ret_li)
    return ret_li, ret_bytes


if __name__ == '__main__':
    name = get_time_uuid()
    save_dir = f'./save/{name}.jpg'
    # path = r'./test_img/hashiqi20230312_00010.jpg'
    path = r'./test_img/hashiqi20230312_00116.jpg'
    # path = r'./test_img/kejiquan20230312_00046.jpg'
    f = open(path, 'rb')
    img = f.read()
    f.close()
    # print(img)
    # print(type(img))
    img_ret_li, img_bytes = start_predict(img, save_dir=save_dir)
    print(img_ret_li)

 

   我的工具文件tools_detect.py,主要封装了一些工具方法

  

import datetime
import os
import random
import sys
import time
from pathlib import Path

import torch
from PIL import Image
from torch import nn

from utils.augmentations import letterbox

FILE = Path(__file__).resolve()
ROOT = FILE.parents[0]  # YOLOv5 root directory
if str(ROOT) not in sys.path:
    sys.path.append(str(ROOT))  # add ROOT to PATH
ROOT = Path(os.path.relpath(ROOT, Path.cwd()))  # relative

from utils.general import (cv2,
                           scale_boxes, xyxy2xywh)
from utils.plots import Annotator, colors
import numpy as np

def bytes_to_ndarray(byte_img):
    """
    图片二进制转numpy格式
    """
    image = np.asarray(bytearray(byte_img), dtype="uint8")
    image = cv2.imdecode(image, cv2.IMREAD_COLOR)
    return image


def ndarray_to_bytes(ndarray_img):
    """
    图片numpy格式转二进制
    """
    ret, buf = cv2.imencode(".jpg", ndarray_img)
    img_bin = Image.fromarray(np.uint8(buf)).tobytes()
    # print(type(img_bin))
    return img_bin

def get_time_uuid():
    """
        :return: 20220525140635467912
        :PS :并发较高时尾部随机数增加
    """
    uid = str(datetime.datetime.fromtimestamp(time.time())).replace("-", "").replace(" ", "").replace(":","").replace(".", "") + str(random.randint(100, 999))
    return uid


def dataLoad(img, img_size, device, half=False):
    image = bytes_to_ndarray(img)
    # print(image.shape)
    im = letterbox(image, img_size)[0]  # padded resize
    im = im.transpose((2, 0, 1))[::-1]  # HWC to CHW, BGR to RGB
    im = np.ascontiguousarray(im)  # contiguous

    im = torch.from_numpy(im).to(device)
    im = im.half() if half else im.float()  # uint8 to fp16/32
    im /= 255  # 0 - 255 to 0.0 - 1.0
    if len(im.shape) == 3:
        im = im[None]  # expand for batch dim

    return image, im


def draw_box_and_save_img(pred, names, class_names, save_dir, im0, im):

    save_path = save_dir
    fontpath = "./simsun.ttc"
    for i, det in enumerate(pred):
        annotator = Annotator(im0, line_width=3, example=str(names), font=fontpath, pil=True)
        if len(det):
            det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round()
            count = 1
            im0_arc = int(im0.shape[0]) * int(im0.shape[1])
            gn = torch.tensor(im0.shape)[[1, 0, 1, 0]]
            base_path = os.path.split(save_path)[0]
            file_name = os.path.split(save_path)[1].split('.')[0]
            txt_path = os.path.join(base_path, 'labels')
            if not os.path.exists(txt_path):
                os.mkdir(txt_path)
            txt_path = os.path.join(txt_path, file_name)
            for *xyxy, conf, cls in reversed(det):
                # 目标太小跳过
                xyxy_arc = (int(xyxy[2]) - int(xyxy[0])) * (int(xyxy[3]) - int(xyxy[1]))
                # print(im0.shape, xyxy, xyxy_arc, im0_arc, xyxy_arc / im0_arc)
                if xyxy_arc / im0_arc < 0.01:
                    continue
                # print(im0.shape, xyxy)
                c = int(cls)  # integer class
                label = f"{class_names[c]}{count} {round(float(conf), 2)}" #  .encode('utf-8')
                # print(xyxy)
                annotator.box_label(xyxy, label, color=colors(c, True))

                im0 = annotator.result()
                count += 1
                # print(im0)

                # print(type(im0))
                # im0 为 numpy.ndarray类型

                # Write to file
                # print('+++++++++++')
                xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist()  # normalized xywh
                # print(xywh)
                line = (cls, *xywh)  # label format
                with open(f'{txt_path}.txt', 'a') as f:
                    f.write(('%g ' * len(line)).rstrip() % line + '\n')
    cv2.imwrite(save_path, im0)

    ret_bytes = ndarray_to_bytes(im0)
    return ret_bytes


def predict_classify(model_path, img, device):
    # im = torch.nn.functional.interpolate(img, (160, 160), mode='bilinear', align_corners=True)
    # print(device)
    if torch.cuda.is_available():
        model = torch.load(model_path)
    else:
        model = torch.load(model_path, map_location='cpu')
    # print(help(model))
    model.to(device)
    model.eval()
    predicts = model(img)
    _, preds = torch.max(predicts, 1)
    pred = torch.squeeze(preds)
    # print(pred)
    return pred


def detect_img_2_classify_img(img, classify_size, device):
    im_crop1 = img.copy()
    im_crop1 = np.float32(im_crop1)
    image = cv2.resize(im_crop1, (classify_size, classify_size))
    image = image.transpose((2, 0, 1))
    im = torch.from_numpy(image).unsqueeze(0)
    im_crop = im.to(device)
    return im_crop

 

  我的配置文件config.py,主要是一些可改配置

  

import torch
import os

base_path = r'E:\project\pet-web-api\predict_deploy\weights'

detect_weight = os.path.join(base_path, r'cat_dog_detect/best.pt')
detect_class = ['猫', '狗']

cat_weight = os.path.join(base_path, r'cat_predict/best.pt')
cat_class = ['东方短毛猫', '亚洲豹猫', '加菲猫', '安哥拉猫', '布偶猫', '德文卷毛猫', '折耳猫', '无毛猫', '暹罗猫', '森林猫', '橘猫', '奶牛猫', '狞猫', '狮子猫', '狸花猫', '玳瑁猫', '白猫', '蓝猫', '蓝白猫', '薮猫', '金渐层猫', '阿比西尼亚猫', '黑猫']

dog_weight = os.path.join(base_path, r'dog_predict/best.pt')
dog_class = ['中华田园犬', '博美犬', '吉娃娃', '哈士奇', '喜乐蒂', '巴哥犬', '德牧', '拉布拉多犬', '杜宾犬', '松狮犬', '柯基犬', '柴犬', '比格犬', '比熊', '法国斗牛犬', '秋田犬', '约克夏', '罗威纳犬', '腊肠犬', '萨摩耶', '西高地白梗犬', '贵宾犬', '边境牧羊犬', '金毛犬', '阿拉斯加犬', '雪纳瑞', '马尔济斯犬']

# device = 0
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')
conf_thres = 0.5
iou_thres = 0.45

detect_img_size = 416
classify_img_size = 160

 

 

其中预测的结果不是很理想,因为我没有更多的数据。

以上就是项目大概内容,具体可看源码。

 

标签:detect,img,检测,分类,系统,im0,im,print,path
From: https://www.cnblogs.com/moon3496694/p/17365451.html

相关文章

  • 分类预测 | MATLAB实现WOA-CNN鲸鱼算法优化卷积神经网络数据分类预测
    ✅作者简介:热爱科研的Matlab仿真开发者,修心和技术同步精进,matlab项目合作可私信。......
  • Windows10系统检测不到声音输出设备,声音图标打叉,没声音的解决方法
    问题描述Windows10系统检测不到声音输出设备,声音图标打叉,没声音解决方案:点下轻松访问音频设置选项,再返回就可以了,至于具体是啥原因造成的,也不太清楚,什么逻辑,也不太清楚总之:<hrstyle="border:solid;width:100px;height:1px;"color=#000000size=1">......
  • 使用Gradio搭建AI演示系统
    简介在训练好模型之后,往往需要将其搭建为一个服务,使得他人能够进行调用。最常见的方案,可能就是借助flask、fastapi等配置较为容易web框架进行服务搭建。但是,根据需求,有时不仅会让我们搭建一个基本的服务,还需要进行前端样式配置,比如毕设的演示系统。笔者作为一名算法人员,前端的知......
  • 软考高项(信息系统项目管理师)—— 第 1 章 信息化发展——信息与信息化
    第1章信息化发展——信息与信息化一、概念 信息:information是物质、能量及其属性的标识的集合,是确定性的增加。它以物质介质为载体,传递和反映世界各种事物存在方式、运动状态等的表征。信息不是物质,也不是能力,它以一种普遍形式,表达物质运动规律,在客观世界中大量存在、产生......
  • 改手机串号技术原理能绕过APP检测
    随着智能手机的普及,应用程序的数量和种类也在不断增加。不同的应用程序可能需要不同的硬件和软件支持,导致一些应用程序无法在所有手机上运行。于是,一些用户开始探索绕过应用程序检测的方法,以使用这些应用程序。其中一个方法是修改手机串号。那么,改手机串号技术原理是否能够......
  • 软考高项(信息系统项目管理师)——前言
    前言信息系统项目管理师——第4版一、什么是信息系统项目开发? 。二、什么是信息系统项目管理? 综合运用相关只是、技能、工具和技术在一定的时间、成本、质量等要求下,为实现预定的系统目标而进行的管理计划、设计、开发、实施、运维等方面的活动称为信息系统项目管理。......
  • Win10系统命令行以管理员身份运行的几种方式
    在win10系统中运行许多命令需要使用管理员身份运行,如果直接按下win+R组合键呼出运行,键入cmd打开命令提示符输入命令执行的话会出现无法执行的现象。给大家分享下win10系统中几个以管理员身份运行的方法。方法一:1、在开始菜单上单击鼠标右键,在弹出的菜单中点击【命令提示符(......
  • COMP2006操作系统
    OperatingSystemsSemester-12023COMP2006-OperatingSystemsCURTINUNIVERSITYSchoolofElectricalEngineering,ComputingandMathematicalSciencesDisciplineofComputingCustomerQueueDueDate:4.00pmMonday8thMay2023Thegoalofthisassignmentisto......
  • 关于Linux操作系统OS账号最后一次登录时间的审计
    本文以RedHatEnterpriseLinuxrelease8.1(Ootpa)为例,应该也能适用于7.x版本的如果对操作系统中的账号审计,其中有一个项目可能会比较重要(尤其是对于个人账号),那就是最后一次登录的记录如果需要查看每一个OS账号的最后一次登录记录,可以使用lastlog命令[qq-5201351@localho......
  • 系统分析的一些经验
    做需求分析,我觉得最重要的任务是简化业务流程、规则、逻辑;丰富用户体验;  0.尽量将复杂的用户需求抽像成最简单的业务规则、数据库结构来实现。因为需求是不可能一下子就确定的,假设我们刚开始对核心需求的实现方式增加了一点点的复杂性,比如说多加了一个表,一个藕合字段,那么对于......