特征图可视化
import torch
import torch.nn as nn
from PIL import Image
import numpy as np
from torchvision import transforms
import matplotlib.pyplot as plt
# import os
# os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
class Test(nn.Module):
def __init__(self):
super(Test, self).__init__()
self.Conv1 = nn.Conv2d(in_channels=3, out_channels=1, kernel_size=3, stride=1, padding=1)
def forward(self, x):
x = self.Conv1(x)
return x
path = "00.png"
trans = transforms.Compose([transforms.ToTensor(),
# transforms.Resize((224, 224)),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
x = Image.open(path)
x = trans(x)
x = torch.unsqueeze(x, 0) # 填充一维
modol = Test()
y = modol(x)
yy = np.squeeze(y, 0)
# plt.imshow(yy[0], cmap='gray')
plt.imshow(transforms.ToPILImage()((yy)[0]))
plt.show()
代码无注释,哪句有问题,欢迎留言,顺便给个关注。
标签:plt,nn,self,0.5,杂侩,transforms,import,代码 From: https://www.cnblogs.com/QIAN-ONE/p/16817479.html