首页 > 其他分享 >深度学习基础框架通用模板 (Pytorch Template) - cifar10 图片分类为例,深度学习模板

深度学习基础框架通用模板 (Pytorch Template) - cifar10 图片分类为例,深度学习模板

时间:2025-01-08 23:04:49浏览次数:3  
标签:10 训练 为例 py cifar10 CIFAR 深度 数据 模板

文章目录

项目简介

项目链接:https://github.com/tangpan360/pytorch_template

本项目是一个基于 PyTorch 的深度学习基础框架,旨在帮助用户快速实现自己的训练模型。通过替换数据集和数据预处理等模块,用户可以专注于模型开发和实验,而无需花费大量时间在基础功能的实现上,比如:

  • 可视化(loss 和 acc 变化曲线)
  • 模型早停机制(Early Stopping)
  • 随机种子设置
  • 数据加载和预处理
  • 训练日志记录

框架结构清晰、模块化设计,便于扩展和复用,同时包含了一些常用的深度学习工具和方法。既适合新手快速上手,也适合高级用户构建自己的实验框架。


运行结果展示

  1. 终端输出:
    在这里插入图片描述
  2. 损失和准确率可视化实时更新:
    在这里插入图片描述
    或者:
    在这里插入图片描述
  3. 参数配置: 在这里插入图片描述

文件和目录结构说明

pytorch_template
├── checkpoints
│   └── best_model.pth                 # 模型训练后的最佳权重文件
├── dataset
│   ├── processed
│   │   └── cifar-10
│   │       ├── test_data/             # 处理后的测试集数据
│   │       ├── train_data/            # 处理后的训练集数据
│   │       ├── val_data/              # 处理后的验证集数据
│   │       ├── full_train_annotations.csv  # 完整训练集的标签文件
│   │       ├── test_annotations.csv   # 测试集的标签文件
│   │       ├── train_annotations.csv  # 训练集的标签文件
│   │       └── val_annotations.csv    # 验证集的标签文件
│   └── raw
│       ├── cifar-10-batches-py        # CIFAR-10 原始数据解压后的目录
│       │   ├── batches.meta
│       │   ├── data_batch_1
│       │   ├── data_batch_2
│       │   ├── data_batch_3
│       │   ├── data_batch_4
│       │   ├── data_batch_5
│       │   ├── readme.html
│       │   └── test_batch
│       └── cifar-10-python.tar.gz     # CIFAR-10 数据的原始压缩包
├── logs
│   ├── train_log.txt                  # 训练日志
│   └── training_metrics.jsonl         # 训练过程中的指标记录(JSON 行格式)
├── models
│   ├── AlexNet.py                     # AlexNet 模型定义
│   ├── LeNet.py                       # LeNet 模型定义
│   └── VGGNet.py                      # VGGNet 模型定义
├── preprocess_scripts
│   ├── convert_cifar10_to_image.py    # 脚本:将 CIFAR-10 数据转换为图片格式
│   ├── download_cifar10.py            # 脚本:下载 CIFAR-10 数据集
│   └── generate_annotations.ipynb     # 脚本:生成训练/验证/测试集的标签文件
├── utils
│   ├── __init__.py                    # 工具模块的初始化文件
│   ├── cifar10_dataset.py             # 自定义 CIFAR-10 数据集加载工具
│   ├── early_stopping.py              # 早停机制的实现
│   ├── seed_utils.py                  # 随机种子设置工具
│   ├── time_utils.py                  # 时间处理工具
│   └── trainer.py                     # 训练流程封装工具
├── visualization
│   └── visualization_loss_acc.ipynb   # 可视化脚本:展示 loss 和 acc 曲线
├── requirements.txt
├── README.md
├── main.py                            # 主入口:训练脚本
└── run_main.sh                        # 运行训练的 Shell 脚本

功能模块详解

1. 数据相关

  • dataset/raw:存放原始数据集文件(如 CIFAR-10 的压缩包)。
  • dataset/processed:存放预处理后的数据集文件,包括训练集、测试集、验证集和对应的标签文件。

