首页 > 其他分享 >手撸代码:从零开始的 AlexNet 图像分类(PyTorch框架)

手撸代码:从零开始的 AlexNet 图像分类(PyTorch框架)

时间:2024-02-03 14:22:50浏览次数:25  
标签:从零开始 nn self loader PyTorch train out AlexNet size

摘要:

本文在 PyTorch 框架下搭建了 AlexNet ,并在 CIFAR10 上完成了图片分类。同时,更正了一些原论文中的小错误(如:输入图像尺寸)。由于 CIFAR10 没有验证集,本文将训练集的 10% 当作验证集。

完整代码已上传至 GitHub:https://github.com/TiezhuXing01/AlexNet_in_PyTorch


1. 引入库

import torch
import torch.nn as nn
import numpy as np
from torchvision import datasets
from torchvision import transforms
from torch.utils.data.sampler import SubsetRandomSampler

SubsetRandomSampler 是 PyTorch 中的一个采样器(sampler)。
具体可以看这篇文章:SubsetRandomSampler 是什么?


2. 选择设备

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

通常情况下,我们都会选择在GPU上训练网络模型,因为神经网络的训练需要大量的计算,而英伟达的GPU提供了CUDA(一个加速计算库)。但如果你的电脑显卡是AMD的,那么有很大概率不支持使用CUDA,此时只能用CPU训练。但在CPU上训练模型是十分缓慢的。如果你暂时没法换电脑,那我建议你去租一个服务器。或者使用阿里云、百度飞桨、谷歌Colab等平台。


3. 加载数据集

CIFAR-10 是一个经典的计算机视觉数据集,用于图像分类任务。它包含了来自 10 个不同类别的 60,000 张彩色图像,每个类别有 6,000 张图像。数据集被分为训练集和测试集,其中训练集包含 50,000 张图像,测试集包含 10,000 张图像。本文拿出训练集的 10% 作为验证集。

3.1 定义获取训练集和验证集的数据加载器

def get_train_val_loader(data_dir, batch_size, augment,
                         random_seed, valid_size = 0.1, shuffle = True):

  # ------------- 设置图像变换 ------------- #
  # (1) 归一化
  normalize = transforms.Normalize(mean = [0.4914, 0.4822, 0.4465],
                                   std = [0.2023, 0.1994, 0.2010])
  # (2) 验证集图像变换
  val_transform = transforms.Compose([transforms.Resize(227),
                                   transforms.ToTensor(),
                                   normalize])
  # (3) 训练集是否数据增强
  if augment:
    train_transform = transforms.Compose([transforms.RandomHorizontalFlip(),
                                      transforms.Resize(227),
                                      transforms.ToTensor(),
                                      normalize])
  else:
    train_transform = transforms.Compose([transforms.Resize(227),
                                       transforms.ToTensor(),
                                       normalize])
  # ---------- 

标签:从零开始,nn,self,loader,PyTorch,train,out,AlexNet,size
From: https://www.cnblogs.com/xing9/p/18002038/AlexNet-PyTorch

相关文章

  • KubeEdge EdgeMark 测试环境从零开始搭建
    https://blog.csdn.net/u010549795/article/details/132557648 EdgeMark测试环境从零开始搭建KubeEdge也提供了类似KubeMark的模拟大规模集群的工具,值得注意的是目前EdgeMark只能模拟edgecore,无法模拟edgemesh,所以如果是对网络方面的测试,还是建议老老实实装虚拟机环境配置使用v......
  • PyTorch神操作:一图秒懂Tensor变形记!
    亲爱的码农小伙伴们,你们是否还在为Tensor的各种变换头大如斗?别怕,今天给大家送上一张超实用的PyTorch变换秘籍图,让你的Tensor操作如行云流水,CPU和GPU之间的切换如穿梭自如!......
  • 【Docker】从零开始:9.Docker命令:Push推送仓库(Docker Hub,阿里云)
    【Docker】从零开始:9.Docker命令:Push推送仓库(DockerHub,阿里云):https://blog.csdn.net/sinat_36528886/article/details/134575054?utm_medium=distribute.pc_relevant.none-task-blog-2~default~baidujs_baidulandingword~default-1-134575054-blog-132139578.235^v43^pc_blog_bo......
  • Python中用PyTorch机器学习神经网络分类预测银行客户流失模型|附代码数据
    阅读全文:http://tecdat.cn/?p=8522最近我们被客户要求撰写关于神经网络的研究报告,包括一些图形和统计输出。分类问题属于机器学习问题的类别,其中给定一组特征,任务是预测离散值。分类问题的一些常见示例是,预测肿瘤是否为癌症,或者学生是否可能通过考试在本文中,鉴于银行客户的某些......
  • tacotron2:深度学习语音合成模型--pytorch
    https://www.python100.com/html/83067.html 一、tacotron2环境搭建如要安装tacotron2环境,需要完成以下步骤:1、安装CUDA。CUDA是Nvidia开发的并行计算平台和编程模型,需要前往官网下载并安装对应版本的CUDA,同时保证显卡支持CUDA。2、安装cuDNN。cuDNN是针对深度神经网络加速......
  • pytorch的模型推理:TensorRT的使用
    相关教程视频:TRTorch真香,一键启用TensorRT图片来源:https://www.bilibili.com/video/BV1TY411h7xC/图片来源:https://www.bilibili.com/video/BV1TY411h7xC/......
  • 华为显卡已经支持pytorch计算框架
    相关链接:https://support.huawei.com/enterprise/zh/doc/EDOC1100079287/a21c08dehttps://www.zhihu.com/question/624955377/answer/3240350483https://www.hiascend.com/document/detail/zh/ModelZoo/pytorchframework/pies/pies_00004.htmlAscend/pytorch项目地址:https:......
  • PyTorch中实现Transformer模型
    前言关于Transformer原理与论文的介绍:详细了解Transformer:AttentionIsAllYouNeed对于论文给出的模型架构,使用PyTorch分别实现各个部分。引入的相关库函数:importcopyimporttorchimportmathfromtorchimportnnfromtorch.nn.functionalimportlog_softmax......
  • 如何将PyTorch模型迁移到昇腾平台
    https://bbs.huaweicloud.com/blogs/399602?utm_source=cnblog&utm_medium=bbs-ex&utm_campaign=other&utm_content=content如何将PyTorch模型迁移到昇腾平台举报 昇腾CANN 发表于2023/04/1809:54:50  5k+  0  1 【摘要】本文介绍将PyTorch网络模型迁移到昇......
  • 从零开始教你手动搭建幻兽帕鲁私服( CentOS 版)
    哈喽大家好,我是咸鱼。想必上网冲浪的小伙伴最近都被《幻兽帕鲁》这款游戏刷屏了。(文中图片均来自网络,侵删)幻兽帕鲁是Pocketpair打造的一款开放世界的生存建造游戏。在游戏中,玩家捕捉各种各样的“帕鲁”。“帕鲁”在玩家支配下,完成不同的工作、任务,像极了现实中的打工人(......