首页 > 其他分享 >PyTorch常用操作

PyTorch常用操作

时间:2023-04-29 17:34:14浏览次数:49  
标签:常用 nn torch PyTorch transforms 图像 操作 数据 加载

数据集加载

1. 网络数据集

加载数据集:https://pytorch.org/vision/stable/datasets.html

import torchvision.datasets as datasets
import torchvision.transforms as transforms

# 定义数据集变换(将图像转换为张量以及对图像进行归一化的操作)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# 加载MNIST数据集
train_set = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_set = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

2. 自己构造数据集

dataset.ImageFolder 是 PyTorch 中 torchvision.datasets 模块中的一个函数,用于从一个文件夹中加载图像数据集。该文件夹包含子文件夹,每个子文件夹对应一种类别,每个文件对应一个样本。

dataset.ImageFolder 函数的参数包括:

  • root: 数据集存放的根目录。
  • transform: 对数据集进行的变换。如果不指定,则返回原始数据集。
  • target_transform: 对标签进行的变换。如果不指定,则返回原始标签。
  • loader: 加载数据集的方式,默认为 PIL 的 Image.open() 函数。
  • is_valid_file: 一个可调用对象,用于过滤不合法的文件。如果不指定,则默认所有文件均为合法文件。

使用dataset.ImageFolder函数可以通过以下代码加载数据集:

import torchvision.datasets as datasets
import torchvision.transforms as transforms

# 定义数据集变换
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 加载数据集
dataset = datasets.ImageFolder(root='./data', transform=transform)

上述代码中,首先定义了一个数据集变换,其中包括将图像缩放到256×256大小、居中裁剪到224×224大小、将图像转换为张量以及对图像进行归一化的操作。然后使用 dataset.ImageFolder 函数加载 ./data 目录下的数据集,并应用上述变换。

# 定义加载函数
def loader(path):
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')

上述代码中,定义了一个加载函数loader,用于从文件路径中加载图像并返回一个PIL图像对象。该函数将图像转换为RGB格式,以便于应用数据集变换。

3. DataLoader

DataLoader 是 PyTorch 中用于加载数据的工具类。它可以将数据集封装成一个迭代器,用于在训练过程中按照指定的批次大小、随机打乱等方式加载数据。

  • dataset: Dataset类, 决定数据从哪读取以及如何读取
  • bathsize: 批大小
  • num_works: 是否多进程读取机制
  • shuffle: 每个epoch是否乱序
  • drop_last: 当样本数不能被batchsize整除时, 是否舍弃最后一批数据

以下是一个使用 DataLoader 类加载数据集的示例代码:

import torch.utils.data as data
import torchvision.transforms as transforms

# 定义数据集变换
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

# 加载图像数据集
train_set = data.ImageFolder(root='./data/train', transform=transform)
val_set = data.ImageFolder(root='./data/val', transform=transform)

# 定义数据加载器
train_loader = data.DataLoader(train_set, batch_size=32, shuffle=True)
val_loader = data.DataLoader(val_set, batch_size=32, shuffle=False)

上述代码中,首先定义了一个数据集变换transform,并使用dataset.ImageFolder()类加载图像数据集。然后,使用data.DataLoader()类定义了两个数据加载器,分别表示训练集和验证集。其中,train_loaderval_loader分别表示训练集和验证集的数据加载器。在数据加载器中,使用batch_size参数指定了每个批次的大小,使用shuffle参数指定了是否随机打乱数据集。

data/
├── train/
│   ├── class1/
│   │   ├── image1.jpg
│   │   ├── image2.png
│   │   └── ...
│   ├── class2/
│   │   ├── image1.jpg
│   │   ├── image2.png
│   │   └── ...
│   └── ...
├── val/
│   ├── class1/
│   │   ├── image1.jpg
│   │   ├── image2.png
│   │   └── ...
│   ├── class2/
│   │   ├── image1.jpg
│   │   ├── image2.png
│   │   └── ...
│   └── ...
└── ...

需要注意的是,DataLoader类是一个迭代器,可以使用for循环遍历数据集中的数据。例如:

for inputs, labels in train_loader:
    # 训练过程
    ...

在上述代码中,使用for循环遍历训练集中的数据。每次迭代返回一个批次的数据,其中inputs表示输入数据,labels表示对应的标签数据。

4. 数据集展示

import matplotlib.pyplot as plt
figure = plt.figure(figsize=(10, 10))
cols, rows = 4, 4
for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(len(training_data), size=(1,)).item()
    img, label = training_data[sample_idx]
    figure.add_subplot(rows, cols, i)
    plt.title(label)
    plt.axis("off")
    plt.imshow(img.squeeze())
    #plt.imshow(img.squeeze(), cmap="gray")
plt.show()

数据集变换

PyTorch中提供了多种数据集变换(Transform),可以对图像、文本等不同类型的数据进行预处理。以下是常用的数据集变换:

1. 图像数据变换

  • transforms.Resize(size)

