首页 > 其他分享 >U2-Net 预测函数

U2-Net 预测函数

时间:2023-04-28 13:35:15浏览次数:41  
标签:info 函数 img pred cv2 U2 time Net model

包含单个图片检测以及视频检测

import os
import time

import cv2
import numpy as np
import matplotlib.pyplot as plt
import torch
import time
import subprocess
from torchvision.transforms import transforms

from src import u2net_full
os.environ['KMP_DUPLICATE_LIB_OK']='TRUE'

def time_synchronized():
    torch.cuda.synchronize() if torch.cuda.is_available() else None
    return time.time()

# 获取GPU相关信息
def get_gpu_info():
    try:
        cmd_out = subprocess.check_output('nvidia-smi --query-gpu=name,memory.used,memory.total --format=csv,noheader',
                                          shell=True)
        gpu_info = cmd_out.decode().strip().split('\n')
        gpu_info = [info.split(', ') for info in gpu_info]
        return gpu_info
    except subprocess.CalledProcessError as e:
        print("Error while invoking nvidia-smi: ", e)
        return None

# 打印 GPU 型号及占用情况
def print_gpu_usage():
    gpu_info = get_gpu_info()
    if gpu_info:
        total_memory = 0
        used_memory = 0
        for name, used, total in gpu_info:
            used_memory += int(used.strip().split()[0])
            total_memory += int(total.strip().split()[0])
            memory_usage_percent = round(used_memory / total_memory * 100, 2)
            print(f"GPU: {name.strip()}, Memory used: {used.strip()}, Memory total: {total.strip()}"
                  f", Memory usage: {memory_usage_percent}%")

# 将原图像与分割后的图像混合
def Image_Blend(src, res):
    info = res.shape
    height = info[0]
    width = info[1]
    dst = np.zeros((height, width, 3), np.uint8)
    # 分割后的图换色
    mask = ~(res == [0, 0, 0]).all(axis=2)
    res[mask] = [0, 0, 255]
    dst = res
    # 2.图像混合
    img = cv2.addWeighted(src, 0.8, dst, 0.2, 0, dtype=cv2.CV_8UC3)
    return img

# 读取 Gpu 信息
def gpu_info() -> str:
    info = ''
    for id in range(torch.cuda.device_count()):
        p = torch.cuda.get_device_properties(id)
        info += f'CUDA:{id} ({p.name}, {p.total_memory / (1 << 20):.0f}MiB)\n'
    return info[:-1]

# 单单检测图片
def pic_predict(threshold, device, data_transform, origin_img, model):
    h, w = origin_img.shape[:2]
    img = data_transform(origin_img)
    img = torch.unsqueeze(img, 0).to(device)  # [C, H, W] -> [1, C, H, W]

    with torch.no_grad():
        # init model
        img_height, img_width = img.shape[-2:]
        init_img = torch.zeros((1, 3, img_height, img_width), device=device)
        model(init_img)

        # 推理
        pred = model(img)
        pred = torch.squeeze(pred).to("cpu").numpy()  # [1, 1, H, W] -> [H, W]
        pred = cv2.resize(pred, dsize=(w, h), interpolation=cv2.INTER_LINEAR)
        pred_mask = np.where(pred > threshold, 1, 0)
        origin_img = np.array(origin_img, dtype=np.uint8)
        seg_img = origin_img * pred_mask[..., None]

    img_res = Image_Blend(origin_img,seg_img)
    cv2.imwrite("result/pred_result11.png", cv2.cvtColor(img_res.astype(np.uint8), cv2.COLOR_RGB2BGR))

# 视频检测
def video_pre(threshold, device, data_transform, origin_img, model):
    h, w = origin_img.shape[:2]
    img = data_transform(origin_img)
    img = torch.unsqueeze(img, 0).to(device)  # [C, H, W] -> [1, C, H, W]

    with torch.no_grad():
        # init model
        img_height, img_width = img.shape[-2:]
        init_img = torch.zeros((1, 3, img_height, img_width), device=device)
        model(init_img)

        # 推理
        pred = model(img)

        # 打印GPU占用信息
        print_gpu_usage()

        pred = torch.squeeze(pred).to("cpu").numpy()  # [1, 1, H, W] -> [H, W]
        pred = cv2.resize(pred, dsize=(w, h), interpolation=cv2.INTER_LINEAR)
        pred_mask = np.where(pred > threshold, 1, 0)
        origin_img = np.array(origin_img, dtype=np.uint8)
        seg_img = origin_img * pred_mask[..., None]

    img_res = Image_Blend(origin_img, seg_img)
    return img_res

