参考github:https://github.com/sixitingting/CAM/blob/master/pytorch_CAM.py
也就是类激活映射(CAM)原作者所给,想要懂理论的去看论文,本次着重实践。
CAM结果展示:
top1 prediction: mountain bike, all-terrain bike, off-roader
后话先说: 我发现现在还有很多朋友搜到这篇文章,但这是我刚开始学的时候写的笔记,很多东西都不太全,最近我更新了一个新的方法,更简单,更实用。如果感兴趣,欢迎查看我的另一篇博文CAM(类激活映射),卷积可视化,神经网络可视化,一个库搞定,真的简单的不能再简单。我相信这篇提供的帮助更多,如果有需要的话,点击看看吧。
----------------------------------------------------------实战开始-------------------------------------------------------
# simple implementation of CAM in PyTorch for the networks such as ResNet, DenseNet, SqueezeNet, Inception
import io
import requests
from PIL import Image
import torch
from torchvision import models, transforms
from torch.autograd import Variable
from torch.nn import functional as F
import numpy as np
import cv2
import json
# input image
LABELS_URL = 'https://s3.amazonaws.com/outcome-blog/imagenet/labels.json'
IMG_URL = 'http://media.mlive.com/news_impact/photo/9933031-large.jpg'
# jsonfile = r'D:\python\Camtest\labels.json'
# with open(jsonfile, 'r') as load_f:
# load_json = json.load(load_f)
# networks such as googlenet, resnet, densenet already use global average pooling at the end,
# so CAM could be used directly.
model_id = 1
if model_id == 1:
net = models.squeezenet1_1(pretrained=False)
pthfile = r'E:\anaconda\app\envs\luo\Lib\site-packages\torchvision\models\squeezenet1_1.pth'
net.load_state_dict(torch.load(pthfile))
finalconv_name = 'features' # this is the last conv layer of the network
elif model_id == 2:
net = models.resnet18(pretrained=False)
finalconv_name = 'layer4'
elif model_id == 3:
net = models.densenet161(pretrained=False)
finalconv_name = 'features'
net.eval()
print(net)
# hook the feature extractor
features_blobs = []
def hook_feature(module, input, output):
features_blobs.append(output.data.cpu().numpy())
net._modules.get(finalconv_name).register_forward_hook(hook_feature)
# get the softmax weight
params = list(net.parameters())
weight_softmax = np.squeeze(params[-2].data.numpy())
def returnCAM(feature_conv, weight_softmax, class_idx):
# generate the class activation maps upsample to 256x256
size_upsample = (256, 256)
bz, nc, h, w = feature_conv.shape
output_cam = []
for idx in class_idx:
cam = weight_softmax[idx].dot(feature_conv.reshape((nc, h*w)))
cam = cam.reshape(h, w)
cam_img = (cam - cam.min()) / (cam.max() - cam.min()) # normalize
cam_img = np.uint8(255 * cam_img)
output_cam.append(cv2.resize(cam_img, size_upsample))
return output_cam
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
normalize
])
response = requests.get(IMG_URL)
img_pil = Image.open(io.BytesIO(response.content))
img_pil.save('test.jpg')
img_tensor = preprocess(img_pil)
img_variable = Variable(img_tensor.unsqueeze(0))
logit = net(img_variable)
# download the imagenet category list
classes = {int(key): value for (key, value)
in requests.get(LABELS_URL).json().items()}
# classes = {int(key): value for (key, value)
# in load_json.items()}
# 结果有1000类,进行排序
h_x = F.softmax(logit, dim=1).data.squeeze()
probs, idx = h_x.sort(0, True)
probs = probs.numpy()
idx = idx.numpy()
# output the prediction 取前5
for i in range(0, 5):
print('{:.3f} -> {}'.format(probs[i], classes[idx[i]]))
# generate class activation mapping for the top1 prediction
CAMs = returnCAM(features_blobs[0], weight_softmax, [idx[0]])
# render the CAM and output
print('output CAM.jpg for the top1 prediction: %s'%classes[idx[0]])
img = cv2.imread('test.jpg')
height, width, _ = img.shape
heatmap = cv2.applyColorMap(cv2.resize(CAMs[0],(width, height)), cv2.COLORMAP_JET)
result = heatmap * 0.3 + img * 0.5
cv2.imwrite('CAM.jpg', result)
结果:0.678 -> mountain bike, all-terrain bike, off-roader
0.088 -> bicycle-built-for-two, tandem bicycle, tandem
0.042 -> unicycle, monocycle
0.038 -> horse cart, horse-cart
0.019 -> lakeside, lakeshore
output CAM.jpg for the top1 prediction: mountain bike, all-terrain bike, off-roader
以上是所有代码,对容易遇到的问题做个解释:
1.由于网络问题下载不下来json文件:
urllib3.exceptions.MaxRetryError: HTTPConnectionPool(host='media.mlive.com', port=80):
方法:手动把文件下载下来,放在同一个项目下,用被注释掉的代码:
jsonfile = r'D:\python\Camtest\labels.json' # 换成自己的地址
with open(jsonfile, 'r') as load_f:
load_json = json.load(load_f)
# 并把下面内容打开
classes = {int(key): value for (key, value)
in load_json.items()}
2.由于网络问题加载不出来预训练的网络,我就加载不出来,所以 pretrained=False,并去下载预训练好的参数,
3.关于hook住中间特征不懂得去百度: register_forward_hook
4.结果只展示了top5的预测结果。