首页 > 其他分享 >PyTorch内置模型detection的resnet50使用,使用本地的权重文件

PyTorch内置模型detection的resnet50使用,使用本地的权重文件

时间:2022-10-26 16:38:02浏览次数:59  
标签:plt resnet50 ## cv2 boxes detection PyTorch import myimg

 

 1         ##完全使用本地权重,识别时根据识别准确率来确定是否绘制
 2         import matplotlib.pyplot as plt
 3         import torch
 4         import torchvision.transforms as T
 5         import torchvision
 6         import cv2
 7         from torchvision.io.image import read_image
 8         from torchvision.models.detection import FasterRCNN_ResNet50_FPN_V2_Weights
 9 
10         import warnings
11         warnings.filterwarnings("ignore",category=ResourceWarning)
12         warnings.filterwarnings("ignore",category=DeprecationWarning)
13 
116         img_path = "./jupyterlab/doc/ccc.jpg"        ##骑着自行车的美女,任选
17         img = read_image(img_path)##用pytorch提供的io函数
18 
19         weights_info = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT
20         ##读本地权重文件,权重文件到pytorch网站下载
21         model = torchvision.models.detection.maskrcnn_resnet50_fpn_v2(weights=None, progress=False, weights_backbone=None)
22         myweights = torch.load('E:/study_2022/working_python/maskrcnn_resnet50_fpn_v2_coco-73cbd019.pth')
23         model.load_state_dict(myweights)
24         model.eval()##识别工作模式
25         
26         preprocess = weights_info.transforms()
27         batch = [preprocess(img)]
28         prediction = model(batch)[0]
29         labels = [weights_info.meta["categories"][i] for i in prediction["labels"]]
30         boxes = [i for i in prediction["boxes"]]
31         scores = [i for i in prediction["scores"]]
32 
33         myimg = cv2.imread(img_path)
35         myimg = cv2.cvtColor(myimg, cv2.COLOR_BGR2RGB)
36         for i,score in enumerate(scores):
37             if score.item() < 0.9 : continue##舍弃准确率90%以下的
38             myimg = cv2.addWeighted(myimg, alpha=0.5, src2=myimg, beta=0.5, gamma=1)
39             ##注意:cv2这里只接受整型坐标值
40             start_point = (int(boxes[i][0]), int(boxes[i][1]))
41             end_point = (int(boxes[i][2]), int(boxes[i][3]))
42             cv2.rectangle(myimg, start_point, end_point, color = (255,0,0), thickness=3)
43             cv2.putText(myimg, labels[i], start_point, cv2.FONT_HERSHEY_SIMPLEX, 2, color = (255,0,0), thickness=3)
44         plt.figure(figsize=(7, 5))
45         plt.imshow(myimg)
46         plt.xticks([])
47         plt.yticks([])
48         plt.show()

 

标签:plt,resnet50,##,cv2,boxes,detection,PyTorch,import,myimg
From: https://www.cnblogs.com/ace007/p/16828883.html

相关文章

  • pytorch+Unet图像分割:将图片中的盐体找出来
    向AI转型的程序员都关注了这个号????????????机器学习AI算法工程  公众号:datayx 什么是图像分割问题呢?简单的来讲就是给一张图像,检测是用框出框出物体,而图像分割分出一......
  • 安装pytorch遇到的OS电脑注册表报错 及 解决办法
    今天在单位安装pytorch的时候,遇到了一个OS报错问题。  我安装的是在CPU上的,虽然我安装了anaconda,但是我还是习惯性的选择用pip安装。所以我就直接去pytorch的官网h......
  • 学习pytorch day02
    NumPy数组数组对象是NumPy中最核心的组成部分,这个数组叫做ndarray,是“N-dimensionalarray”的缩写。其中的N是一个数字,指代维度,例如你常常能听到的1-D数组、2-......
  • 学习PyTorch Day01
    PyTorch设计得更科学,无需像TensorFlow那样,要在各种API之间切换,操作更加便捷。PyTorch能够帮你快速实现模型与算法的验证,快速完成深度学习模型部署,提供高并发服务,还......
  • 《PyTorch 深度学习实践 》-刘二大人 第十三讲
    同样的参数,CPU跑15min,GPU2min43s1#根据地名分辨国家2importmath3importtime4importtorch5#绘图6importmatplotlib.pyplotasplt7impo......
  • 《PyTorch 深度学习实践 》-刘二大人 第十二讲
    1'''2inputhello3outputohloluseRNNcell4'''5importtorch67input_size=48hidden_size=49batch_size=110#准备数据11idx2char=['e'......
  • 《PyTorch 深度学习实践 》 刘二大人 第十讲
    课堂练习:1importtorch2fromtorchvisionimporttransforms3fromtorchvisionimportdatasets4fromtorch.utils.dataimportDataLoader5importtorch.......
  • 《PyTorch深度学习实践》-刘二大人 第九讲
    课堂练习,课后作业不想做了……1importtorch2fromtorchvisionimporttransforms3fromtorchvisionimportdatasets4fromtorch.utils.dataimportDataLoa......
  • PyTorch (1) | PyTorch的安装与简介
    本文已收录于Pytorch系列专栏:​​Pytorch入门与实践​​专栏旨在详解Pytorch,精炼地总结重点,面向入门学习者,掌握Pytorch框架,为数据分析,机器学习及深度学习的代码能力打下......
  • PyTorch 深度学习实践
    1importnumpyasnp2importtorch3importmatplotlib.pyplotasplt4importos5os.environ['KMP_DUPLICATE_LIB_OK']='True'67#1preparedataset......