def main():
    weights_path = "model_best.pth"
    img_path = "test/video.mp4"
    threshold = 0.5

    # 判断图片路径是否正确
    assert os.path.exists(img_path), f"image file {img_path} dose not exists."
    # 判断 Gpu 是否可用
    if torch.cuda.is_available():
        print(gpu_info())
    # 设置硬件 根据Gpu 是否可用,来选择用GPU 还是 CPU
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    data_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize(320),
        transforms.Normalize(mean=(0.485, 0.456, 0.406),
                             std=(0.229, 0.224, 0.225))
    ])

    # 载入模型
    model = u2net_full()
    weights = torch.load(weights_path, map_location='cpu')
    if "model" in weights:
        model.load_state_dict(weights["model"])
    else:
        model.load_state_dict(weights)
    model.to(device)
    model.eval()

    str = os.path.splitext(img_path)[-1]
    if str == ".mp4":
        print("视频读入")
        # 获取视频部分
        cap = cv2.VideoCapture(img_path)
        i = 0
        # 2、获取图像的属性(宽和高),并将其转化为整数
        frame_width = int(cap.get(3))
        frame_height = int(cap.get(4))
        # 3、创建保存视频的对象,设置编码格式、帧率、图像的宽高等
        out = cv2.VideoWriter('result/OutPut2.avi', cv2.VideoWriter_fourcc(*'FFV1'), 30,
                              (frame_width, frame_height))

        while (cap.isOpened()):
            # 4、获取每一帧图像
            ret, frame = cap.read()
            img = frame
            i += 1
            start_time = time.time()  # 开始处理一帧图片的时间
            img_res = video_pre(threshold, device, data_transform, img, model)
            # 5、将每一帧图像写入到输出文件中
            if ret == True:
                out.write(img_res)
            else:
                break
            end_time = time.time()
            cost_time = end_time - start_time
            print("检测第 {} 帧花了 {:.8f}s 。".format(i, cost_time))
        cap.release()
        out.release()
        cv2.destroyAllWindows()
    elif str == ".jpg":
        print("图片读入")
        start_time = time.time()  # 开始处理一帧图片的时间
        origin_img = cv2.cvtColor(cv2.imread(img_path, flags=cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)
        pic_predict(threshold, device, data_transform, origin_img, model)
        end_time = time.time()
        cost_time = end_time - start_time
        print("检测一张图片花了 {:.8f}s 。".format(cost_time))
    else:
        print("请重新读入图片或者视频")

if __name__ == '__main__':
    main()

标签:info,函数,img,pred,cv2,U2,time,Net,model
From: https://www.cnblogs.com/femme/p/17361837.html

相关文章

  • Kubernetes 设置命令行的命名空间
    在较新版本的Kubernetes中,kubectl的默认命名空间已经不再是default,而是用户的当前命名空间。这是因为Kubernetes强烈建议您在不同的命名空间中隔离应用程序和资源,因此kubectl默认使用用户当前的命名空间来提高生产力。您可以使用以下命令来查看当前所在的命名空间:arduin......
  • 【Dotnet 工具箱】JIEJIE.NET - 强大的 .NET 代码混淆工具
    你好,这里是Dotnet工具箱,定期分享Dotnet有趣,实用的工具和组件,希望对您有用!JIEJIE.NET-强大的.NET代码混淆工具JIEJIE.NETJIEJIE.NET是一个使用C#开发的开源.NET代码加密工具。很多.NET开发人员担心他们的软件被破解,版权受到侵犯,所以他们使用一些工具来混淆IL......
  • 关于开环传递函数的理解
    困惑许久,直到在知乎上看见回答:https://www.zhihu.com/question/450172398 我以前一直以为开环传函是指把反馈回路断开,输出和输入的比值,但后来才知道是指将中间的环路任意位置断开,环路本身的传递函数,也就是说上图中的开环传函不是A(s),而是A(s)B(s)在此记录,以防遗忘 ......
  • .netcore 使用Quartz定时任务
    这是一个使用.NETCore和Quartz.NET实现定时任务的完整示例。首先确保已经安装了.NETCoreSDK。接下来按照以下步骤创建一个新的控制台应用程序并设置定时任务:创建一个新的.NETCore控制台应用程序:dotnetnewconsole-nQuartzDemocdQuartzDemo通过NuGet添加......
  • 私有继承派生类使用基类的成员函数
    按要求完成下面的程序:1、定义一个Animal类,成员包括:(1)整数类型的私有数据成员m_nWeightBase,表示Animal的体重;(2)整数类型的保护数据成员m_nAgeBase,表示Animal的年龄;(3)公有函数成员set_weight,用指定形参初始化数据成员m_nWeightBase;(4)公有成员函数get_weight,返回数据成员m_nWeightBase的......
  • EXPLORING MODEL-BASED PLANNING WITH POLICY NETWORKS
    发表时间:2020(ICLR2020)文章要点:这篇文章说现在的planning方法都是在动作空间里randomlygenerated,这样很不高效(其实瞎扯了,很多不是随机的方法啊)。作者提出在modelbasedRL里用policy网络来做onlineplanning选择动作,提出了model-basedpolicyplanning(POPLIN)算法。作者提出......
  • ubuntu22.04取消开机输入密码(实测)
    打开终端sudonano/etc/gdm3/custom.conf在文件的[daemon]部分中添加以下两行代码:[daemon]AutomaticLoginEnable=TrueAutomaticLogin=username保存并关闭,注意usename值的是你自己登录的用户名第二步sudonano/etc/pam.d/gdm-password将下面一行注释掉authrequi......
  • Kubernetes 1.3 从入门到进阶 安装篇:minikube
    Kubernetes单机运行环境一直是一个没有得到重视的问题。现在我们有了minikube,一个用go语言开发的可以在本地运行kubernetes的利器,不过目前应该只是支持kubernetes1.3。如果你只有一台机器或者虚拟机又想试验一下Kubernetes的新的功能,或者作kubernetes上开发的本地环境,minikube可能......
  • 函数重载
    函数形参不同:            intadd(intx,inty);            float add(floatx,floaty);形参个数不同:            intadd(intx,inty);            intadd(intx,i......
  • exec函数族
      /*exec函数族加载并运行可执行目标文件fork调用一次,返回两次exec调用一次,从不返回,只有出现错误时,才会返回-1到调用程序fork后相同程序,不同进程;execve后相同进程,不同程序。因此,通常fork一个子进程,然后再使用exec......