首页 > 编程语言 >推荐一个计算Grad-CAM的Python库

推荐一个计算Grad-CAM的Python库

时间:2024-04-18 16:13:05浏览次数:37  
标签:CAM img Python grad cam model Grad tensor

前言

类激活图CAM(class activation mapping)用于可视化深度学习模型的感兴趣区域,增加了神经网络的可解释性。现在常用Grad-CAM可视化,Grad-CAM基于梯度计算激活图,对比传统的CAM更加灵活,且不需要修改模型结构。

虽然计算grad-cam并不复杂,但是本着能导包就导包的原则,想着去用现成的库。

pip install grad-cam

官方文档开源仓库

简单试用

  1. 加载模型和预训练权重

这里使用PyTorch官方提供的在ImageNet上预训练的Resnet50。注意:这里使用现成的模型参数,也需要用它们提供的图片预处理方式

from torchvision.models import resnet50, ResNet50_Weights

# 加载ResNet模型和预训练权重
weights = ResNet50_Weights.DEFAULT
model = resnet50(weights=weights)
model.eval()

preprocess = weights.transforms() # 图片预处理方法
  1. 简单读入一张图片

bird

from PIL import Image

src = 'bird.jpg'
img = Image.open(src)
print(f'The Image size:{img.size}')
img_tensor = preprocess(img)
print(f'The shape of image preprocessed: {img_tensor.shape}')

Output

The Image size:(474, 315)
The shape of image preprocessed: torch.Size([3, 224, 224])
  1. 计算Grad-CAM
rom pytorch_grad_cam import GradCAM

grad_cam = GradCAM(model=model, target_layers=[model.layer4[-1]])   
cam = grad_cam(input_tensor=img_tensor.unsqueeze(0)) # 输入的Shape: B x C x H x W

print(f'Cam.shape: {cam.shape}')
print(f'Cam.max: {cam.max()}, Cam.min: {cam.min()}')

Output

Cam.shape: (1, 224, 224)
Cam.max: 0.9999998807907104, Cam.min: 0.0

这里可以看到计算的CAM值的区间是[0, 1],一些处理长尾数据的图像增强的方法,通过CAM的值与原图像相乘,得到图像的主体或背景(上下文)。

  1. 可视化
from pytorch_grad_cam.utils.image import show_cam_on_image
import uuid
import numpy as np
import torch

def vis_cam(cam: np.ndarray, input_tensor: torch.Tensor):
    def normalization(x: np.ndarray, scale=1):   # 归一化
        x_min = np.min(x)
        x_max = np.max(x)
        return (x - x_min) / (x_max - x_min) * scale 
    
    # 底层是cv2实现的所以要求图像形状为 H x W x C
    input_tensor= input_tensor.permute(1, 2, 0).numpy()
    norm_img = normalization(input_tensor)
    
    # 可视化不支持batch,所以要取cam第一个
    vis = show_cam_on_image(norm_img, cam[0], use_rgb=True)
    
    vis_img = Image.fromarray(vis)
    vis_img.save(f'cam_{uuid.uuid1()}.jpg')
    return vis

vis1 = vis_cam(cam, img_tensor)

结果如下,由于图像经过了预处理,size变味了224x224,所以CAM的大小也是这个尺寸。

另外,这个库也提供了其他CAM方法,如GradCAMElementWise,与Grad-CAM相似,将激活值与梯度逐元素相乘,然后在求和之前应用 ReLU 运算。但是简单使用后,肉眼没有察觉差异:

from pytorch_grad_cam import GradCAMElementWise
grad_cam = GradCAMElementWise(model=model, target_layers=[model.layer4[-1]])
cam = grad_cam(input_tensor=img_tensor.unsqueeze(0)) # 输入的Shape: B x C x H x W
vis2 = vis_cam(cam, img_tensor)


将它们做一个横向对比,从左至右分别是原图、GradCAMGradCAMElementWise

img_hstack = np.hstack([img.resize(size=(224, 224)), vis1, vis2])
Image.fromarray((img_hstack).astype(np.uint8)).save('cam_compare.jpg')            



其他

有一点很重要,但是文中并没有使用,关于ClassifierOutputTarget的使用,文档中它的一种用法:

cam = GradCAM(model=model, target_layers=target_layers, use_cuda=args.use_cuda)

targets = [ClassifierOutputTarget(281)]

grayscale_cam = cam(input_tensor=input_tensor, targets=targets)

