首页 > 其他分享 >深度学习(输出模型中间特征)

深度学习(输出模型中间特征)

时间:2024-09-30 22:33:56浏览次数:6  
标签:输出 plt torchvision img extractor 模型 feature 深度 import

   

深度学习骨干网络一般会包含很多层,这里写了一个脚本,可以保存骨干网络的所有特征图。

代码主要用了get_graph_node_names和create_featrue_extractor这两个函数。

get_graph_node_names是得到所有特征节点名字。

create_featrue_extractor是提取对应节点输出的特征tensor。

下面以resnet18为例,一共得到15491个特征图。

import torchvision
from PIL import Image
import torchvision.transforms as transforms
from matplotlib import pyplot as plt
from torchvision.models.feature_extraction import create_feature_extractor,get_graph_node_names

toTensor = transforms.ToTensor()

model = torchvision.models.resnet18(pretrained=True)
#model = torchvision.models.efficientnet_b0(pretrained=True)

feature_nodes, _ = get_graph_node_names(model)
features = create_feature_extractor(model, return_nodes=feature_nodes)

img = Image.open("1.jpg")
img = toTensor(img).unsqueeze(0)
out = features(img) 

print(list(out))

count = 0
for feature_name in list(out):
    feature = out[feature_name]
    if len(feature.shape)==4:
        B,C,H,W = feature.shape
        if H >1 and W>1:
            for c in range(C):   
                fig = plt.figure(1)
                plt.axis('off')
                print(feature.shape)
                plt.imshow(feature[0][c].detach().numpy())
                plt.savefig('./output/'+str(count)+'_'+feature_name+'_'+str(c)+'.png',bbox_inches='tight',pad_inches=0)
                count +=1
                plt.clf()

所有输出保存成图像,这里用plt输出保存,可以保证特征图像素一样多。

没选择cv2或PIL保存图像的原因是这两个库会保存原始图像,而后面的特征图越来越小,不好直观的看出区别。

下面是一些保存的图像:

标签:输出,plt,torchvision,img,extractor,模型,feature,深度,import
From: https://www.cnblogs.com/tiandsp/p/18432264

相关文章

  • LSTM模型改进实现多步预测未来30天销售额
    关于深度实战社区我们是一个深度学习领域的独立工作室。团队成员有:中科大硕士、纽约大学硕士、浙江大学硕士、华东理工博士等,曾在腾讯、百度、德勤等担任算法工程师/产品经理。全网20多万+粉丝,拥有2篇国家级人工智能发明专利。社区特色:深度实战算法创新获取全部完整项目......
  • 【机器学习】揭秘反向传播:深度学习中神经网络训练的奥秘
      目录......
  • CNN模型实现CIFAR-10彩色图片识别
    关于深度实战社区我们是一个深度学习领域的独立工作室。团队成员有:中科大硕士、纽约大学硕士、浙江大学硕士、华东理工博士等,曾在腾讯、百度、德勤等担任算法工程师/产品经理。全网20多万+粉丝,拥有2篇国家级人工智能发明专利。社区特色:深度实战算法创新获取全部完整项目......
  • 券商股大涨,至少17家券商已入局AI人工智能金融大模型
    大家好,我是Shelly,一个专注于输出AI工具和科技前沿内容的AI应用教练,体验过300+款以上的AI应用工具。关注科技及大模型领域对社会的影响10年+。关注我一起驾驭AI工具,拥抱AI时代的到来。最近,券商股价的大涨成为了财经新闻的热门话题。背后的原因,除了市场整体环境的改善,更重要的是......
  • 65结构体-结构体数组。在C++中,结构体的定义是什么呢?如何新建一个结构体呢?新建好的结构
    问题描述:根据下列代码和结果回答下列问题。//Createdby黑马程序员.#include"iostream"usingnamespacestd;#include<string>//结构体定义structstudent{//成员列表stringname;//姓名intage;//年龄intscore;//分数}stu3;/......
  • 联邦学习中的模型异构 :知识蒸馏
    目录 联邦学习中的模型异构 一、定义与背景:揭开模型异构的神秘面纱二、模型异构的挑战:智慧与技术的双重考验三、解决策略与方法:智慧与技术的巧妙融合四、实际应用与前景:智慧与技术的无限可能举例说明异构模型的具体表现模型异构的挑战与解决方案实际应用案例 联......
  • 深度学习系列之1----直观解释Transformer
    Abstract这个系列主要用来记录我自己这种的AI小白的学习之路,通过将所学所知总结下来,记录下来。之前总喜欢记录在笔记本上,或者ipad上,或者PC端的Typora上,但总是很难回头检索到一些系统的知识,因此我觉得博客是一个不错的选择,因为时不时我就会登录网站翻看过去的痕迹,我觉得这是一种很......
  • 【大模型指令微调: 从零学会炼丹】第二章: 数据集预处理
    大模型指令微调:从零学会炼丹系列目录第一章:微调数据集构建第二章:数据集预处理第二章:数据集预处理环境准备pipinstalldatasetstransformerspandasduckdbfunctools导入包fromdatasetsimportDatasetfromtransformersimport(AutoTokenizer,......
  • 【大模型指令微调: 从零学会炼丹】第一章: 微调数据集构建
    大模型指令微调:从零学会炼丹系列目录第一章:微调数据集构建文章目录大模型指令微调:从零学会炼丹系列目录第一章:微调数据集构建Alpaca格式编写Instructioninstruction-key读取本地数据定义format函数第一章:微调数据集构建Alpaca格式Alpaca格式是一......
  • 基于ARIMA回归模型的股票价格预测
    一、ARIMA回归模型ARIMA回归模型是一种用于时间序列数据预测的统计模型。ARIMA 是AutoRegressiveIntegratedMovingAverage 的缩写,中文意思是“自回归差分移动平均模型”。它是一种结合了自回归(AR)、差分(I)和移动平均(MA)三种方法的模型。自回归(AR):指的是模型会考虑过去的......