首页 > 其他分享 >学习笔记9:卷积神经网络实现MNIST分类(GPU加速)

学习笔记9:卷积神经网络实现MNIST分类(GPU加速)

时间:2024-06-04 09:22:37浏览次数:19  
标签:acc loss 卷积 self torch epoch test GPU MNIST

转自:https://www.cnblogs.com/miraclepbc/p/14345342.html

相关包导入

import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from torch import nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
import torchvision
from torchvision import datasets, transforms
%matplotlib inline

设置device

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

如果cuda是可用的,那么就使用"cuda:0",否则使用"cpu"

数据加载

transformation = transforms.Compose([
    transforms.ToTensor(),       ## 转化为一个tensor, 转换到0-1之间, 将channnel放在第一位
])

train_ds = datasets.MNIST(
    'E:/datasets2/1-18/dataset/daatset',
    train = True,
    transform  =transformation,
    download = True
)

test_ds = datasets.MNIST(
    'E:/datasets2/1-18/dataset/daatset',
    train = False,
    transform = transformation,
    download = True
)

train_dl = DataLoader(train_ds, batch_size = 64, shuffle = True)
test_dl = DataLoader(test_ds, batch_size = 258)

模型定义

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        #参数分别为n_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True
        self.pool = nn.MaxPool2d((2, 2))
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.linear_1 = nn.Linear(16 * 4 * 4, 256)
        self.linear_2 = nn.Linear(256, 10)
    def forward(self, input):
        x = F.relu(self.conv1(input))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        # print(x.size())
        x = x.view(-1, 16 * 4 * 4)
        x = F.relu(self.linear_1(x))
        x = self.linear_2(x)
        return x

loss_func = torch.nn.CrossEntropyLoss()

这里需要注意一点是,卷积、池化之后是不知道数据的shape的,因此可以采用print的方法,测试一下
具体来说,就是先在全连接层的维度那里随便设置值,然后打印一下
在输出框里,会出现正确的值,这时再将之前随便设置的值修正过来即可

模型训练

def fit(epoch, model, trainloader, testloader):
    correct = 0
    total = 0
    running_loss = 0
    for x, y in trainloader:
        x, y = x.to(device), y.to(device)
        y_pred = model(x)
        loss = loss_func(y_pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        with torch.no_grad():
            y_pred = torch.argmax(y_pred, dim = 1)
            correct += (y_pred == y).sum().item()
            total += y.size(0)
            running_loss += loss.item()

    epoch_acc = correct / total
    epoch_loss = running_loss / len(trainloader.dataset)
    
    test_correct = 0
    test_total = 0
    test_running_loss = 0
    
    with torch.no_grad():
        for x, y in testloader:
            x, y = x.to(device), y.to(device)
            y_pred = model(x)
            loss = loss_func(y_pred, y)
            y_pred = torch.argmax(y_pred, dim = 1)
            test_correct += (y_pred == y).sum().item()
            test_total += y.size(0)
            test_running_loss += loss.item()
    epoch_test_acc = test_correct / test_total
    epoch_test_loss = test_running_loss / len(testloader.dataset)
    
    print('epoch: ', epoch, 
          'loss: ', round(epoch_loss, 3),
          'accuracy: ', round(epoch_acc, 3),
          'test_loss: ', round(epoch_test_loss, 3),
          'test_accuracy: ', round(epoch_test_acc, 3))
    
    return epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc

model = Model()
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr = 0.001)
epochs = 20

train_loss = []
train_acc = []
test_loss = []
test_acc = []
for epoch in range(epochs):
    epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc = fit(epoch, model, train_dl, test_dl)
    train_loss.append(epoch_loss)
    train_acc.append(epoch_acc)
    test_loss.append(epoch_test_loss)
    test_acc.append(epoch_test_acc)

这里需要注意的地方是,如果要调用gpu,那么需要将模型和数据都转移到gpu上
因此,需要调用.to(device)方法进行转移

训练结果

标签:acc,loss,卷积,self,torch,epoch,test,GPU,MNIST
From: https://www.cnblogs.com/gongzb/p/18230128

相关文章