输入的参数是图片对应的target,也就是one-hot标签里面的1的下标,但由于使用的是预训练模型,所以不知道具体的标签。而当cam这里的targets=None时,会自动选择得分最高的类。

关于grad-cam还有许多功能,这里仅仅介绍了计算cam和可视化的部分。

运行环境

grad-cam                  1.5.0                    pypi_0    pypi
pytorch                   2.2.2           py3.12_cuda12.1_cudnn8_0    pytorch

标签:CAM,img,Python,grad,cam,model,Grad,tensor
From: https://www.cnblogs.com/zh-jp/p/18143700

相关文章

  • Effective Python:第6条 把数据结构直接拆分到多个变量里,不要专门通过下标访问
    使用拆分(unpacking),就可以把元组里面的元素分别赋给多个变量。优点:1,通过unpacking来赋值要比通过下标去访问元组内的元素更清晰,而且这种写法所需的代码量通常比较少。2,便于原地交换两个变量;tb=[1,2]tb[0],tb[1]=tb[1],tb[0]print(tb)3,for循环或者类似的结构(例如推......
  • react-native-camera 安装
    npmi react-native-camera--save或yarnadd react-native-camera 在android/app/build.gradle中添加:missingDimensionStrategy'react-native-camera','general'implementationproject(path:':react-native-camera') 在MainApplicatio......
  • python包:matplotlib
    1):matplotlib是一个python2D绘图库,利用它可以画出许多高质量的图像。只需几行代码即可生成直方图,条形图,饼图,散点图等。Matplotlib是整个包,pyplot是Matplotlib中的一个模块,并且pylab是一个安装在一起的模块。 https://matplotlib.org/2:使用https://zhuanlan.zhihu.com/p......
  • python --多个叠加装饰器
    defdeco1(func1):defwrapper1(*args,**kwargs):print("运行deco1_wrapper1")res1=func1(*args,**kwargs)returnres1returnwrapper1defdeco2(func2):defwrapper2(*args,**kwargs):print("运行deco2_wra......
  • blender python api 使用脚本修改动画关键帧的属性值
    1.代码1-将动画关键帧中的所有Y轴都设置为1.0,代码:importbpy#设置重置到的Y坐标值reset_to=1.0#遍历所有当前选中的对象forobjectinbpy.context.selected_objects:#如果对象没有动画,我们也应该重置其Y坐标object.location.y=reset_to#检......
  • 对大量ip:port进行批量telnet检测的python脚本
    对大量ip:port进行批量telnet检测的python脚本telnet_test.py#导入socket模块,用于网络通信importsocket#定义一个函数,用于测试Telnet连接是否成功deftest_telnet(ip,port):try:#尝试创建到指定IP和端口的连接socket.create_connection((ip,po......
  • python学习第一天
    学习一门技术,我们都要知道三个问题。为什么要学习?优点缺点是?怎么入门?python有很强的,就业性,学习完全是为了未来ai趋势做迎合,同样也是基于爱好站在了人工智能和大数据的风口上,站在风口上,猪都能飞起来。优点:简单上手,功能强大,库多缺点:速度慢,代码不能加密2、python的第一个......
  • 基于python的文件seek和tell实例解析
    一概念AF.seek(偏移量,whence=相对位置)偏移量大于0的数代表向文件末尾方向移动的字节数小于0的数代表向文件头方向中移动的字节数相对位置0代表从文件头开始偏移1代表从文件当前读写位置开始偏移2代表从文件尾开始偏移Btell函数能够返回指针......
  • 【Python微信机器人】写一个监控采集公众号文章的插件
    原文链接:https://mp.weixin.qq.com/s/f8zbM6wMld3koqjaFbCuxw前言弄了个视频号下载后,同一个问题每天都会被问,回答的有点烦了。想了想根本原因还是缺少一个交流平台,微信群的话,刚进群的看不到之前的聊天记录。想整个知识星球,发现只能弄个收费的,免费的需要激活码才能创建。而有......
  • blender python api 使用脚本批量对obj物体进行渲染(obj所在目录要有与之对应的mtl文件
     代码:importbpy#导入Blender的PythonAPI接口importpathlib#导入pathlib模块,用于操作文件路径#设置OBJ文件所在的目录路径obj_root=pathlib.Path('D:\\ceshi')#注意Windows路径中的斜杠需要转义#取消选择场景中的所有物体,以便导入时不会与已选择的物体冲......