本节视频主要内容
如何把数据集和transform结合在一起,毕竟因为不可能只对一张图片进行处理,所以会讲到在科研中需要使用的标准数据集该如何下载、组织、查看、使用。(也就是.dataset和.transforms如何进行联合使用)
torchvision的几个模块
进入PyTorch官网,点击官方文档(DOCS),看到不同的块,选择torchvision,看到API文档,在写代码时只需指定相应的数据集并设置参数,它就能自己去下载和使用这些标准数据集。
例如COCO数据集常用于目标检测、语义分割,MNIST手写文字,CIFAR常用于物体识别
torchvision除了.dataset这个模块,还有:
.io模块,但不常用;
.models模块,提供一些比较常见的神经网络,比较重要也比较常用,例如分类、语义分割、视频分类;
.ops模块提供少见操作;
.transforms是之前讲解的内容;
.utils提供一些小工具,如TensorBoard。
什么是CIFAR
CIFAR-10 是一个经典的图像分类数据集,包含 10 个类别的 60,000 张彩色图像(每张图像 32x32 大小),其中有 50,000 张用于训练,10,000 张用于测试。这个数据集已经被很多计算机视觉的研究人员用于深度学习模型的训练和测试。
torchvision.datasets.CIFAR10 是 torchvision
库中的一个内置数据集,用户可以很方便地通过它来加载 CIFAR-10 数据集(如代码所示)。
target_transform的作用
target_transform
参数允许你对数据集的标签(也就是 target)进行变换。它的作用是在返回数据之前,对目标值进行某些变换操作。
举例说明:
如果 CIFAR-10 中的标签是 0 到 9 的整数,对应类别如飞机、汽车等,你可以通过 target_transform
将这些整数转换为 one-hot 编码(一个向量),用以适应不同的模型需求。例如:
from torchvision import transforms
import torch
def one_hot_transform(target):
one_hot = torch.zeros(10)
one_hot[target] = 1
return one_hot
dataset = torchvision.datasets.CIFAR10(root='./data', train=True, target_transform=one_hot_transform)
设置断点,以及 Debug 与 Run 的区别
-
Run(运行): 在不设置断点的情况下,运行程序意味着代码会从头到尾顺序执行,并在没有错误的情况下完成。
-
Debug(调试): 进行 Debug 的目的是帮助开发者找出代码中的错误或问题。设置断点后,程序会在运行到断点时暂停,允许你逐步检查变量的值、执行的逻辑等。这样可以实时观察程序执行过程中每一步的状态。
- 在代码中设置断点(红点处)并选择 Debug,而不是直接运行的目的是让程序在那一行代码执行之前暂停,你可以通过查看变量状态、函数调用等,找到问题或者验证程序的执行逻辑。
Debug窗口中 Threads & Variables
栏目
1. Threads(线程): 显示当前 Python 程序正在运行的线程。每个线程代表程序的一个执行路径。对于多线程程序,这里会列出所有线程的状态。
2. Variables(变量): 这个部分显示的是在当前调试上下文中,所有的局部和全局变量的状态。
-
classes 这行对应的是 CIFAR-10 中的类别和具体数字的映射关系。
classes
的内容为一个列表,其中列出了 CIFAR-10 的所有类别,如飞机(airplane)、汽车(automobile),这些类别会对应数字 0 到 9。图中classes
下的内容正是这个映射关系。 -
其他可下拉内容:
- data:这是数据集本身的内容,即图片数据。
- targets:这是数据集的目标标签(target),即每张图片对应的类别标签。
- meta:这是一些元数据(如文件名、MD5 校验等)。
- transform:显示了用于对数据进行转换的操作(如将图像转为 Tensor)。
train_set 与 test_set
我们有如下两行代码:
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=dataset_transform, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=dataset_transform, download=True)
挨个解释其中的参数:
root="./dataset":表示数据集存储的根目录。如果该目录不存在,程序会自动下载数据集。
train=True/False:决定是否加载训练集(train=True
)或测试集(train=False
)。也就是说:
-
train=True
:表示加载 训练集,也就是用于模型训练的数据。在 CIFAR-10 中,训练集有 50,000 张图像。 -
train=False
:表示加载 测试集,也就是用于模型评估的数据。CIFAR-10 的测试集中包含 10,000 张图像。
transform=dataset_transform:对图像数据进行预处理转换的操作。在这个例子中,dataset_transform
将图像数据转为 Tensor(张量),使其可以用于 PyTorch 模型的训练。
download=True:如果数据集未下载过,会自动下载。
Python中for循环的写法
标准写法
for i in range(10):
print(i)
这种写法会遍历从 0 到 9 的数值,循环 10 次。
遍历列表
for item in [1, 2, 3, 4]:
print(item)
这种写法会遍历一个列表中的每个元素。
带索引遍历
for index, item in enumerate([10, 20, 30]):
print(index, item)
这种写法会同时获取元素的索引和值。
enumerate()
是 Python 内置的一个函数,用于遍历一个可迭代对象(如列表、元组等)时,同时获取元素的索引和值。
举个例子,假设你有一个列表 [10, 20, 30]
,使用 enumerate()
函数可以同时得到列表中每个元素的索引和对应的值。
下面是一个简单的示例:
my_list = [10, 20, 30]
for index, value in enumerate(my_list):
print(index, value)
输出结果:
0 10
1 20
2 30
index
:表示当前元素的索引(位置),比如 0, 1, 2。value
:表示当前索引对应的元素的值,比如 10, 20, 30。