预处理脚本

  • preprocess_scripts/download_cifar10.py:下载并解压 CIFAR-10 数据集。
  • preprocess_scripts/convert_cifar10_to_image.py:将 CIFAR-10 数据集转换为图片格式。
  • preprocess_scripts/generate_annotations.ipynb:生成训练/验证/测试集的标签文件。

2. 模型相关

  • models:存放常见深度学习模型的定义文件。
    • AlexNet.py:AlexNet 模型的实现。
    • LeNet.py:LeNet 模型的实现。
    • VGGNet.py:VGGNet 模型的实现。

您可以在该目录下添加或修改自己的模型文件。

3. 工具函数

  • utils:存放训练与辅助功能的实现,包括:
    • cifar10_dataset.py:自定义数据集类,用于加载和处理 CIFAR-10 数据。
    • early_stopping.py:实现 Early Stopping,用于防止过拟合。
    • seed_utils.py:随机种子设置工具,确保实验结果可重复。
    • time_utils.py:时间工具,用于规范化地输出或计算时间。
    • trainer.py:封装训练流程的工具,用于简化训练代码。

4. 可视化

  • visualization:存放可视化相关脚本。
    • visualization_loss_acc.ipynb:通过 Jupyter Notebook 可视化训练过程中的 loss 和 acc 曲线。

5. 训练和日志

  • logs:存放训练过程中生成的日志和指标记录文件。
    • train_log.txt:记录训练过程中的日志信息(损失、精度等)。
    • training_metrics.jsonl:以 JSON 行格式保存的训练指标记录。
  • checkpoints:存放模型训练过程中保存的权重文件(如最佳模型 best_model.pth)。

6. 主程序

  • main.py:框架主入口,负责整体训练流程。您可以根据需要修改或扩展该文件。
  • run_main.sh:Shell 脚本,用于一键运行 main.py

使用方法

1. 克隆项目

git clone https://github.com/tangpan360/pytorch_template.git
cd pytorch_template

2. 创建并激活 Python 3.9 虚拟环境

conda create -n pytorch python=3.9
conda activate pytorch

3. 安装依赖

pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple

4. 安装 Jupyter 及相关依赖

如果需要在 Jupyter Notebook 中执行相关脚本,可以安装:

conda install jupyter

若需在 Notebook 中管理多个 conda 环境,可安装:

conda install nb_conda_kernels

2. 数据准备

  1. 下载并解压 CIFAR-10 数据集:
    cd preprocess_scripts
    python download_cifar10.py
    
  2. 将 CIFAR-10 数据集转换为图片格式:
    python convert_cifar10_to_image.py
    
  3. 生成训练/验证/测试集的标签文件:
    # 打开并执行 generate_annotations.ipynb
    

3. 开始训练

进入pytorch_template文件夹根目录,直接运行主程序:

python main.py

或通过 Shell 脚本运行:

bash run_main.sh

4. 可视化结果

在 Jupyter Notebook 中查看训练过程的 Loss 和 Accuracy 曲线:

# 进入 visualization 文件夹,打开并执行 visualization_loss_acc.ipynb

快速替换自己的数据集

  1. 将自有数据放入 dataset/raw 目录,并根据情况修改预处理脚本。
  2. 替换 cifar10_dataset.py 中的数据加载逻辑。
  3. 调整 main.py 中的训练和验证流程,使其适配新的数据集。

TODO

  • 增加更多预训练模型支持(如 ResNet、Transformer 等)。
  • 支持多 GPU 训练。
  • 增加更多数据增强功能。
  • 支持 TensorBoard 可视化。

参考与鸣谢

标签:10,训练,为例,py,cifar10,CIFAR,深度,数据,模板
From: https://blog.csdn.net/weixin_51524504/article/details/144975949

