PyTorch数据处理工具箱
torch.utils.data
- Dataset
抽象类,其他数据集类定义时需继承自该类,并覆写两个方法:getitem__和__len - DataLoader
定义一个新的迭代器,实现批量batch读取,打乱shuffle数据和并行加速等功能 - random_split
将数据集随机拆分成给定长度的非重叠的新数据集 - *sample
多种采样函数
torch.utils.data.Dataset抽象类,自定义数据集需继承这个类,并实现两个函数__len__和__getitem__
torch.utils.data.DataLoader定义数据集迭代器,实现batch读取
class TestDataset(data.Dataset):
def __init__(self):
self.Data= np.asarray([[1,2],[3,4],[2,1],[3,4],[4,5]])
self.Label= np.asarray([0,1,0,1,2])
def __getitem__(self, index):
txt= torch.from_numpy(self.Data[index]) # 将numpy转换成Tensor
label= torch.tensor(self.Label[index])
return txt,label
def __len__(self):
return len(self.Data)
# 使用DataLoader对数据集Dataset进行批量batch处理,同时进行shuffle和并行加速等操作
DataLoader(
dataset,
batch_size=1,
shuffle=False,
sampler=None, # 样本抽取
batch_sampler=None,
num_workers=0, # 多进程加载
collate_fn=<function default_collate at 0x7f108ee01620>, # 多个样本拼接成一个batch的拼接方式
pin_memory=False, # 是否将数据保存在pin memory区,加速加载到GPU
drop_last=False, # 将多出的不足一个batch_size的数据丢弃
timeout=0,
worker_init_fn=None,
)
test= TestDataset()
test_loader=data.DataLoader(test, batch_size=2, shuffle=False, num_workers=2)
for i,traindata in enumerate(test_loader):
data,label= traindata
# 可以像使用迭代器一样使用test_loader,不过由于他不是迭代器,使用iter()将其转换成迭代器并用next()遍历
iter(test_loader)
next(dataiter)
一般使用data.Dataset处理同一目录下的数据,若数据在不同目录下(不同目录表示不同类别),此时可用torchvision处理数据
torchvision
- datasets
继承自torch.utils.data.Dataset,提供常用数据集Mnist,Cifar10/100,ImageNet和COCO - models
提供经典的网络结构和模型pretrained=True
,如AlexNet,VGG,ResNet,Inception系列 - transforms
常用的数据预处理操作,主要对Tensor和PIL Image类型操作,当预处理有多个函数时,可用transforms.Compose
将其组合 - utils
含有两个函数:make_grid
将多张图像拼接在一个网格中,save_img
将Tensor保存成图像
transforms 对 PIL.Image 常见操作
- Scale/Resize 调整尺寸,保持长宽比不变
- CenterCrop,RandomCrop,RandomSizeCrop 裁减图像
- Pad 填充
- ToTensor 将取值范围为[0,255]的
PIL.Image
或形状为(H,W,C)的ndarray
转换成(C,H,W),取值范围为[0,1.0]的torch.FloatTensor - RandomHorizontalFlip 图像随机水平翻转,翻转概率为0.5
- RandomVerticalFlip 图像随机垂直翻转
- ColorJitter 修改图像亮度,对比度和饱和度
transforms对 Tensor 常见操作
- Normalize 标准化
- ToPILImage 将Tensor转换成PIL.Image
transforms.Lambda()
使用自定义lambda表达式,如每个像素加10:transforms.Lambda(lambda x:x.add(10))
当对数据集进行多个操作时,可通过Compose()
将这些操作拼接,类似于nn.Sequential
transforms.Compose({
# 将给定的PIL.Image进行中心切割,size可以是tuple或Integer
transforms.CenterCrop(10),
transforms.RandomCrop(20,padding=0),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))
})
函数ImageFolder
为torchvision.datasets
中成员
当文件依据标签存储在不同文件夹下时,可以使用其直接构造出Dataset,ImageFolder
会将文件夹名自动转换成序列
my_trans=transforms.Compose({
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(224),
transforms.ToTensor()
})
train_data= torchvision.datasets.ImageFolder("",transforms=my_trans)
train_loader= torch.utils.data.DataLoader(train_data, batch_size=8, shuffle=True)
for i_batch,img in enumerate(train_loader):
if i_batch==0:
print(img[1])
fig= plt.figure()
grid= torchvision.utils.make_grid(img[0])
plt.imshow(grid.numpy().transpose((1,2,0)))
plt.show()
utils.save_image(grid,'te.png')
标签:__,utils,torch,batch,transforms,数据处理,工具,data
From: https://www.cnblogs.com/sgqmax/p/18522291