首页 > 其他分享 >PyTorch项目实战09——CIFAR10数据的读取和展示

PyTorch项目实战09——CIFAR10数据的读取和展示

时间:2023-07-01 22:01:18浏览次数:56  
标签:CIFAR10 标签 09 0.5 transform PyTorch 图像 True

1 CIFAR10

cifar10 的官方网址:<http://www.cs.toronto.edu/~kriz/cifar.html>

是由32\*32像素的60000张图片组成的数据集,50000张图片用于训练,10000张图片用于测试,其中有10个类别,每个类别有6000张图片,

分类之间彼此独立,不会重叠,因此是一个单标签多分类的问题。

2 读取CIFAR10数据

首先在对图像做预处理,将图像的RGB均值都设置为0.5。

transform = transforms.Compose(
        [transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

CIFAR10(root="./data", train=True, download=True, transform=transform)

  • 使用 torchvision 下载CIFAR10的图像,因为之前本地没有CIFAR10的图像内容,第一次使用时需要进行下载,download=True
  • root 图像所在的位置,如果之前没有图像,这里也是图像的下载位置
  • train 是否用于训练,True 为训练,False 为测试
  • transform 图像数组

DataLoader(trainset, batch\_size=4, shuffle=True, num\_workers=2)

对图像进行训练

  • trainset 原始图像
  • batch\_size 每批次数据量,这里是每个批次4张图像
  • shuffle 使用乱序,因为是进行训练,这里是True,如果是用于测试,可以设置为 False
  • num\_workers 线程数,加快运行速度,因 pytorch 在 windows 平台有bug,最好设置为0
trainset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=4, shuffle=True, num_workers=0)
testset = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)
testloader = DataLoader(testset, batch_size=4, shuffle=False, num_workers=0)

3 将图像展示

将训练图像获取到后,展示到前端界面上。

对一个批次的图像转换为迭代器,并将其尺寸缩小一半后,显示在界面上。

# 获取训练图片
dataiter = iter(trainloader)
images, labels = dataiter.__next__()


def imgshow(img):
    img = img / 2 + 0.5
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()


# 显示图像
imgshow(torchvision.utils.make_grid(images))

4 打印出训练图像标签

将图像标签(或者叫类别)定义在一个元组中,使用迭代器,将上一步中展示出的图像对应的标签打印到控制台。

classes = (
    "airplane",
    "automobile",
    "bird",
    "cat",
    "deer",
    "dog",
    "frog",
    "horse",
    "ship",
    "truck"
)

print(" ".join('%5s' % classes[labels[l]] for l in range(4)))

5 运行

5.1 首先下载图像到本地

PyTorch项目实战09——CIFAR10数据的读取和展示_迭代器

5.2 随机展示一组图像

PyTorch项目实战09——CIFAR10数据的读取和展示_迭代器_02

5.3 打印该组图像所属标签

关闭上一步展示出的图像,程序继续执行,将上一步展示的图片所属标签,打印到控制台中。

PyTorch项目实战09——CIFAR10数据的读取和展示_迭代器_03


标签:CIFAR10,标签,09,0.5,transform,PyTorch,图像,True
From: https://blog.51cto.com/u_113754/6601659

相关文章

  • 03常用pytorch剪枝工具
    常用剪枝工具pytorch官方案例importtorch.nn.utils.pruneaspruneimporttorchfromtorchimportnnimporttorch.nn.utils.pruneaspruneimporttorch.nn.functionalasFprint(torch.__version__)device=torch.device("cuda"iftorch.cuda.is_available()els......
  • pytorch保存单通道灰度图片
    前言importtorchimporttorchvision.transformsastransformsfromtorchvision.utilsimportsave_imageimage=torch.randn(1,256,256)#示例,随机生成一个单通道图像#将图像张量保存为文件save_image(image,"single_channel_image.png",normalize=True)pytorch中......
  • 怎样导入pytorch gpu版本?
    1.下载anaconda2.在anaconda里创建环境create-npytorch_gpu#激活环境condaactivatepytorch_gpu3.在环境里install修改镜像接下来就是关键一步了,把-cpytorch表示的pytorch源,更改为国内的镜像。https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/先浏......
  • 光脚丫学LINQ(009):选择各个源元素的子集
    视频演示:http://u.115.com/file/f2d7193f3a 选择源序列中的各个元素的子集有两种主要方法:1、若要只选择源元素的一个成员,请使用点运算。在下面的示例中,假定Customer对象包含几个公共属性,其中包括名为City的字符串。在执行此查询时,此查询将生成字符串输出序列。NorthwindDataCo......
  • 使用numpy实现bert模型,使用hugging face 或pytorch训练模型,保存参数为numpy格式,然后使
     之前分别用numpy实现了mlp,cnn,lstm,这次搞一个大一点的模型bert,纯numpy实现,最重要的是可在树莓派上或其他不能安装pytorch的板子上运行,推理数据本次模型是随便在huggingface上找的一个新闻评论的模型,7分类看这些模型参数,这并不重要,模型占硬盘空间都要400+Mbert.embeddings.w......
  • 机器学习之pytorch环境配置以及cuda安装
     关于conda环境下安装cuda配置和pytorch安装cuda查看显卡型号 (进入cmd环境下) nvidia-smi 下载对应的cudaCUDA Toolkit Archive | NVIDIA Developer)选择与cuda相匹配的版本(版本尽量靠近些电脑的)建议使用迅雷下载,网站下载会限速正式安装安装路径的选择,......
  • AI_Pytorch_损失函数
    数据和向量损失函数数据的归一化Z-score均值方差归一化(standardization):把所有数据归一化到均值为0方差为1的分布中。适用于数据分布没有明显的边界,有可能存在极端的数据值。 数据符合正态分布,消除离群点的影响min-max标准化最值归一化(Normalizati......
  • 利用Pytorch实现Faster R-CNN
    代码解析: Pytorchtorchvision构建Faster-rcnn(一)----coco数据读取Pytorchtorchvision构建Faster-rcnn(二)----基础网络Pytorchtorchvision构建Faster-rcnn(三)----RPNPytorchtorchvision构建Faster-rcnn(四)----ROIHead训练模型:BaiduCloud 附加Pytorch源码:https://github.com/chen......
  • pta第三部分总结oop训练集09-11
    一,前言:oop09:7-1统计Java程序中关键词的出现次数:对Java中字符串,元字符,正则表达式的应用。oop10:7-1容器-HashMap-检索:对Java程序中HashMap的特性对输入内容进行检索的应用。7-2容器-HashMap-排序:对Java升序中HashMap的无序性的应用将其排序。7-3课程成绩......
  • 「路飞项目09」redis
    1Redis介绍和安装#Redis:软件,存储数据的,速度非常快,redis是一个key-value存储系统(没有表的概念),cs架构的软件-服务端客户端(python作为客户端,java,go,图形化界面,命令窗口的命令)#es:存数据的地方#关系型数据库和非关系型数据库-关系型:mysql,PostgreSQL【PG】,oracle,sqlserver,db......