t-SNE理论相关理论可参见t-SNE 算法。本文通过PyTorch提供的预训练Resnet50提取CIFAR-10表征,并使用t-SNE进行可视化。
加载预训练Resnet50
import torch
from torchvision.models import resnet50, ResNet50_Weights
# 加载ResNet模型
resnet = resnet50(weights=ResNet50_Weights.DEFAULT)
# 移除最后一层全连接层
resnet_fe = torch.nn.Sequential(*(list(resnet.children())[:-1]))
resnet_fe.cuda()
resnet_fe.eval()
加载CIFAR-10数据集
from torchvision.datasets import CIFAR10
from torchvision import transforms
transformer = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
dataset = CIFAR10(root='./data', train=True, download=True, transform=transformer)
提取CIFAR-10表征
from torch.utils.data import DataLoader
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
features = []
labels = []
for i, (x, y) in enumerate(dataloader):
x = x.cuda()
with torch.no_grad():
feature = resnet_fe(x) # feature shape: (batch_size, 512, 1, 1)
feature = feature.view(feature.size(0), -1).cpu() # feature shape: (batch_size, 512)
for f,l in zip(feature,y):
features.append(f.numpy())
labels.append(l.numpy())
训练t-SNE
from sklearn.manifold import TSNE
import numpy as np
features = np.array(features)
labels = np.array(labels)
tsne = TSNE(n_components=2, random_state=0).fit_transform(X=features)
可视化
import seaborn as sns
import matplotlib.pyplot as plt
# 提取 x 和 y 坐标
x = tsne[:, 0]
y = tsne[:, 1]
# 创建 DataFrame
import pandas as pd
df = pd.DataFrame({'x': x, 'y': y, 'label': labels})
# 创建散点图
plt.figure(figsize=(8, 6))
sns.set(style="whitegrid")
sns.scatterplot(data=df, x='x', y='y', hue='label', palette='tab10', alpha=0.8)
plt.xlabel('')
plt.ylabel('')
plt.xticks([])
plt.yticks([])
plt.legend(title='Labels')
plt.savefig('scatter_plot.svg')
plt.show()
参考文献
运行环境
jupyter 1.0.0 py312haa95532_9
matplotlib 3.8.0 py312haa95532_0
pandas 2.2.1 py312h0158946_0
pytorch 2.2.2 py3.12_cuda12.1_cudnn8_0 pytorch
scikit-learn 1.3.0 py312hc7c4135_2
seaborn 0.12.2 py312haa95532_0
标签:10,plt,feature,CIFAR,resnet,SNE,import
From: https://www.cnblogs.com/zh-jp/p/18139028