相关文章

  • 代码精简之路-模板模式
    1.前言程序员怕重复CRUD,总是做一些简单繁琐的事情。“不要重复造轮子”,“把基础功能提炼出来封装成工具类”我喜欢把这些话挂在嘴边,写起来常不知从何下手。下面拆解一个项目中的功能。记录从复制粘贴到对业务抽象、实现功能分层的详细过程。如何着手提升代码重构优化能力,拿到......
  • 深度学习目标检测中_构建一个基于YOLOv8的道路裂缝检测系统来处理道路裂缝数据集 4类
    道路裂缝数据集数据集共21041张道路图像,涉及3000+道路损坏实例,数据集包含四种损伤类别的注释:纵向裂缝D00、横向裂缝D10、鳄鱼裂缝D20和坑洞D40;已标注yolo格式、voc格式,可直接用于训练;标签类别及标签个数:D00(6592)、D10(4446)、D20(8381)、D40(5627)构建一个基于YOLOv8的道......
  • 深度学习目标检测使用YOLOv8来训练航拍遥感飞机数据集 yolo
    航拍遥感飞机数据集Yolo格式标注深度学习目标检测使用YOLOv8来训练航拍遥感飞机数据集。以下是详细的步骤和代码示例,包括环境部署、模型训练、指标可视化展示以及PyQt5界面设计。文章代码仅供参考:数据集结构假设你的数据集已经准备好,并且是以YOLO格式存储的。以下......
  • 玩转LangChain:从模型调用到Prompt模板与输出解析的完整指南
    系列文章目录01-玩转LangChain:从模型调用到Prompt模板与输出解析的完整指南文章目录系列文章目录前言一、LangChain环境搭建与初始配置1.1安装依赖1.2环境变量加载1.2.1具体步骤1.2.2注意事项1.3初始化模型客户端二、基础示例:与模型交互2.1直接调用模型2.1.1......
  • 2025 GitCode 开发者冬日嘉年华:AI 与开源的深度交融之旅
    在科技的浪潮中,AI技术与开源探索的火花不断碰撞,催生出无限可能。2025年1月4日,由GitCode联合CSDNCOC城市开发者社区精心打造的开年首场开发者活动:冬日嘉年华在北京中关村•鼎好DH3-A座22层盛大举行,为AI技术爱好者和开源世界探索者带来了一场别开生面的交流派......
  • 【02】优雅草央央逆向技术篇之逆向接口协议篇-以小红书为例-python逆向小红书将用户名
    【02】优雅草央央逆向技术篇之逆向接口协议篇-以小红书为例-python逆向小红书将用户名转换获得为uid-优雅草央千澈背景本次学习逆向是针对小红书的用户转uid学习使用,逆向工程应当在合法和道德的范围内进行,尊重他人的知识产权和隐私权。要在小红书(Red)中将用户名转换为用户ID(U......
  • Zabbix安装,配置模板监控主机(在线安装和离线安装)
    安装环境组件版本LinuxRockyLinux8ZabbixZabbix6.0MySql8.0.30(根据Zabbix6搭建时的经验,MySql版本太低的话会不兼容Zabbix,但当时的报错找不到了,所以未能在此展示)IP监控192.168.88.1&&被监控192.168.88.2安装步骤#离线安装https://repo.zabbix.com/zabbix/7.0/rocky/......
  • Java HashMap 深度解析:底层原理、源码剖析与面试必备知识
    1.HashMap概述HashMap是Java集合框架中最常用的数据结构之一,基于哈希表(HashTable)实现。它以键值对(Key-Value)存储数据,允许null键和null值,且无序。1.1HashMap的特性基于哈希表(HashTable)实现允许null键和null值非线程安全默认初始容量16,负载因子0.75JDK1......
  • 使用LangChain模板在Amazon Bedrock上配置Anthropic‘s Claude作为聊天机器人
    文章目录概要整体架构流程技术名词解释技术细节小结概要提示:这里可以添加技术概要例如:openAI的GPT大模型的发展历程。整体架构流程提示:这里可以添加技术整体架构例如:在语言模型中,编码器和解码器都是由一个个的Transformer组件拼接在一起形成的。技术......
  • 【深度学习|变化检测】如何理解基于门控注意力的池化层及其与快速水平集演化结合的方
    【深度学习|变化检测】如何理解基于门控注意力的池化层及其与快速水平集演化结合的方式?附代码(二)【深度学习|变化检测】如何理解基于门控注意力的池化层及其与快速水平集演化结合的方式?附代码(二)文章目录【深度学习|变化检测】如何理解基于门控注意力的池化层及其与快速......