首页 > 其他分享 >如何使用深度学习框架(PyTorch)来训练——147913张图像的超大超详细垃圾分类数据集,并附上详细训练代码和步骤。这个数据集包含4大类,345小类

如何使用深度学习框架(PyTorch)来训练——147913张图像的超大超详细垃圾分类数据集,并附上详细训练代码和步骤。这个数据集包含4大类,345小类

时间:2024-11-06 18:44:26浏览次数:3  
标签:loss 训练 torch loader 147913 running train 详细 model

超大超详细垃圾分类数据集(分类,分类),共4大类,345小类,147913张图,已全部分类标注完成,共12GB。

厨余垃圾 76小类 35058张
可回收物 195类 86116张
其他垃圾 53类 16156张
有害垃圾 18小类 10583张

 

如何使用深度学习框架(如PyTorch)来训练一个包含147913张图像的超大超详细垃圾分类数据集,并附上详细的训练代码和步骤。这个数据集包含4大类,345小类,已全部分类标注完成,总大小为12GB。

数据集描述

  • 数据量:147913张图像
  • 类别
    • 厨余垃圾:76小类,35058张
    • 可回收物:195小类,86116张
    • 其他垃圾:53小类,16156张
    • 有害垃圾:18小类,10583张
  • 总大小:12GB
  • 任务类型:图像分类

数据集组织

假设你的数据集目录结构如下:

garbage_classification_dataset/
├── train/
│   ├── food_waste/
│   │   ├── class1/
│   │   ├── class2/
│   │   └── ...
│   ├── recyclable/
│   │   ├── class1/
│   │   ├── class2/
│   │   └── ...
│   ├── other_waste/
│   │   ├── class1/
│   │   ├── class2/
│   │   └── ...
│   └── hazardous_waste/
│       ├── class1/
│       ├── class2/
│       └── ...
├── valid/
│   ├── food_waste/
│   │   ├── class1/
│   │   ├── class2/
│   │   └── ...
│   ├── recyclable/
│   │   ├── class1/
│   │   ├── class2/
│   │   └── ...
│   ├── other_waste/
│   │   ├── class1/
│   │   ├── class2/
│   │   └── ...
│   └── hazardous_waste/
│       ├── class1/
│       ├── class2/
│       └── ...
└── test/
    ├── food_waste/
    │   ├── class1/
    │   ├── class2/
    │   └── ...
    ├── recyclable/
    │   ├── class1/
    │   ├── class2/
    │   └── ...
    ├── other_waste/
    │   ├── class1/
    │   ├── class2/
    │   └── ...
    └── hazardous_waste/
        ├── class1/
        ├── class2/
        └── ...

安装依赖

确保你已经安装了必要的依赖库:

pip install torch torchvision matplotlib

数据加载和预处理

使用torchvision中的ImageFolder来加载数据集,并进行预处理:

import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# 定义数据预处理
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 加载数据集
train_dataset = datasets.ImageFolder(root='./garbage_classification_dataset/train', transform=transform)
valid_dataset = datasets.ImageFolder(root='./garbage_classification_dataset/valid', transform=transform)
test_dataset = datasets.ImageFolder(root='./garbage_classification_dataset/test', transform=transform)

# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)

模型定义

使用预训练的ResNet50模型进行迁移学习:

import torch.nn as nn
import torchvision.models as models

# 加载预训练的ResNet50模型
model = models.resnet50(pretrained=True)

# 替换最后一层全连接层
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, 345)  # 345个类别

# 将模型移动到GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

损失函数和优化器

定义损失函数和优化器:

import torch.optim as optim

# 定义损失函数
criterion = nn.CrossEntropyLoss()

# 定义优化器
optimizer = optim.Adam(model.parameters(), lr=0.001)

训练模型

编写训练循环:

def train_model(model, train_loader, valid_loader, criterion, optimizer, num_epochs=100):
    best_val_accuracy = 0.0

    for epoch in range(num_epochs):
        # 训练阶段
        model.train()
        running_loss = 0.0
        running_corrects = 0

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            # 前向传播
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            # 反向传播和优化
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            _, preds = torch.max(outputs, 1)
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)

        train_loss = running_loss / len(train_loader.dataset)
        train_accuracy = running_corrects.double() / len(train_loader.dataset)

        # 验证阶段
        model.eval()
        running_loss = 0.0
        running_corrects = 0

        with torch.no_grad():
            for inputs, labels in valid_loader:
                inputs, labels = inputs.to(device), labels.to(device)

                outputs = model(inputs)
                loss = criterion(outputs, labels)

                _, preds = torch.max(outputs, 1)
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

        val_loss = running_loss / len(valid_loader.dataset)
        val_accuracy = running_corrects.double() / len(valid_loader.dataset)

        print(f'Epoch {epoch + 1}/{num_epochs}')
        print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.4f}')
        print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_accuracy:.4f}')

        # 保存最佳模型
        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            torch.save(model.state_dict(), 'best_model.pth')

# 开始训练
train_model(model, train_loader, valid_loader, criterion, optimizer, num_epochs=100)

模型评估

训练完成后,可以使用以下代码评估模型在测试集上的表现:

def evaluate_model(model, test_loader, criterion):
    model.load_state_dict(torch.load('best_model.pth'))
    model.eval()

    running_loss = 0.0
    running_corrects = 0

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)

            _, preds = torch.max(outputs, 1)
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)

    test_loss = running_loss / len(test_loader.dataset)
    test_accuracy = running_corrects.double() / len(test_loader.dataset)

    print(f'Test Loss: {test_loss:.4f}, Test Acc: {test_accuracy:.4f}')

# 评估模型
evaluate_model(model, test_loader, criterion)

模型预测

你可以使用训练好的模型对新图像进行预测:

import matplotlib.pyplot as plt

def predict_image(image_path, model, transform, class_names):
    model.load_state_dict(torch.load('best_model.pth'))
    model.eval()

    image = Image.open(image_path).convert('RGB')
    image_tensor = transform(image).unsqueeze(0).to(device)

    with torch.no_grad():
        outputs = model(image_tensor)
        _, predicted = torch.max(outputs, 1)

    predicted_class = class_names[predicted.item()]
    plt.imshow(image)
    plt.title(f'Predicted: {predicted_class}')
    plt.show()

# 获取类别名称
class_names = [item[0].split('/')[-1] for item in train_loader.dataset.class_to_idx.items()]

# 预测新图像
predict_image('path/to/your/image.jpg', model, transform, class_names)

注意事项

  • 数据集质量:确保数据集的质量,包括清晰度、标注准确性等。
  • 模型选择:可以选择更强大的模型版本(如ResNet101、EfficientNet等)以提高性能。
  • 超参数调整:根据实际情况调整超参数,如批量大小(batch-size)、学习率(lr)等。
  • 监控性能:训练过程中监控损失函数和准确率指标,确保模型收敛。

通过上述步骤,你可以使用PyTorch来训练一个超大超详细的垃圾分类数据集,并使用训练好的模型进行预测。

标签:loss,训练,torch,loader,147913,running,train,详细,model
From: https://blog.csdn.net/2401_88440984/article/details/143421880

相关文章

  • 使用YOLOv5来训练——井盖状态检测数据集,并使用训练好的模型进行预测井盖状态检测数据
    井盖状态检测数据集yolo格式五种类别:broke(井盖破损),good(完好),circle(边圈破损),lose(井盖丢失),uncovered(井盖位移/未覆盖全)训练数据已划分,配置文件稍做路径改动即可训练。训练集:1217验证集:108 使用YOLOv5来训练一个包含1217张训练图像和108张验证图像的井盖状态检......
  • 【安全运维】检测即代码(DAC) 详细步骤
    原创Zafkie1SecLink安全空间引言DAC(DetectionAsCode),检测即代码是一种战略方法,可将安全检测机制无缝集成到软件开发生命周期中。通过将安全控制视为代码,组织可以在整个SIEM运维过程中自动部署、配置和维护安全措施。或许很多人听说过DAC的概念,但是并没有一步步地实现过......
  • CDGP|数据治理如何落地?多角度详细探讨
    数据治理是一个长期且复杂的体系化工程,它通过一系列流程规范、制度、IT能力以及持续运营等机制来保障治理工作的持续推进。落地数据治理需要从多个方面入手,本文将从组织建设、流程规范、IT平台以及持续运营等角度详细探讨。一、建立数据治理组织数据治理需要打破企业内部壁......
  • 使用 【Java】 集成 【Elasticsearch】:详细教程
    Elasticsearch是一个开源的分布式搜索引擎,它能够快速地存储、搜索和分析大量的文本数据。它基于ApacheLucene构建,广泛应用于日志分析、全文搜索、推荐系统等场景。本文将详细介绍如何在Java项目中集成Elasticsearch,包括如何配置、索引文档、查询数据、以及与Elasticsea......
  • 渗透测试入门教程(非常详细),从零基础入门到精通,看完这一篇就够了_渗透测试教程
    什么是渗透测试渗透测试就是模拟真实黑客的攻击手法对目标网站或主机进行全面的安全评估,与黑客攻击不一样的是,渗透测试的目的是尽可能多地发现安全漏洞,而真实黑客攻击只要发现一处入侵点即可以进入目标系统。一名优秀的渗透测试工程师也可以认为是一个厉害的黑客,也可以被......
  • Windows系统搭建ELK日志收集(详细版)
    一、ELK是什么?ELK是由Elasticsearch、Logstash、Kibana这3个软件的首字母缩写。ELK的大致工作顺序:应用程序产生log日志-->Logstash收集日志-->Logstash整理输出到Elasticsearch-->通过Kibana展示。ELK(Elasticsearch,Logstash,Kibana)是一个强大的开源数据分析和可视化平台,......
  • W外链如何创建短链接?详细操作步骤。
    根据搜索结果,创建微信外链并将长链接转换为短链接,可以通过使用W外链工具来实现。以下是使用W外链进行长链接转短链接的步骤:注册与登录:打开W外链平台的官方网站,注册一个账号,通常需要提供手机号、用户名、密码等信息。注册成功后,使用用户名和密码登录到平台。生成短链接:准备......
  • 一周搞定模电!(2) 超详细!!新手小白必看!
     目录稳压二极管整流二极管开关二极管电容1、什么是电容2、电容的作用2.1旁路的作用2.2去耦(退耦)电容的作用2.3滤波和储能3.电容在电路中的连接问题稳压二极管嵌入式系统,作为一种专用计算机系统,被广泛应用于各种设备和装置中,从智能手机到汽车,从家用电器到医......
  • 软件著作权申请教程(超详细)(2024新版)软著申请
                   目录一、注册账号与实名登记二、材料准备三、申请步骤1.办理身份2.软件申请信息3.软件开发信息4.软件功能与特点5.填报完成一、注册账号与实名登记    首先我们需要在官网里面注册一个账号,并且完成实名认证,一般是注册【个人......
  • 人工智能AI 产品经理与传统产品经理工作到底有什么不同?非常详细收藏我这一篇就够了
    一、AI产品经理的定义及职责范围AI产品经理是直接应用或间接涉及了AI技术,进而完成相关AI产品的设计、研发、推广、产品生命周期管理等工作的产品经理。具体来说,狭义AI产品经理直接应用了语义、语音、计算机视觉和机器学习这4个领域的AI技术,例如语义类AI产品......