首页 > 编程问答 >如何实现 Grad-CAM 在 TensorFlow ResNet152V2 上查看激活图/热图以进行图像分类

如何实现 Grad-CAM 在 TensorFlow ResNet152V2 上查看激活图/热图以进行图像分类

时间:2024-07-20 23:35:05浏览次数:21  
标签:python tensorflow conv-neural-network resnet image-classification

您好,我正在使用 ResNet152V2 做一个关于 TensorFlow 图像分类的小项目。

我编写了一个 Train-Predict.py 脚本,它能够训练 trained_weights.hdf5 文件以成功预测自闭症和非自闭症人士的图像。

此处。是脚本:

#Import Libraries
import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import models
from tensorflow.keras.applications import ResNet152V2
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dropout, Dense, BatchNormalization, GlobalAveragePooling2D, Conv2D
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.optimizers import Adam
from PIL import Image
import cv2





os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
img_size = 224
batchsize = 128
epochs = 50
lrate = 0.01
lrate_reduction_factor = 0.5
training = False


#Variable setup and Detect images/Classes in dataset folders
traindatadir="Train"
testdatadir="Test"

#Data Augmentation
datagen = ImageDataGenerator(
    rescale = 1./255,
    horizontal_flip = True,
    vertical_flip = True,
    rotation_range=15,
    shear_range=0.1,
    zoom_range=0.2,
    width_shift_range=0.1,
    height_shift_range=0.1
)

#Preprosessing Data
train_datagen=datagen.flow_from_directory(
    traindatadir,
    target_size = (img_size, img_size),
    color_mode = 'rgb',
    batch_size = batchsize,
    shuffle = True,
    seed = 123,
    class_mode= 'categorical'
)

test_datagen=datagen.flow_from_directory(
    testdatadir,
    target_size = (img_size, img_size),
    color_mode = 'rgb',
    batch_size = batchsize,
    shuffle = True,
    seed = 123,
    class_mode= 'categorical'
)


#ResNet Model with custom node tuning (Reduced number of Nodes this time)
resnet = ResNet152V2( 
    include_top = False,
    weights = 'imagenet',
    input_shape = (img_size, img_size, 3)
)
 
resnet.trainable = False

model = Sequential()
 
model.add(resnet)
model.add(Conv2D(512, kernel_size=(3,3), activation="relu"))  
model.add(Conv2D(512, kernel_size=(3,3), activation="relu"))  
model.add(Conv2D(512, kernel_size=(3,3), activation="relu"))  
model.add(GlobalAveragePooling2D())
model.add(Dropout(0.2))
model.add(Dense(512, activation='relu'))
model.add(Dense(64, activation = 'relu'))
model.add(Dropout(0.2))
model.add(Dense(256, activation='relu'))
model.add(BatchNormalization())
model.add(Dense(128, activation='relu'))    
model.add(Dense(2, activation='softmax'))
 
model.compile(optimizer=Adam(learning_rate=lrate), loss = 'categorical_crossentropy', metrics = ['accuracy'])
#model.compile(loss='categorical_crossentropy', optimizer="Adam", metrics=['accuracy'])

model.summary()


#Use ReduceLR to reduce learning rate when metric not improving
earlystop = EarlyStopping(patience=20)
learning_rate_reduction = ReduceLROnPlateau(monitor='val_loss', 
                                            patience=10, 
                                            verbose=1, 
                                            factor=lrate_reduction_factor, 
                                            min_lr=0.000000000000001) 
callbacks = [earlystop, learning_rate_reduction]

if training:
    history = model.fit(train_datagen,epochs=epochs,batch_size=batchsize,validation_data=test_datagen,callbacks=callbacks)
    model.save('trained_weights.hdf5')
    
    #Plot
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(6,6))
    ax1.plot(history.history['loss'], color='b', label="Training loss")
    ax1.plot(history.history['val_loss'], color='r', label="validation loss")
    ax1.set_xticks(np.arange(0, epochs, (epochs/10)))
    ax1.legend()
    
    ax2.plot(history.history['accuracy'], color='b', label="Training accuracy")
    ax2.plot(history.history['val_accuracy'], color='r',label="Validation accuracy")
    ax2.set_xticks(np.arange(0, epochs, (epochs/10)))
    ax2.legend()
    
    legend = plt.legend(loc='best', shadow=True)
    plt.tight_layout()
    plt.show()