将图像的大小调整为指定的大小。size可以是一个整数,表示将图像的最短边调整为该大小,另一边按比例缩放;也可以是一个二元组,表示将图像的大小调整为指定的宽度和高度。

  • transforms.CenterCrop(size)

对图像进行中心裁剪,将图像裁剪为指定的大小。size可以是一个整数,表示将图像的宽度和高度都裁剪为该大小;也可以是一个二元组,表示将图像的宽度和高度分别裁剪为指定的宽度和高度。

  • transforms.RandomCrop(size)

对图像进行随机裁剪,将图像随机裁剪为指定的大小。size的含义与transforms.CenterCrop(size)相同。

  • transforms.RandomHorizontalFlip(p=0.5)

以指定的概率随机对图像进行水平翻转。p为翻转的概率,取值范围为[0, 1]

  • transforms.ToTensor()

将图像转换为Tensor格式。

  • transforms.Normalize(mean, std)

对图像进行标准化处理。meanstd分别为均值和标准差,可以是一个列表或元组,表示RGB三个通道的均值和标准差。

2. 文本数据变换

  • transforms.ToTensor()

将文本转换为Tensor格式。

  • transforms.Lambda(lambda)

自定义变换,使用lambda函数对文本进行处理。

Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))  #one-hot编码

3. 其他数据变换

  • transforms.Compose(transforms)

将多个变换组合起来使用。transforms为一个很多tansformes列表,表示需要组合使用的变换。

预训练模型与权重

https://pytorch.org/vision/stable/models.html

PyTorch中提供了很多预训练的模型,包括图像分类、目标检测、分割等领域的模型。这些预训练模型已经在大规模数据集上进行了训练,可以直接用于特定任务的微调或特征提取。

预训练模型通常包含两个部分:模型结构和权重。模型结构定义了模型的网络结构和参数,权重则包含了预训练模型的参数值。在PyTorch中,可以使用torchvision.models模块中的函数加载预训练模型及其权重。

以下是加载预训练模型及其权重的示例代码:

import torch.nn as nn
import torchvision.models as models

# 加载预训练模型
model = models.resnet18(pretrained=True)

# 将最后一层全连接层替换为自定义的层
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, num_classes)

# 加载预训练模型的权重
model.load_state_dict(torch.load('resnet18.pth'))

在上述代码中,使用models.resnet18(pretrained=True)函数加载预训练的ResNet18模型,其中pretrained=True表示加载预训练模型的权重。然后,将最后一层的全连接层替换为自定义的层。最后,使用torch.load()函数加载预训练模型的权重。

建立神经网络

https://pytorch.org/docs/stable/nn.html

1. 卷积层

nn.Conv2d()

2. 反卷积层

nn.ConvTranspose2d()

3. Pooling 层

nn.MaxPool2d()

nn.AvgPool2d()

nn.FractionalMaxPool2d()

nn.LPPool2d()

nn.AdaptiveMaxPool2d()

nn.AdaptiveAvgPool2d()

4. Padding 层

5. 激活层

nn.ReLU()

nn.Softmax()

nn.Tanh()

nn.Sigmoid()

nn.LeakyReLU()

6. 线性层

nn.Identity()

nn.Linear()

7. Sequential

model = nn.Sequential(
          nn.Conv2d(1,20,5),
          nn.ReLU(),
          nn.Conv2d(20,64,5),
          nn.ReLU()
        )

model = nn.Sequential(OrderedDict([
          ('conv1', nn.Conv2d(1,20,5)),
          ('relu1', nn.ReLU()),
          ('conv2', nn.Conv2d(20,64,5)),
          ('relu2', nn.ReLU())
        ]))

8. Block示例

构造块有助于编写重复复杂的网络

import torch
import torch.nn as nn


# 残差模块,将输入加到输出上
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, *args, **kwargs):
        return self.fn(x, *args, **kwargs) + x

# 上采样(反卷积)
def Upsample(dim):
    return nn.ConvTranspose2d(dim, dim, 4, 2, 1)

# 下采样
def Downsample(dim):
    return nn.Conv2d(dim, dim, 4, 2, 1)

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = nn.GroupNorm(1, dim)

    def forward(self, x):
        x = self.norm(x)
        return self.fn(x)

优化神经网络

1. 定义损失函数

https://pytorch.org/docs/stable/nn.html

在深度学习中,损失函数是评估模型性能的重要指标。PyTorch提供了多种常用的损失函数,可以根据具体情况选择合适的损失函数。

下面介绍几种常用的损失函数:

from torch import nn

# L1损失
nn.L1Loss()

# 均方误差损失函数(Mean Squared Error Loss)
nn.MSELoss()

# 交叉熵损失函数(Cross-Entropy Loss)
nn.CrossEntropyLoss()

# KL散度损失函数(Kullback-Leibler Divergence Loss)
nn.KLDivLoss(reduction='batchmean')

