首页 > 其他分享 >feature map-opencv实现特征热力图可视化

feature map-opencv实现特征热力图可视化

时间:2023-03-25 14:11:19浏览次数:67  
标签:map plt img format self feature opencv savepath model

上代码

绿色底纹 部分 代表 单个通道 热力图生成 代码;

import cv2
import time
import os
import matplotlib.pyplot as plt
import torch
from torch import nn
import torchvision.models as models
import torchvision.transforms as transforms
import numpy as np

savepath = r'features_heat'
if not os.path.exists(savepath):
    os.mkdir(savepath)

def draw_features(width, height, x, savename):
    fig = plt.figure(figsize=(16, 16))
    fig.subplots_adjust(left=0.05, right=0.95, bottom=0.05, top=0.95, wspace=0.05, hspace=0.05)
    for i in range(width * height):
        plt.subplot(height, width, i + 1)
        plt.axis('off')
# 归一化每个通道 img = x[0, i, :, :] # x:[b c h w] pmin = np.min(img) pmax = np.max(img) img = ((img - pmin) / (pmax - pmin + 0.000001)) * 255 # float在[0,1]之间,转换成0-255 img = img.astype(np.uint8) # 转成unit8 print(img.shape) # (14, 14) img = cv2.applyColorMap(img, cv2.COLORMAP_JET) # 生成heat map print(img.shape) # (14, 14, 3) img = img[:, :, ::-1] # 注意cv2(BGR)和matplotlib(RGB)通道是相反的 plt.imshow(img) print("完成plot {}/{}".format(i, width * height)) fig.savefig(savename, dpi=100) fig.clf() plt.close() class ft_net(nn.Module): def __init__(self): super(ft_net, self).__init__() model_ft = models.resnet50(pretrained=True) print([i for i in model_ft.children()]) self.model = model_ft def forward(self, x): if True: # draw features or not x = self.model.conv1(x) draw_features(8, 8, x.cpu().numpy(), "{}/f1_conv1.png".format(savepath)) x = self.model.bn1(x) draw_features(8, 8, x.cpu().numpy(), "{}/f2_bn1.png".format(savepath)) x = self.model.relu(x) draw_features(8, 8, x.cpu().numpy(), "{}/f3_relu.png".format(savepath)) x = self.model.maxpool(x) draw_features(8, 8, x.cpu().numpy(), "{}/f4_maxpool.png".format(savepath)) x = self.model.layer1(x) draw_features(16, 16, x.cpu().numpy(), "{}/f5_layer1.png".format(savepath)) x = self.model.layer2(x) draw_features(16, 32, x.cpu().numpy(), "{}/f6_layer2.png".format(savepath)) x = self.model.layer3(x) draw_features(32, 32, x.cpu().numpy(), "{}/f7_layer3.png".format(savepath)) x = self.model.layer4(x) draw_features(32, 32, x.cpu().numpy()[:, 0:1024, :, :], "{}/f8_layer4_1.png".format(savepath)) draw_features(32, 32, x.cpu().numpy()[:, 1024:2048, :, :], "{}/f8_layer4_2.png".format(savepath)) x = self.model.avgpool(x) plt.plot(np.linspace(1, 2048, 2048), x.cpu().numpy()[0, :, 0, 0]) plt.savefig("{}/f9_avgpool.png".format(savepath)) plt.clf() plt.close() x = x.view(x.size(0), -1) x = self.model.fc(x) plt.plot(np.linspace(1, 1000, 1000), x.cpu().numpy()[0, :]) plt.savefig("{}/f10_fc.png".format(savepath)) plt.clf() plt.close() else: print(44444444444444444444444444444444) x = self.model.conv1(x) x = self.model.bn1(x) x = self.model.relu(x) x = self.model.maxpool(x) x = self.model.layer1(x) x = self.model.layer2(x) x = self.model.layer3(x) x = self.model.layer4(x) x = self.model.avgpool(x) x = x.view(x.size(0), -1) x = self.model.fc(x) return x model = ft_net()#.cuda() model.eval() img = cv2.imread('image1.jpg') img = cv2.resize(img, (224, 224)) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) img = transform(img)#.cuda() img = img.unsqueeze(0) with torch.no_grad(): out = model(img) result = out ind = np.argsort(result, axis=1) for i in range(5): print("predict:top {} = cls {} : score {}".format(i + 1, ind[0, 1000 - i - 1], result[0, 1000 - i - 1])) print("done")