#Model Prediction
model = models.load_model('trained_weights.hdf5', compile = True)

predict_path = "Train"

datagen = ImageDataGenerator(
    rescale = 1./255,
)

predict_data = datagen.flow_from_directory(
    predict_path,
    target_size = ((img_size,img_size)), 
)

ci = predict_data.class_indices
classes = {v: k for k, v in ci.items()}

#path = "C:/Users/J.A.X/Desktop/TempEnv/Test/autistic/1028.jpg"
path = "C:/Users/J.A.X/Desktop/TempEnv/Train/non_autistic/0001.jpg"
#path = input('Please enter path of image to classify: \n')

inp = Image.open(path)
img = inp.resize((img_size,img_size))
img = np.array(img)/255.0
img = np.reshape(img, [1,img_size,img_size,3])

predictions = model.predict(img)

top_values, top_indices = tf.nn.top_k(predictions, k=2)

values = np.array(top_values)
indices = np.array(top_indices)

#print('Input Image: \n\n\n')
#inp.show()

print('Probabilities: \n')
#print(values)
#print(indices)

for i in range(2):
    print(classes[indices[0][i]] + " : ", end = "")
    print(values[0][i] * 100)
    print()


image = cv2.imread(path)
# Convert the image from BGR to RGB
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = cv2.resize(image, (img_size, img_size))

# Expand dimensions to match the expected input shape (1, 224, 224, 3)
image = np.expand_dims(image, axis=0)

# Convert image to float32 and normalize
image = image.astype(np.float32) / 255.0

# checking how it looks
plt.imshow(image[0])  # Note: image[0] because image now has a batch dimension
plt.show()

print(image.shape) # Print Shape

i = np.argmax(predictions[0])
print(i) # 0 is autistic and 1 is non austistic

输出:

Found 2054 images belonging to 2 classes.
Found 882 images belonging to 2 classes.
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 resnet152v2 (Functional)    (None, 7, 7, 2048)        58331648  
                                                                 
 conv2d (Conv2D)             (None, 5, 5, 512)         9437696   
                                                                 
 conv2d_1 (Conv2D)           (None, 3, 3, 512)         2359808   
                                                                 
 conv2d_2 (Conv2D)           (None, 1, 1, 512)         2359808   
                                                                 
 global_average_pooling2d (G  (None, 512)              0         
 lobalAveragePooling2D)                                          
                                                                 
 dropout (Dropout)           (None, 512)               0         
                                                                 
 dense (Dense)               (None, 512)               262656    
                                                                 
 dense_1 (Dense)             (None, 64)                32832     
                                                                 
 dropout_1 (Dropout)         (None, 64)                0         
                                                                 
 dense_2 (Dense)             (None, 256)               16640     
                                                                 
 batch_normalization (BatchN  (None, 256)              1024      
 ormalization)                                                   
                                                                 
 dense_3 (Dense)             (None, 128)               32896     
                                                                 
 dense_4 (Dense)             (None, 2)                 258       
                                                                 
=================================================================
Total params: 72,835,266
Trainable params: 14,503,106
Non-trainable params: 58,332,160
_________________________________________________________________

enter image description here

Found 2054 images belonging to 2 classes.
1/1 [==============================] - 3s 3s/step
Probabilities: 

non_autistic : 81.39994740486145

autistic : 18.60005408525467

(1, 224, 224, 3)
1

我希望创建一个热图来可视化卷积层中发生激活的位置。

类似这样的东西: enter image description here

但是在在线跟踪数十个 Grad-Cam 示例之后并编写几十个函数,例如:

# Grad-CAM implementation
def get_img_array(img_path, size):
    img = load_img(img_path, target_size=size)
    array = np.expand_dims(np.array(img), axis=0)
    return array / 255.0

