了解积分梯度如何帮助识别哪些输入特征对模型的预测贡献
在金融和医疗保健等受到高度监管的行业中使用 AI 模型肩负着关键责任:可解释性。您的模型预测准确是不够的。您应该能够解释您的模型做出特定预测的原因。例如,如果我们正在开发一个基于脑部 MRI 扫描的肿瘤检测模型,我们应该能够解释我们的模型使用哪些信息以及它如何处理这些信息并导致肿瘤识别。在这种情况下,监管机构或医生需要了解这些细节,以确保结果公正和准确。那么,您如何解释您的模型决策呢?手动解释它们并不容易,因为没有简单的 “if-else” 逻辑 — 深度学习模型通常有数百万个参数以非线性方式交互,因此无法追踪从输入到输出的路径。
当我们要了解某个特征(比如水果的颜色)对模型预测结果(比如水果的价格)的影响时,我们会从一个基线开始——这个基线是一个我们认为对模型输出没有影响的点,比如将所有特征值设为零或者平均值。然后,我们会逐步增加这个特征的重要性,直到达到实际输入的值。在这个过程中,我们会记录每次小幅度增加特征值时模型输出的变化。最后,我们将这些小变化加起来,就能得到这个特征对模型输出总贡献的一个估计。Integrated Gradients是一种帮助我们理解复杂模型工作原理的技术。它通过计算特征值从小到大的过程中模型输出的变化量,来量化每个特征对模型预测结果的影响。这种方法不仅适用于图像识别任务,还可以应用于文本分析、声音识别等多个领域。
在满足这一需求方面,我有实践经验的技术之一是积分梯度。它是由 Google 的研究人员于 2017 年推出的,这是一种强大的方法,通过整合从基线到实际输入的梯度来计算归因。在本文中,我将引导您完成一个图像分类用例,并向您展示集成梯度如何帮助我们了解哪些图像像素在决策中最重要。我将使用 Captum 库来计算归因,并使用预先训练的 ResNet 模型来预测图像。
环境设置
安装下面提到的必要软件包
pip install captum torch torchvision Matplotlib NumPy PIL
下载示例图像并将其命名为 image.jpg(您可以下载所需的任何图像)。
现在,我们将加载了解对象的预训练 ResNet 模型,使用 ResNet 模型对图像进行分类,使用集成梯度技术计算属性,并可视化结果,显示图像的哪些像素对模型的预测最重要。
以下是带有详细注释的完整实现。
import torch
import torchvision
import torchvision.transforms as transforms
from captum.attr import IntegratedGradients
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np# 加载我们刚刚下载的图像。如果你下载的图像名字不同,请更改图像名称。
image_to_predicted = Image.open('image.jpg')# 将图像调整为标准格式并转换为数字。
transformed_image = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])(image_to_predicted).unsqueeze(0)# 下载预训练的ResNet模型,使其准备好进行预测。
model = torchvision.models.resnet18(pretrained=True)
model.eval()# 进行预测并使用argmax函数找到概率最高的对象。
predicted_image_class = torch.nn.functional.softmax(model(transformed_image)[0], dim=0).argmax().item()# 创建IntegratedGradients对象并计算属性
integrated_gradients = IntegratedGradients(model)# 创建基线参考,IG计算从这里开始
baseline_image = torch.zeros_like(transformed_image)# 使用预测的图像类别、基线和转换后的图像计算归因
computed_attributions, delta = integrated_gradients.attribute(transformed_image, baseline_image, target=predicted_image_class, return_convergence_delta=True)# 将归因转换为numpy数组以进行可视化
attributions_numpy = np.abs(np.transpose(computed_attributions.squeeze().cpu().detach().numpy(), (1, 2, 0))) * 255 * 10
attributions_numpy = attributions_numpy.astype(np.uint8)# 可视化下载的图像和带有归因的图像
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
axes[0].imshow(img)
axes[0].axis('off')
axes[0].set_title('下载的图像')
axes[1].imshow(attributions_numpy, cmap='magma')
axes[1].axis('off')
axes[1].set_title('带有归因的图像')
plt.show()
这是我下载的示例图像的输出。您可以看到突出显示的区域显示每个像素对于模型预测的重要性。
结论
我们探讨了如何使用积分梯度来解释深度学习模型的预测。使用积分梯度,我们深入了解了图像的哪些像素对模型的预测最重要。这不是唯一可用于模型可解释性的技术。其他技术,如特征重要性、Shapley 加法解释 (SHAP),也可用于深入了解模型行为。
今天先到这儿,希望对云原生,技术领导力, 企业管理,系统架构设计与评估,团队管理, 项目管理, 产品管理,信息安全,团队建设 有参考作用 , 您可能感兴趣的文章:
构建创业公司突击小团队
国际化环境下系统架构演化
微服务架构设计
视频直播平台的系统架构演化
微服务与Docker介绍
Docker与CI持续集成/CD
互联网电商购物车架构演变案例
互联网业务场景下消息队列架构
互联网高效研发团队管理演进之一
消息系统架构设计演进
互联网电商搜索架构演化之一
企业信息化与软件工程的迷思
企业项目化管理介绍
软件项目成功之要素
人际沟通风格介绍一
精益IT组织与分享式领导
学习型组织与企业
企业创新文化与等级观念
组织目标与个人目标
初创公司人才招聘与管理
人才公司环境与企业文化
企业文化、团队文化与知识共享
高效能的团队建设
项目管理沟通计划
构建高效的研发与自动化运维
某大型电商云平台实践
互联网数据库架构设计思路
IT基础架构规划方案一(网络系统规划)
餐饮行业解决方案之客户分析流程
餐饮行业解决方案之采购战略制定与实施流程
餐饮行业解决方案之业务设计流程
供应链需求调研CheckList
企业应用之性能实时度量系统演变
如有想了解更多软件设计与架构, 系统IT,企业信息化, 团队管理 资讯,请关注我的微信订阅号:
作者:Petter Liu
出处:http://www.cnblogs.com/wintersun/
本文版权归作者和博客园共有,欢迎转载,但未经作者同意必须保留此段声明,且在文章页面明显位置给出原文连接,否则保留追究法律责任的权利。
该文章也同时发布在我的独立博客中-Petter Liu Blog。