输入图像

conv1 [1,64,112,112]

bn1_relu [1,64,112,112]

maxpool [1,64,56,56]

layer1 [1,256,56,56]

 

layer2 [1,512,28,28]

layer3 [1,1024,14,14]

layer4 [1,2048,7,7]

avgpool [1,2048]

fc [1,1000]

 

其中:横轴是类别编号,纵轴是输出的类别得分(没有经过softmax) 

 

 

 

 

 

 

 

 

 

 

 

参考资料:

https://blog.csdn.net/weixin_40500230/article/details/93845890  Pytorch自带Resnet50特征图heat map热力图可视化

标签:map,plt,img,format,self,feature,opencv,savepath,model
From: https://www.cnblogs.com/yanshw/p/17054745.html

相关文章

  • Bitmap读取本地高分辨率图片报内存不足的解决方案
    1#regiongetThumImage生成缩略图2///<summary>3///生成缩略图4///</summary>5///<paramname="sourceFile">原始图......
  • Arcmap出现拓扑无效问题怎么解决
    在ArcMap中出现拓扑无效错误通常是由于要素类之间存在空间关系不一致或拓扑错误导致的。以下是几种可能的解决方案:运行“检查几何”工具,以确定是否存在几何错误。如果有几......
  • opencv对比两张图片的相似度
    OpenCV提供了两种计算图像相似度的方法:结构相似性(SSIM)和均方误差(MSE)。其中,SSIM是一种更加准确的方法,它不仅考虑了像素之间的差异,还考虑了人眼对图像的感知。而MSE则只是简......
  • Java中Map类型数据使用LinkedHashMap保留数据的插入顺序
    场景Vue中JS遍历后台JAVA返回的Map数据,构造对象数组数据格式:Vue中JS遍历后台JAVA返回的Map数据,构造对象数组数据格式_BADAO_LIUMANG_QIZHI的博客在上面构造以时间为Key,以......
  • javascript 高级编程系列 - Set集合与Map集合
    ES6中新增的Set集合类型是一种有序列表,其中含有一些相互独立的非重复值,通过Set集合可以快速访问其中的数据,更有效地追踪各种离散值。1.创建Set集合并添加元素调用newS......
  • Bitmap、RoaringBitmap原理分析
    作者:京东科技 曹留界在人群本地化实践中我们介绍了人群ID中所有的pin的偏移量可以通过Bitmap存储,而Bitmap所占用的空间大小只与偏移量的最大值有关系。假如现在要向Bitma......
  • 跟着chatgpt学mmap
    以前对Linux的了解比较少,现在跟着chatgpt来学学,很好玩。比搜索那堆垃圾博客好太多了,而且咱chatgpt中文也很好哦chatgpt给的mmap解释mmap是一种UNIX和类UNIX操作系统......
  • Nmap 的使用教程
    Nmap是一个网络侦测和安全审计工具。它可以用于发现网络上的主机和服务,并提供广泛的信息,其中包括操作系统类型和版本、应用程序和服务的详细信息等。在本文中,我们将介绍如何......
  • Go语言并发编程(3):sync包介绍和使用(上)-Mutex,RWMutex,WaitGroup,sync.Map
    一、sync包简介在并发编程中,为了解决竞争条件问题,Go语言提供了sync标准包,它提供了基本的同步原语,例如互斥锁、读写锁等。sync包使用建议:除了Once和WaitGroup......
  • JS中的 map, forEach 无法跳出循环, return和 break不起作用,可以使用every 和 some方法
    JS中的map,forEach无法跳出循环,return和break不起作用,可以使用every和some方法敲代码的TKP于2022-09-0115:52:47发布1711收藏1分类专栏:javaScriptes6文......