def make_gradcam_heatmap(img_array, model, last_conv_layer_name, pred_index=None):
    grad_model = tf.keras.models.Model(
        [model.inputs], [model.get_layer(last_conv_layer_name).output, model.output]
    )
    with tf.GradientTape() as tape:
        last_conv_layer_output, preds = grad_model(img_array)
        if pred_index is None:
            pred_index = tf.argmax(preds[0])
        class_channel = preds[:, pred_index]

    grads = tape.gradient(class_channel, last_conv_layer_output)
    pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))

    last_conv_layer_output = last_conv_layer_output[0]
    heatmap = last_conv_layer_output @ pooled_grads[..., tf.newaxis]
    heatmap = tf.squeeze(heatmap)

    heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)
    return heatmap.numpy()

def save_and_display_gradcam(img_path, heatmap, cam_path="cam.jpg", alpha=0.4):
    img = cv2.imread(img_path)
    heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0]))
    heatmap = np.uint8(255 * heatmap)
    heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
    superimposed_img = heatmap * alpha + img
    cv2.imwrite(cam_path, superimposed_img)

    img = cv2.cvtColor(superimposed_img, cv2.COLOR_BGR2RGB)
    plt.imshow(img)
    plt.axis('off')
    plt.show()

img_array = get_img_array(path, size=(img_size, img_size))

# Generate Grad-CAM heatmap
last_conv_layer_name = "conv5_block3_out"
heatmap = make_gradcam_heatmap(img_array, model, last_conv_layer_name)

# Display Grad-CAM
save_and_display_gradcam(path, heatmap)

每个不同的示例都会导致不同的 错误 各种类型,但大多数都指向 resnet 嵌套在实际模型中...

根据我读过的示例表示使用模型的最后一个卷积层。不太确定它们到底意味着 resnet 中的 "conv5_block3_out" 层还是 resnet 之外的模型中的我的 conv2d_2 卷积层...

欢迎对编写和显示正确的热图有任何帮助。我愿意在这一点上尝试任何事情。

虚拟环境是 minicoda

设置虚拟环境的说明:

conda create --name Cuda_Python3.8 python=3.8 -y
conda activate Cuda_Python3.8

conda install cudatoolkit=11.2 cudnn=8.1 -c=conda-forge -y

pip install tensorflow-gpu==2.10.1
pip install spyder==5.5.5
pip install Pillow==10.4.0
pip install matplotlib==3.7.5
pip install opencv-python==4.10.0.84

你遇到的问题在于你的 Grad-CAM 实现尝试访问 ResNet152V2 模型内部的层,而你之前已经冻结了这些层的训练。 为了解决这个问题,你需要在 Grad-CAM 中使用你的自定义模型的最后一个卷积层,而不是 ResNet 的。

以下是更新后的代码,其中包含必要的更改:

# ... (你之前的代码) ...

# Grad-CAM implementation
def get_img_array(img_path, size):
    img = load_img(img_path, target_size=size)
    array = np.expand_dims(np.array(img), axis=0)
    return array / 255.0

def make_gradcam_heatmap(img_array, model, last_conv_layer_name, pred_index=None):
    grad_model = tf.keras.models.Model(
        [model.inputs], [model.get_layer(last_conv_layer_name).output, model.output]
    )
    with tf.GradientTape() as tape:
        last_conv_layer_output, preds = grad_model(img_array)
        if pred_index is None:
            pred_index = tf.argmax(preds[0])
        class_channel = preds[:, pred_index]

    grads = tape.gradient(class_channel, last_conv_layer_output)
    pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))

    last_conv_layer_output = last_conv_layer_output[0]
    heatmap = last_conv_layer_output @ pooled_grads[..., tf.newaxis]
    heatmap = tf.squeeze(heatmap)

    heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)
    return heatmap.numpy()

def save_and_display_gradcam(img_path, heatmap, cam_path="cam.jpg", alpha=0.4):
    img = cv2.imread(img_path)
    heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0]))
    heatmap = np.uint8(255 * heatmap)
    heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
    superimposed_img = heatmap * alpha + img
    cv2.imwrite(cam_path, superimposed_img)

    img = cv2.cvtColor(superimposed_img, cv2.COLOR_BGR2RGB)
    plt.imshow(img)
    plt.axis('off')
    plt.show()

img_array = get_img_array(path, size=(img_size, img_size))

# **修改:使用你自定义模型的最后一个卷积层的名称**
last_conv_layer_name = "conv2d_2"  
heatmap = make_gradcam_heatmap(img_array, model, last_conv_layer_name)

