首页 > 其他分享 >pytorch数据集MNIST训练与测试实例

pytorch数据集MNIST训练与测试实例

时间:2024-02-05 17:23:30浏览次数:25  
标签:loss torch 28 pytorch 实例 import input model MNIST

 

 

import os
import torch
import numpy as np
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import Compose,ToTensor,Normalize
from torch.optim import Adam
import torch.nn as nn
import torch.nn.functional as F


BATCH_SIZE = 128
TEST_BATCH_SIZE = 516
#1、准备数据集
def get_dataloader(train=True,batch_size=BATCH_SIZE):
    transform_fn = Compose([ToTensor(),Normalize(mean=(0.1307,),std=(0.3081,))]) #mean和std的形状和通道数相同
    dataset = MNIST(root='./data',train=train,download=False,transform=transform_fn)
    data_loader = DataLoader(dataset,batch_size=batch_size,shuffle=True)
    return data_loader
    # for i in enumerate(data_loader):
    #     print(i)

#2.构建模型
class MnisModel(nn.Module):
    def __init__(self):
        super(MnisModel,self).__init__()
        self.fc1 = nn.Linear(1*28*28,28)
        self.fc2 = nn.Linear(28,10)

    def forward(self,input):
        """
        input:[batch_size,1,28,28]
        """
        #1.修改形状
        x = input.view([input.size(0),1*28*28]) # input.view(-1,1*28*28)
        #2.进行全连接的操作
        x = self.fc1(x)
        #3.进行激活函数处理,形状不会发生变化
        x = F.relu(x)
        #4.输出层
        out = self.fc2(x)
        return F.log_softmax(out,dim=-1)

# 1.实例化模型
model = MnisModel()
#2.实例优化器类
optimizer = Adam(model.parameters(),lr=0.001)
if os.path.exists("./model/model.pt"):
    model.load_state_dict(torch.load("./model/model.pt"))  #加载模型
    optimizer.load_state_dict(torch.load("./model/optimizer.pt"))  #加载优化器

def train(epoch):
    """
    实现训练过程
    """
    #3.加载数据集,遍历
    data_loader = get_dataloader()
    for idex,(input,target) in enumerate(data_loader):
        optimizer.zero_grad()  #4.梯度置为0
        output = model(input)  #5.调用模型,得到预测值
        loss = F.nll_loss(output,target)  #6.计算损失
        loss.backward()  #7.反向传播
        optimizer.step()  #8.梯度的更新
        if idex % 100 == 0:
            print(loss.item())

        if idex % 100 == 0:
            torch.save(model.state_dict(),"./model/model.pt")  #保存模型参数
            torch.save(optimizer.state_dict(),"./model/optimizer.pt") #保存优化器参数

def test():  #测试数据
    loss_list = []
    acc_list = []
    test_dataloader = get_dataloader(False,batch_size=TEST_BATCH_SIZE)  #获取测试数据集
    for idx,(input,target) in enumerate(test_dataloader):
        # print(idx,target,input)
        # break
        with torch.no_grad():
            output = model(input)
            cur_loss = F.nll_loss(output,target)
            loss_list.append(cur_loss)
            #计算准备率
            #output [batch_size,10] target:[batch_size]
            pred = output.max(dim = -1)[-1]
            cur_acc = pred.eq(target).float().mean()
            acc_list.append(cur_acc)

    print("平均准确率,平均损失",np.mean(acc_list),np.mean(loss_list))


if __name__ == '__main__':
    # for i  in range(3):  #训练三轮
    #     train(i)
    test()

 

标签:loss,torch,28,pytorch,实例,import,input,model,MNIST
From: https://www.cnblogs.com/handsomeziff/p/18008513

相关文章

  • 【csh】makefile实例
    makefile实例:if(-e$1)thenforeachcell(`awk'{print}'$1`)make-f./makefileCELLNAME=$cell$argv[2-]endelsemake-f./makefileCELLNAME=$1$argv[2-]endif 重点是makefile文件可以串行提升效率: LAY_LIB="AA"SCH_LIB="BB&quo......
  • pytorch gather函数
    转载于:https://www.zhihu.com/question/562282138/answer/2947708508?utm_id=0官方文档链接:https://pytorch.org/docs/stable/generated/torch.gather.html#torch.gathertorch.gather()的定义非常简洁:在指定dim上,从原tensor中获取指定index的数据,看到这个核心定义,我们很容易......
  • JS——常用实例
    对话框输入,获取,计算,输出。<!DOCTYPEhtml><htmllang="en"><head><metacharset="UTF-8"><title>JS:操作HIML对象</title></head><body><!--1.两个输人框和一个输出框--><labelfor="1......
  • PyTorch下,使用list放置模块,导致计算设备不一的报错
    报错在复现Transformer代码的训练阶段时,发生报错:RuntimeError:Expectedalltensorstobeonthesamedevice,butfoundatleasttwodevices,cuda:0andcpu!解决方案通过next(linear.parameters()).device确定model已经在cuda:0上了,同时输入model.forward()的......
  • PyTorch 2.2 中文官方教程(十六)
    介绍torch.compile原文:pytorch.org/tutorials/intermediate/torch_compile_tutorial.html译者:飞龙协议:CCBY-NC-SA4.0注意点击这里下载完整的示例代码作者:WilliamWentorch.compile是加速PyTorch代码的最新方法!torch.compile通过将PyTorch代码JIT编译成优化的......
  • PyTorch 2.2 中文官方教程(十八)
    开始使用完全分片数据并行(FSDP)原文:pytorch.org/tutorials/intermediate/FSDP_tutorial.html译者:飞龙协议:CCBY-NC-SA4.0作者:HamidShojanazeri,YanliZhao,ShenLi注意在github上查看并编辑本教程。在大规模训练AI模型是一项具有挑战性的任务,需要大量的计算能力和资源......
  • PyTorch 2.2 中文官方教程(十九)
    使用RPC进行分布式管道并行原文:pytorch.org/tutorials/intermediate/dist_pipeline_parallel_tutorial.html译者:飞龙协议:CCBY-NC-SA4.0作者:ShenLi注意在github中查看并编辑本教程。先决条件:PyTorch分布式概述单机模型并行最佳实践开始使用分布式RPC框......
  • PyTorch 2.2 中文官方教程(二十)
    移动设备在iOS上进行图像分割DeepLabV3原文:pytorch.org/tutorials/beginner/deeplabv3_on_ios.html译者:飞龙协议:CCBY-NC-SA4.0作者:JeffTang审阅者:JeremiahChung介绍语义图像分割是一种计算机视觉任务,使用语义标签标记输入图像的特定区域。PyTorch语义图像分割De......
  • PyTorch 2.2 中文官方教程(十一)
    使用PyTorchC++前端原文:pytorch.org/tutorials/advanced/cpp_frontend.html译者:飞龙协议:CCBY-NC-SA4.0PyTorchC++前端是PyTorch机器学习框架的纯C++接口。虽然PyTorch的主要接口自然是Python,但这个PythonAPI坐落在一个庞大的C++代码库之上,提供了基础数据......
  • PyTorch 2.2 中文官方教程(十二)
    自定义C++和CUDA扩展原文:pytorch.org/tutorials/advanced/cpp_extension.html译者:飞龙协议:CCBY-NC-SA4.0作者:PeterGoldsboroughPyTorch提供了大量与神经网络、任意张量代数、数据处理和其他目的相关的操作。然而,您可能仍然需要更定制化的操作。例如,您可能想使用在论......