首页 > 其他分享 >神经网络特征图显示(matplotlib同一画布切换的方式)

神经网络特征图显示(matplotlib同一画布切换的方式)

时间:2023-01-23 20:22:06浏览次数:49  
标签:map plt get current feature matplotlib 画布 神经网络 num

在网络上查了一转也没找到,全是复制粘贴的代码,下面贴下自己写的代码。

采用matplotlib.pyplot方式绘制,发现默认显示附带工具栏,就在里面按照ttk方式添加了切换图片的按钮,给出两种方式(单一特征图和两个特征图进行对比)。

内存不足的机器,不要想着一下子显示所有特征图。

import matplotlib.pyplot as plt
from tkinter import ttk, IntVar
import torch

def show_feature_map(feature_map, show_size=(1024, 1024), cur_block_idx=0):
    feature_map = feature_map[:1, ...].cpu().detach()
    upsample = torch.nn.UpsamplingBilinear2d(size=show_size)
    feature_map = upsample(feature_map)
    feature_map = feature_map.view(feature_map.shape[1], feature_map.shape[2], feature_map.shape[3])
    feature_map_num = feature_map.shape[0]

    def next_btn_push():
        plt.clf()
        if current_num.get() < feature_map_num - 1:
            current_num.set(current_num.get() + 1)
        plt.suptitle('feature channel number : {}/{}'.format(current_num.get() + 1, feature_map_num))
        plt.imshow(feature_map[current_num.get()], cmap='gray')
        plt.show()

    def back_btn_push():
        plt.clf()
        if current_num.get() > 0:
            current_num.set(current_num.get() - 1)
        plt.suptitle('feature channel number : {}/{}'.format(current_num.get() + 1, feature_map_num))
        plt.imshow(feature_map[current_num.get()], cmap='gray')
        plt.show()

    fig = plt.figure(cur_block_idx)
    current_num = IntVar(fig.canvas.manager.toolbar)
    current_num.set(0)
    next_btn = ttk.Button(fig.canvas.manager.toolbar, text='下一张图片', command=next_btn_push)
    back_btn = ttk.Button(fig.canvas.manager.toolbar, text='上一张图片', command=back_btn_push)
    next_btn.pack()
    back_btn.pack()

    plt.suptitle('feature channel number : {}/{}'.format(current_num.get() + 1, feature_map_num))
    plt.imshow(feature_map[current_num.get()], cmap='gray')
    plt.show()

def show_feature_map_compare(feature_map, seg_feature_map, show_size=(1024, 1024), cur_block_idx=0):
    feature_map = feature_map[:1, ...].cpu().detach()
    seg_feature_map = seg_feature_map[:1, ...].cpu().detach()
    upsample = torch.nn.UpsamplingBilinear2d(size=show_size)
    feature_map = upsample(feature_map)
    seg_feature_map = upsample(seg_feature_map)
    feature_map = feature_map.view(feature_map.shape[1], feature_map.shape[2], feature_map.shape[3])
    seg_feature_map = seg_feature_map.view(seg_feature_map.shape[1], seg_feature_map.shape[2], seg_feature_map.shape[3])
    feature_map_num = feature_map.shape[0]

    def next_btn_push():
        plt.clf()
        if current_num.get() < feature_map_num - 1:
            current_num.set(current_num.get() + 1)
        plt.suptitle('feature channel number : {}/{}'.format(current_num.get() + 1, feature_map_num))
        plt.subplot(1, 2, 1)
        plt.imshow(feature_map[current_num.get()], cmap='gray')
        plt.subplot(1, 2, 2)
        plt.imshow(seg_feature_map[current_num.get()], cmap='gray')
        plt.show()

    def back_btn_push():
        plt.clf()
        if current_num.get() > 0:
            current_num.set(current_num.get() - 1)
        plt.suptitle('feature channel number : {}/{}'.format(current_num.get() + 1, feature_map_num))
        plt.subplot(1, 2, 1)
        plt.imshow(feature_map[current_num.get()], cmap='gray')
        plt.subplot(1, 2, 2)
        plt.imshow(seg_feature_map[current_num.get()], cmap='gray')
        plt.show()

    fig = plt.figure(cur_block_idx)
    current_num = IntVar(fig.canvas.manager.toolbar)
    current_num.set(0)
    next_btn = ttk.Button(fig.canvas.manager.toolbar, text='下一张图片', command=next_btn_push)
    back_btn = ttk.Button(fig.canvas.manager.toolbar, text='上一张图片', command=back_btn_push)
    next_btn.pack()
    back_btn.pack()

    plt.suptitle('feature channel number : {}/{}'.format(current_num.get() + 1, feature_map_num))
    plt.subplot(1, 2, 1)
    plt.imshow(feature_map[current_num.get()], cmap='gray')
    plt.subplot(1, 2, 2)
    plt.imshow(seg_feature_map[current_num.get()], cmap='gray')
    plt.show()

运行效果如下:

 

 

 

 

标签:map,plt,get,current,feature,matplotlib,画布,神经网络,num
From: https://www.cnblogs.com/lzqdeboke/p/17065472.html

相关文章