首页 > 其他分享 >mnist数据集使用torch进行卷积训练

mnist数据集使用torch进行卷积训练

时间:2022-10-30 13:12:13浏览次数:48  
标签:evl nn 卷积 torch label model data mnist

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets.mnist as mnist
import torchvision.transforms as T
import torchvision.models as models
import collections
import math


transform = T.Compose([
    T.Grayscale(3), #灰度化并且转为三个通道
    T.ToTensor(),#numpy/pilimage 转为tonsor并且转为float,并且如果是int8 则除以255,形状hwc->chw
    T.Normalize(mean=[0.5],std= [1]) #标准化 y=(x-mean)/std
])

train_set = mnist.MNIST("dataset",train=True,download=True,transform= transform)
val_set = mnist.MNIST("dataset",train=False,download=True,transform=transform)


class Model(nn.Module):
    def __init__(self,num_class):
        super().__init__()
        self.backbone = nn.Sequential(
            nn.Conv2d(3,16,3),
            nn.ReLU(),
            nn.Conv2d(16,32,3),
            nn.ReLU(),
            nn.MaxPool2d(2,2),
            nn.Conv2d(32,32,3),
            nn.ReLU(),
            nn.MaxPool2d(2,2),
            nn.ReLU(),
            nn.Conv2d(32,64,3),
            # nn.AdaptiveAvgPool2d(1,1),
            nn.Flatten(),
            nn.LazyLinear(64,num_class)
        )

    def forward(self , x):
        return self.backbone(x)

torch.uint8
dataloder = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True, num_workers=0)
glr = 1e-3
model = Model(10).cuda()
optim = torch.optim.Adam(model.parameters(),glr)
fu_loss = nn.CrossEntropyLoss()
epchos =10

def evl_model(model,evl_data):
    model.eval()
    
    with torch.no_grad():
        right_data = 0
        total_num = len(evl_data)
        evl_data = torch.utils.data.DataLoader(evl_data,batch_size = 256,shuffle = False,num_workers = 0)
        for index,(image,label) in enumerate(evl_data):
            image = image.cuda()
            label = label.cuda()
            predict = model(image)
            predict = predict.argmax(dim=1)
            right_data += (predict == label).sum()
    return right_data.item()/total_num
model.train()
for epcho in range(epchos):

    for index,(image,label) in enumerate(dataloder):
        image = image.cuda()
        label = label.cuda()
        predict = model(image)
        loss = fu_loss(predict,label)
        optim.zero_grad()
        loss.backward()
        optim.step()
    acc = evl_model(model,evl_data=val_set)
    print(f"epcho = {epcho} loss = {loss} acc = {acc}")

 

标签:evl,nn,卷积,torch,label,model,data,mnist
From: https://www.cnblogs.com/xiaoruirui/p/16841055.html

相关文章