# 二元交叉熵损失函数
nn.BCELoss()

# Huber损失
nn.HuberLoss()

# Hinge损失
nn.HingeEmbeddingLoss()

2. 定义优化器

https://pytorch.org/docs/stable/optim.html

PyTorch中提供了许多优化器(Optimizer)用于优化神经网络的参数,使得损失函数最小化。以下是一些常用的优化器:

from torch import optim

# SGD(Stochastic Gradient Descent)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

# Adam(Adaptive Moment Estimation)
optimizer = optim.Adam([var1, var2], lr=0.0001)

# Adagrad(Adaptive Gradient)

# Adadelta

# RMSprop(Root Mean Square Propagation)

3. 使用示例

for input, target in dataset:
    optimizer.zero_grad()
    output = model(input)
    loss = criterion(output, target)
    loss.backward()
    optimizer.step()

保存和加载模型

在PyTorch中,我们可以使用torch.save()torch.load()函数来保存和加载模型。

通常,我们将模型的参数保存到一个文件中,可以是.pt.pth格式的文件。

# 保存模型参数
torch.save(model.state_dict(), 'model_weights.pth')
model.load_state_dict(torch.load('model_weights.pth'))

# 保存整个模型
torch.save(model, 'model.pth')
model = torch.load('model.pth')

调用GPU

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

标签:常用,nn,torch,PyTorch,transforms,图像,操作,数据,加载
From: https://www.cnblogs.com/jijunhao/p/17364240.html

相关文章

  • C# 序列化操作类
    usingSystem;usingSystem.Collections.Generic;usingSystem.IO;usingSystem.Linq;usingSystem.Runtime.Serialization.Formatters.Binary;usingSystem.Text;usingSystem.Threading.Tasks;namespaceEIMS{[Serializable]publicstaticclass_Serial......
  • C# MD5 加密操作类
    usingSystem;usingSystem.Collections.Generic;usingSystem.Text;usingSystem.Globalization;usingSystem.Security.Cryptography;usingSystem.IO;usingSystem.ServiceModel;namespaceEIMS{publicstaticclassMD5Helper{#regionMD5加密......
  • 对文件的操作
    /*ifstream读文件ofstream写文件fstream读写文件这个三个的头文件是fstreamofstreamoutfile;*/写文件 ofstream h1; /fstreamh1h1.open("user.txt");h1<<name<<"\t";//"\t"换行h1<<age<<endl; //endl表示换行h1.clos......
  • git与github(结合clion操作)
    对自己学习git的一个记录,由于刚开始接触git,所以没有对于git做深入解释和说明,仅供参考,如有理解不对的地方或者需要改进的地方敬请指出。 用到的git命令:gitinit//初始化gitadd.//添加所有文件gitadd文件名//添加指定文件git......
  • Shell列表操作
    字符串列表定义方法已空格分割a=(1234)输出列表所有元素echo${a[*]}输出列表下标echo${!a[*]}输出列表长度echo${#a[*]}列表循环foriin${a[*]}doecho$idone使用列表实现数值排序#冒泡算法a=(1345078974)#获取列表长度len=${#a[@]}echo......
  • 使用findIndex查找并做一些操作
    1.查找指定数据并删除letfindIndex=arrItemsApprover.findIndex(item=>item.zusrid===oObject.zusrid);if(findIndex!==-1){ arrItemsApprover.splice(findIndex,1);}2.查找指定数据并添加属性arrData.forEach(item=>{if(selectApproveUserData.findIndex(i=>item......
  • Python 基于win32com客户端实现Excel操作
    测试环境Python3.6.2代码实现非多线程场景下使用新建并保存EXCELimportwin32com.clientfromwin32apiimportRGBdefsave_something_to_excel(result_file_path):excel_app=win32com.client.Dispatch('Excel.Application')excel_app.Visible=False#设......
  • Pytorch2 如何通过算子融合和 CPU/GPU 代码生成加速深度学习
    动动发财的小手,点个赞吧!PyTorch中用于图形捕获、中间表示、运算符融合以及优化的C++和GPU代码生成的深度学习编译器技术入门计算机编程是神奇的。我们用人类可读的语言编写代码,就像变魔术一样,它通过硅晶体管转化为电流,使它们像开关一样工作,并允许它们实现复杂的逻辑——这......
  • pip和conda的源管理相关操作
    一、pip使用pip默认的镜像在国外,网络连接较差,下载速度比较慢D:\pythonProject3\Django>pipinstallDjango==2.1.3CollectingDjango==2.1.3DownloadingDjango-2.1.3-py3-none-any.whl(7.3MB)|█████████████|3.0MB15kB/set......
  • DML操作
    外键概念:外键作用:创建外键建表时指定外键约束建表后修改删除外键操作:删除具有主键关系的表示,要先删字表,后删除主表DML语言添加数据INSERT命令修改数据update命令where条件子句删除数据DELETE命令TRUNCATE命令外键1.概念:如果公共关键字在......