首页 > 其他分享 >使用t-SNE可视化CIFAR-10的表征

使用t-SNE可视化CIFAR-10的表征

时间:2024-04-16 19:47:00浏览次数:33  
标签:10 plt feature CIFAR resnet SNE import

t-SNE理论相关理论可参见t-SNE 算法。本文通过PyTorch提供的预训练Resnet50提取CIFAR-10表征,并使用t-SNE进行可视化。

加载预训练Resnet50

import torch
from torchvision.models import resnet50, ResNet50_Weights

# 加载ResNet模型
resnet = resnet50(weights=ResNet50_Weights.DEFAULT)

# 移除最后一层全连接层
resnet_fe = torch.nn.Sequential(*(list(resnet.children())[:-1]))
resnet_fe.cuda()
resnet_fe.eval()

加载CIFAR-10数据集

from torchvision.datasets import CIFAR10
from torchvision import transforms

transformer = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

dataset = CIFAR10(root='./data', train=True, download=True, transform=transformer)

提取CIFAR-10表征

from torch.utils.data import DataLoader

dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
features = []
labels = []
for i, (x, y) in enumerate(dataloader):
    x = x.cuda()
    with torch.no_grad():
        feature = resnet_fe(x)  # feature shape: (batch_size, 512, 1, 1)
    feature = feature.view(feature.size(0), -1).cpu()  # feature shape: (batch_size, 512)
    for f,l in zip(feature,y):
        features.append(f.numpy())
        labels.append(l.numpy())

训练t-SNE

from sklearn.manifold import TSNE
import numpy as np

features = np.array(features)
labels = np.array(labels)
tsne = TSNE(n_components=2, random_state=0).fit_transform(X=features)

可视化

import seaborn as sns
import matplotlib.pyplot as plt

# 提取 x 和 y 坐标
x = tsne[:, 0]
y = tsne[:, 1]

# 创建 DataFrame
import pandas as pd
df = pd.DataFrame({'x': x, 'y': y, 'label': labels})

# 创建散点图
plt.figure(figsize=(8, 6))
sns.set(style="whitegrid")
sns.scatterplot(data=df, x='x', y='y', hue='label', palette='tab10', alpha=0.8)
plt.xlabel('')
plt.ylabel('')
plt.xticks([])
plt.yticks([])
plt.legend(title='Labels')
plt.savefig('scatter_plot.svg')
plt.show()

参考文献

  1. Representation Learning — CIFAR-10
  2. t-SNE 算法
  3. sklearn.manifold.TSNE

运行环境

jupyter                   1.0.0           py312haa95532_9
matplotlib                3.8.0           py312haa95532_0
pandas                    2.2.1           py312h0158946_0
pytorch                   2.2.2           py3.12_cuda12.1_cudnn8_0    pytorch
scikit-learn              1.3.0           py312hc7c4135_2
seaborn                   0.12.2          py312haa95532_0

标签:10,plt,feature,CIFAR,resnet,SNE,import
From: https://www.cnblogs.com/zh-jp/p/18139028

相关文章

  • 特斯拉宣布 10%大裁员;刘强东数字人开启直播首秀丨 RTE 开发者日报 Vol.185
       开发者朋友们大家好: 这里是「RTE开发者日报」,每天和大家一起看新闻、聊八卦。我们的社区编辑团队会整理分享RTE(RealTimeEngagement)领域内「有话题的新闻」、「有态度的观点」、「有意思的数据」、「有思考的文章」、「有看点的会议」,但内容仅代表编......
  • 4月10日每日总结
    SpringBoot与安全今天我学习了如何在SpringBoot应用中实现安全控制,保护应用的资源和用户数据。我首先了解了SpringSecurity框架的基本原理和核心组件,包括用户认证、授权管理、表单登录等功能。随后,我学习了如何在SpringBoot应用中集成SpringSecurity,并配置基本的安全策略和......
  • FR107-ASEMI快恢复二极管FR107
    编辑:llFR107-ASEMI快恢复二极管FR107型号:FR107品牌:ASEMI封装:DO-41最大平均正向电流(IF):1A最大循环峰值反向电压(VRRM):1000V最大正向电压(VF):1.20V工作温度:-55°C~150°C反向恢复时间:50ns芯片个数:1芯片尺寸:mil引脚数量:2正向浪涌电流(IFMS):30A包装方式:50/管1000/盘3000/箱F......
  • 10.函数对象 与 闭包函数
    【一】函数对象Python中一切皆对象可直接被引用定义一个函数,用一个新的变量名来存,来调用可作为元素被储存功能字典中函数的地址可作为参数传递给另一个函数将函数的内存地址作为参数函数的返回值可以是函数直接将函数的内存地址返回【二】闭包函数1)......
  • CF1097F Alex and a TV Show 题解
    题目链接点击打开链接题目解法很牛的套路啊!看到集合并,且只要求奇偶性的问题,第一个想到\(bitset\)\(1,2,4\)操作都是好维护的,关键是第\(3\)个操作看到$\gcd$,首先想到莫反令\(c_{x,i}\)为集合\(x\)中数\(i\)的出现次数则\(c_{x,i}=\sum\limits_{i|j}\sum\limit......
  • ResNet50算法
    ResNet(Residualnet)是残差网络的通用概念,而ResNet50是一个具体的网络结构,其由50个卷积层组成。ResNet50是指包含了50个卷积层(包括卷积层、池化层、全连接层等)的ResNet网络。ResNet50是基于ImageNet数据集上的训练所提出的一个具体网络结构。ResNet核心:在最终输出中,除了......
  • kylin V10 SP3将openssl降级到1.0.2u
    前言kylinV10SP3的openssl版本较高OpenSSH_9.5p1,OpenSSL1.1.1f 31Mar2020,程序编译环境必须是OpenSSL1.0.2,故需要降级,记录一下。下载安装包下载openssl1.0.2u安装#解压tarzxvfopenssl-1.0.2u.tar.gz#安装cdopenssl-1.0.2u./config--prefix=/usr/local/sslmak......
  • 华为实习4.10机考第二题C++代码
    考的是简单的并查集这道题考法就是并查集,若两个图片相似度大于0,则将他们放到一个家族中,同时维护家族的相似度总和。注意M矩阵是对称矩阵,所以需要避免重复维护相似度,因此可以只针对M矩阵的下三角矩阵或上三角矩阵中的连接块,计算相似度总和;或考虑整个M矩阵,然后相似度总和除......
  • 1025 反转链表
    我看其他博客用的reverse,但是下标我真的有点糊涂,以下是参考某位dalao的。#include<bits/stdc++.h>usingnamespacestd;structnode{ intsno; intdata; intnext;}s[100010];intmain(){ intstart,cnt,fz;//start cin>>start>>cnt>>fz; for(inti=0;i<cnt......
  • P10288 [GESP样题 八级] 区间
    原题链接题解本题的优化真的很重要!!把所有元素出现的下标用map套vector存起来,然后二分查找code#include<bits/stdc++.h>usingnamespacestd;map<int,vector<int>>mp;intmain(){ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);//缺一不可intt;ci......