首页 > 其他分享 >torch--minst手写体识别

torch--minst手写体识别

时间:2024-10-06 09:03:58浏览次数:7  
标签:loss plt -- minst torch 28 import out

utils.py

import torch
import matplotlib.pyplot as plt


def plot_curve(data):
    fig = plt.figure()
    plt.plot(range(len(data)), data, color='blue')
    plt.legend(['value'], loc='upper right')
    plt.xlabel('step')
    plt.ylabel('value')
    plt.show()


def plot_image(img, label, name):
    fig = plt.figure()
    for i in range(6):
        plt.subplot(2, 3, i+1)
        plt.tight_layout()
        plt.imshow(img[i][0] * 0.3081 + 0.1307, cmap='gray', interpolation='none')
        plt.title("{}:{}".format(name, label[i].item()))
        plt.xticks([])
        plt.yticks([])
    plt.show()


def one_hot(label, depth=10):
    out = torch.zeros(label.size(0), depth)
    idx = torch.LongTensor(label).view(-1, 1)
    out.scatter_(dim=1, index=idx, value=1)
    return out

main.py

"""
手写体数字识别MNIST
"""

import torch
from torch import nn
from torch.nn import functional as F
from torch import optim
import torchvision
from matplotlib import pyplot as plt
from utils import plot_image, plot_curve, one_hot


batch_size = 512

# step1. load dataset
train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('mnist data', train=True, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.1307,), (0.3081,))
                               ])),
    batch_size=batch_size, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('mnist_data/', train=False, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.1307,), (0.3081,))
                               ])),
    batch_size=batch_size, shuffle=True)

x, y = next(iter(train_loader))
print(x.shape, y.shape, x.min(), x.max())
plot_image(x, y, 'image sample')


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        # w*x + b
        self.fc1 = nn.Linear(28*28, 256)
        self.fc2 = nn.Linear(256, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        # x: [b, 1, 28, 28]
        # h1 = relu(x * w1 + b1)
        x = F.relu(self.fc1(x))
        # h2 = relu(h1 * w2 + b2)
        x = F.relu(self.fc2(x))
        # h3 = h2 * w3 + b3
        x = self.fc3(x)

        return x


net = Net()
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
train_loss = []

for epoch in range(3):
    for batch_idx, (x, y) in enumerate(train_loader):
        # x: [b, 1, 28, 28]    y: 512
        # [b, 1, 28, 28] =>   [b, 784]
        x = x.view(x.size(0), 28*28)
        out = net(x)
        y_onehot = one_hot(y)
        # loss = mse(out, y_onehot)
        loss = F.mse_loss(out, y_onehot)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss.append(loss.item())

        if batch_idx % 10 == 0:
            print(epoch, batch_idx, loss.item())

plot_curve(train_loss)
total_correct = 0
for x, y in test_loader:
    x = x.view(x.size(0), 28*28)
    out = net(x)
    pred = out.argmax(dim=1)
    correct = pred.eq(y).sum().float().item()
    total_correct += correct

total_num = len(test_loader.dataset)
acc = total_correct / total_num

print('test acc:', acc)

x, y = next(iter(test_loader))
out = net(x.view(x.size(0), 28*28))
pred = out.argmax(dim=1)
plot_image(x, pred, 'test')


标签:loss,plt,--,minst,torch,28,import,out
From: https://www.cnblogs.com/jackchen28/p/18448802

相关文章

  • 冲刺CSP联训模拟2
    冲刺CSP联训模拟2A.挤压考虑把一个数写成二进制,不妨记为$\sums_i\times2^i,s=0或1$,设其概率为$p_k$,则期望值:\[p_k\times(\sum_{i=0}^{29}s_i)^2=p_k\times\sum_{i=0}^{29}\sum_{j=0}^{29}s^i\timess^j\times2^{i+j}\]设$dp[i][j]$为异或后......
  • 读数据湖仓08数据架构的演化
    1. 数据目录1.1. 需要将分析基础设施放置在数据目录(DataCatalogue)的结构中1.1.1. 元数据1.1.2. 数据模型1.1.3. 本体1.1.4. 分类标准1.2. 数据目录类似于图书馆的图书检索目录1.2.1. 先通过图书馆的图书检索目录进行查找,以便快速找到所需的图书......
  • 博客园救园事件之反思
    博客园救园已经结束,虽然没有深度参与,但也算是见证了事情的发展,反思其经过,从中吸取一些经验和教训。首先,博客园这二十年来,始终坚持以用户为本,服务好用户群体,成为这次救园能够成功的一个关键因素,这是最值得我们学习的地方。其次,要不断跳出自己的舒适圈,不断改进,不断革新,要警惕温水......
  • Cisco Firepower 9300 Series FTD Software 7.6.0 & ASA Software 9.22.1
    CiscoFirepower9300SeriesFTDSoftware7.6.0&ASASoftware9.22.1FirepowerThreatDefense(FTD)Software-思科防火墙系统软件请访问原文链接:https://sysin.org/blog/cisco-firepower-9300/,查看最新版。原创作品,转载请保留出处。作者主页:sysin.orgCiscoSecure防......
  • 学期(如2024-2025-1) 20241304 《计算机基础与程序设计》第2周学习总结
    学期(如2024-2025-1)20241304《计算机基础与程序设计》第2周学习总结作业信息这个作业属于哪个课程<班级的链接>(2024-2025-1-计算机基础与程序设计)这个作业要求在哪里<作业要求的链接>(如2024-2025-1计算机基础与程序设计第二周作业)这个作业的目标<自学教材第一章......
  • 题解:P11008 『STA - R7』异或生成序列
    Solution序列\(p\)是\(1\)~\(n\)的排列,因此考虑搜索回溯。由\(\sumn\le2\times10^6\)得知\(O(n^2)\)会炸,深感遗憾但仍考虑剪枝。坚信深搜过百万的蒟蒻。。。原\(b\)序列为长度\(n-1\)的序列:{\(b_1,b_2,b_3\cdotsb_n-1\)}将其前面插入一个元素\(......
  • Cisco Firepower 4100 Series FTD Software 7.6.0 & ASA Software 9.22.1
    CiscoFirepower4100SeriesFTDSoftware7.6.0&ASASoftware9.22.1FirepowerThreatDefense(FTD)Software-思科防火墙系统软件请访问原文链接:https://sysin.org/blog/cisco-firepower-4100/,查看最新版。原创作品,转载请保留出处。作者主页:sysin.orgCiscoSecure防......
  • 从零开始学机器学习——网络应用
    首先给大家介绍一个很好用的学习地址:https://cloudstudio.net/columns今天,我们的主要任务是按照既定的流程再次运行模型,并将其成功加载到Web应用程序中,以便通过Web界面进行调用。最终生成的模型将能够基于UFO目击事件的数据和经纬度信息,推断出事件发生的城市地址。尽管经纬......
  • Cisco Firepower 1000 Series FTD Software 7.6.0 & ASA Software 9.22.1
    CiscoFirepower1000SeriesFTDSoftware7.6.0&ASASoftware9.22.1FirepowerThreatDefense(FTD)Software-思科防火墙系统软件请访问原文链接:https://sysin.org/blog/cisco-firepower-1000/,查看最新版。原创作品,转载请保留出处。作者主页:sysin.org面向小型办公室......
  • 分层图
    P4568分层图板子题(卡SPFA)include<bits/stdc++.h>definelllonglongusingnamespacestd;constintN=1e5+10;intn,m,k,s,t;llvis[N<<4],dis[N<<4];llans=INT_MAX;vector<pair<int,int>>g[N<<4];voiddijkstra(ints){prior......