首页 > 其他分享 >Pytorch分类模型的训练框架

Pytorch分类模型的训练框架

时间:2024-04-15 16:26:41浏览次数:26  
标签:__ nn 框架 模型 dataset Pytorch import self size

Pytorch分类模型的训练框架

PhotoDataset数据集是自己定义的数据集,数据集存放方式为:

----image文件夹

--------0文件夹

--------------img1.jpg

--------------img2.jpg

--------1文件夹

--------------img1.jpg

--------------img2.jpg

....

如果是cpu训练的话,就把代码中的.cuda()改成.cpu()

import os
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from torch import nn
from torch import optim
from torch.utils.data.sampler import SubsetRandomSampler
import torch.nn.functional as F
import torchvision.models as models


# 定义数据集类
class PhotoDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.files = self._load_files()

    def _load_files(self):
        files = []
        for label in range(10):
            label_dir = os.path.join(self.root_dir, str(label))
            for file_name in os.listdir(label_dir):
                file_path = os.path.join(label_dir, file_name)
                files.append((file_path, int(label)))
        return files

    def __len__(self):
        return len(self.files)

    def __getitem__(self, index):
        file_path, label = self.files[index]
        image = torchvision.io.read_image(file_path, torchvision.io.ImageReadMode.RGB)
        image = image.float()
        if self.transform:
            image = self.transform(image)
        return image, label

# 数据集路径
data_dir = './photo2'

# 数据集预处理
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    #transforms.ToTensor(),
    #transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# 创建数据集
dataset = PhotoDataset(data_dir, transform=transform)

# 划分训练集和验证集
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
print(len(train_dataset))
print(len(val_dataset))
# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)


# 定义基于ResNet-18的分类器
class ResNet18Classifier(nn.Module):
    def __init__(self, num_classes=10):
        super(ResNet18Classifier, self).__init__()
        # 使用预训练的ResNet-18模型
        self.resnet = models.resnet18(pretrained=False)
        # 替换掉原有的fc层,以适应我们的分类任务
        self.resnet.fc = nn.Linear(self.resnet.fc.in_features, num_classes)

    def forward(self, x):
        # 直接使用ResNet的forward函数
        x = self.resnet(x)
        return x

# 自定义模型,搭积木一样
class LightweightClassifier(nn.Module):
    def __init__(self, num_classes=10):
        super(LightweightClassifier, self).__init__()
        # 定义卷积层
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        
        # 定义池化层
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # 定义全连接层
        self.fc1 = nn.Linear(64 * 32 * 32, 128)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        # 通过卷积层和池化层
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        
        # 展平特征图以输入到全连接层
        x = x.view(-1, 64 * 32 * 32)
        
        # 通过全连接层
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        
        return x


# 创建模型实例
model = ResNet18Classifier() #LightweightClassifier()
model.cuda()

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss().cuda()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

# 训练模型
num_epochs = 100
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, labels in train_loader:
        images = images.cuda()
        labels = labels.cuda()
        optimizer.zero_grad()
        outputs = model(images)
        #outputs = F.sigmoid(outputs)#二分类的话可以再加个sigmoid()
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f'Epoch {epoch + 1}, Loss: {running_loss / len(train_loader)}')

    # 验证模型
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in val_loader:
            images = images.cuda()
            labels = labels.cuda()
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        print(f'Validation Accuracy: {100 * correct / total}%')

其他

其实在这个框架中还有许多要素可以加入,比如:

  1. Loss收敛趋势可视化,用tensorboard;【todo】
  2. 学习率的动态调整,用各种warnup策略;【todo】
  3. 对数据集的预处理,用albumentations库;【todo】
  4. 使用T-SNE对数据集分布进行可视化;【todo】
  5. 使用CAM可视化对模型提取到的特征进行可视化;【todo】
  6. 模型权重的自动化保存策略,例如保存前10次效果最优的模型;【todo】

标签:__,nn,框架,模型,dataset,Pytorch,import,self,size
From: https://www.cnblogs.com/lwp-nicol/p/18136202