# Display Grad-CAM
save_and_display_gradcam(path, heatmap)

解释:

  1. 修改: 我们将 last_conv_layer_name 从 "conv5_block3_out" 更改为 "conv2d_2"。这是因为 "conv2d_2" 是你自定义模型中最后一个卷积层的名称。

  2. 其他的 Grad-CAM 函数保持不变,因为它们现在将使用正确的卷积层来生成热图。

通过这个修改,你的 Grad-CAM 实现应该可以正常工作,并生成你期望的热图,突出显示图像中对分类决策重要的区域.

标签:python,tensorflow,conv-neural-network,resnet,image-classification
From: 78773167

相关文章

  • 让 cpython 优化恒定条件
    我正在用Python编写需要尽可能高效运行的代码,但有时我需要深入挖掘调试语句。不要注释这些输入或输出(或者使用外部预处理器来处理代码,就像这里建议的那样Python相当于#ifdefDEBUG或这里如何在python中实现“#ifdef”?|||)我想在模块的开头定义一个变量......
  • 如何使用 for 循环在 python jupyter 笔记本中创建动态图?
    我正在学习本课关于用Python求解热方程。该课程指出,在求解热方程后,我们可以通过在循环中简单地调用pyplot.plot()来可视化解的动画图,其中下面的代码将动态绘制每次每个点的温度,从而得到一个动画情节(课程帖子中提供了动画情节的示例)。importnumpyfrommatplotlibi......
  • Python:动态爱心代码
    importrandomfrommathimportsin,cos,pi,logfromtkinterimport*CANVAS_WIDTH=640CANVAS_HEIGHT=480CANVAS_CENTER_X=CANVAS_WIDTH/2CANVAS_CENTER_Y=CANVAS_HEIGHT/2IMAGE_ENLARGE=11HEART_COLOR="#FF99CC"defcenter_......
  • 如何在 PYTHON 中查找输入数字的千位、百位、十位和个位中的数字?例如:256 有 6 个一、5
    num=int(input("Pleasegivemeanumber:"))print(num)thou=int((num//1000))print(thou)hun=int((num//100))print(hun)ten=int((num//10))print(ten)one=int((num//1))print(one)我尝试过这个,但它不起作用,我被困住了。代码几乎是正确的,但需......
  • ModuleNotFoundError:没有名为“pyaes”的模块 python 虚拟机
    在此处输入图像描述当我在启动python项目的虚拟机上构建某个工具时,几秒钟后会出现此消息。我已经尝试重新安装pyaes但无济于事。谁能帮我?非常感谢我已经尝试重新安装pyaes但无济于事,我搜索了tepyaes模块的十个路径,但我没有找到它,而我在另一台虚拟机上完成了......
  • 使用 Python 操作 Splunk
    使用Python操作Splunk目录使用Python操作Splunk1参考文档2安装PythonSplunk-SDK3连接splunk4配置查询5参考1参考文档SplunkGithub地址:GitHub-splunk/splunk-sdk-python:SplunkSoftwareDevelopmentKitforPythonSplunk开发者文档地址:Pythontools|......
  • Python:如何通过请求帖子对评论进行投票?
    我对评论进行投票的代码无法正常工作。它返回一个http500错误。我有一个使用用户登录的Python程序,它应该自动对评论进行投票。我的代码如下:frombs4importBeautifulSoupimportrequestslogin_url="https://xxxxxxxxxxx/auth/login"login_url_post="http......
  • python_day7(补1)
    数据类型​ 之前为列表类型​ 插入一个元组的介绍 之后还有字典,三者区别为括号方式()[]{}元组类型(tuple)使用:先定义一个元组数据​ vegetable_tuple='(tomato','corn','cucumber','carrot','corn','pumpkin)'与列表类型格式很像,不过只能取不能改,需要特......
  • 在 python 中写入 %appdata% 时出现奇怪的行为
    我试图将一些数据写入%appdata%。一切似乎都像Script1的输出中所示的那样工作。正在创建新目录并保存文件,并且也成功检索数据。但尝试查看文件资源管理器中的数据时,该文件夹不存在!CMD也找不到文件和目录。后来我手动创建了文件,检查了一下,发生了什么。CMD现在可以找到该文......