相关文章

  • 爆火 AI 硬件遭差评,Ai Pin 上市即翻车;Grok 推出首个多模态模型丨 RTE 开发者日报 Vol.
      开发者朋友们大家好: 这里是「RTE开发者日报」,每天和大家一起看新闻、聊八卦。我们的社区编辑团队会整理分享RTE(RealTimeEngagement)领域内「有话题的新闻」、「有态度的观点」、「有意思的数据」、「有思考的文章」、「有看点的会议」,但内容仅代表编辑......
  • LLM学习(1)——大模型简介
    1.1.1LLM的概念为了区分不同参数尺度的语言模型,研究界为大规模的PLM(例如,包含数百亿或数千亿个参数)创造了术语“大型语言模型”LLM。1.1.2LLM的能力与缩放定律LLM的能力涌现能力LLMs被正式定义为“在小型模型中不存在但在大型模型中出现的能力”,这是LLMs区别于以往PLM的最突......
  • 人工智能大模型的分类-来自智谱清言
    人工智能大模型可以根据不同的维度进行分类,以下是一些主要的分类方式:按照模型架构分类:深度神经网络(DNNs):包括多层感知机(MLPs)、卷积神经网络(CNNs)、循环神经网络(RNNs)、长短期记忆网络(LSTMs)和门控循环单元(GRUs)。Transformer模型:如BERT、GPT系列、Transformer-XL等,这些模型主要基......
  • 人工智能大模型的训练阶段和使用方式来分类
    是的,人工智能大模型也可以根据它们的训练阶段和使用方式来分类。以下是根据模型的阶段性来区分的一些类别:预训练模型:这些模型在大规模数据集上进行训练,以学习通用的特征表示。预训练可以是无监督的(如使用自编码或生成对抗网络),也可以是有监督的(如在大型标注数据集上进行训练)。......
  • 开箱即用的模型叫什么模型?有什么特点
    可以直接拿来就用的模型通常被称为“即用型模型”(Ready-to-UseModels)或“预训练模型”(Pre-TrainedModels)。这些模型已经被训练好了,用户可以直接下载并应用于自己的任务,而无需进行额外的训练。即用型模型通常具有以下特点:预训练:模型在大规模的数据集上进行了预训练,学习了通用的......
  • flask框架基础(1)
    flask基础一.开发模式flask是b/s(浏览器开发)开发模式二.flask七行代码fromflaskimportFlaskapp=Flask(_name_)@app.route("/")defindex():retun"打开此网页"if_name_=='_name':app.run()三.flask核心1.werkzeug负责后端2.jinja2负责前端......
  • css 盒子模型
    1.分类标准盒子模型content-box怪异盒子模型border-box2.示例代码<!DOCTYPEhtml><htmllang="en"><head><metacharset="UTF-8"><metaname="viewport"content="width=device-width,initial-scale=1......
  • 十款优质企业级Java微服务开源项目(开源框架,用于学习、毕设、公司项目,减少开发工作!)
     Java微服务开源项目前言一、pig二、zheng三、SpringBlade四、SOP五、matecloud六、mall七、jeecg-boot八、Cloud-Platform九、microservices-platform十、RuoYi-Cloud 前言这篇文章为大家推荐几款优质的Java开源项目框架,可以用于学习,毕业设计,公司项目......
  • Java微服务框架一览
    Java微服务框架一览微服务在开发领域的应用越来越广泛,因为开发人员致力于创建更大、更复杂的应用程序,而这些应用程序作为微小服务的组合能够更好地得以开发和管理。这些微小的服务可以组合在一起工作,并实现更大、应用更广泛的功能。现在出现了很多的工具来满足使用逐段法而不......
  • 时空图神经网络ST-GNN的概念以及Pytorch实现
    在我们周围的各个领域,从分子结构到社交网络,再到城市设计结构,到处都有相互关联的图数据。图神经网络(GNN)作为一种强大的方法,正在用于建模和学习这类数据的空间和图结构。它已经被应用于蛋白质结构和其他分子应用,例如药物发现,以及模拟系统,如社交网络。标准的GNN可以结合来自其他机器......