首页 > 其他分享 >22-lenet网络

22-lenet网络

时间:2024-08-26 14:47:50浏览次数:10  
标签:22 nn metric 网络 iter train lenet device net

import torch
import torch.nn as nn
from d2l import torch as d2l

net = nn.Sequential(nn.Conv2d(1, 6, kernel_size=(5, 5), padding=2),
                    nn.Sigmoid(),
                    nn.AvgPool2d(kernel_size=(2, 2), stride=2),
                    nn.Conv2d(6, 16, kernel_size=(5, 5)),
                    nn.Sigmoid(),
                    nn.AvgPool2d(kernel_size=(2, 2), stride=2),
                    nn.Flatten(),
                    nn.Linear(16*5*5, 120),
                    nn.Sigmoid(),
                    nn.Linear(120, 84),
                    nn.Sigmoid(),
                    nn.Linear(84, 10))

X = torch.rand(size=(1, 1, 28, 28), dtype=torch.float32)
for layer in net:
    X = layer(X)
    print(layer.__class__.__name__, '------', X.shape)

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

# 估计模型的准确率
def evaluate_accuracy_gpu(net, data_iter, device = None):
    if isinstance(net, nn.Module):
        net.eval() # 停止dropout和梯度计算
        if device is None:
            device = next(iter(net.parameters())).device
    metric = d2l.Accumulator(2) # 0:正确预测的数量 1:总预测数量
    with torch.no_grad():
        for X, y in data_iter:
            if isinstance(X, list):
                X = [x.to(device) for x in X]
            else:
                X = X.to(device)
            y = y.to(device)
            metric.add(d2l.accuracy(net(X), y), y.numel()) # numel获取一共多少元素
    return metric[0] / metric[1]


def train(net, train_iter, test_iter, num_epochs, lr, device):
    # 参数初始化
    def init_weights(m):
        if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
            nn.init.xavier_uniform_(m.weight)

    net.apply(init_weights)
    net.to(device)

    # 定义损失函数和优化器
    optimizer = torch.optim.SGD(net.parameters(), lr=lr)
    loss = nn.CrossEntropyLoss()

    animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],
                            legend=['train loss', 'train acc', 'test acc'])
    timer, num_batches = d2l.Timer(), len(train_iter)

    for epoch in range(num_epochs):
        # 训练损失之和,训练准确率之和,样本数
        metric = d2l.Accumulator(3)
        net.train() # 开启训练模式
        for i, (X, y) in enumerate(train_iter):
            timer.start()
            X = X.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            y_pred = net(X)
            l = loss(y_pred, y)
            l.backward()
            optimizer.step()
            with torch.no_grad():
                metric.add(l * X.shape[0], d2l.accuracy(y_pred, y), X.shape[0])
            timer.stop()
            train_l = metric[0] / metric[2]
            train_acc = metric[1] / metric[2]
            if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
                animator.add(epoch + (i + 1) / num_batches,
                             (train_l, train_acc, None))
        test_acc = evaluate_accuracy_gpu(net, test_iter)
        animator.add(epoch + 1, (None, None, test_acc))
    print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, '
          f'test acc {test_acc:.3f}')
    print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec '
          f'on {str(device)}')


lr, num_epochs = 0.1, 10
train(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())

标签:22,nn,metric,网络,iter,train,lenet,device,net
From: https://www.cnblogs.com/morehair/p/18381025

相关文章

  • 24. 网络编程
    osi七层协议 每层常见物理设备 OSI(OpenSystemsInterconnection)七层协议是由国际标准化组织(ISO)制定的一个网络模型,它将网络系统划分为七个层次。以下是各层的详细解释: ‌物理层(PhysicalLayer)‌:这是OSI模型的最低层或第一层。它负责在物理媒体上传输原始的比特流......
  • Steam遭网络攻击 奇安信:很难不让人联想是针对《黑神话:悟空》
    8月24日晚,众多网友反映Steam无法登录,进不去游戏。随后,Steam中国区代理——完美世界竞技平台表示,此次Steam崩溃是由于受到大规模DDoS攻击导致。日前,奇安信XLab实验室发文披露并还原了本次攻击事件的幕后细节。据介绍,本次有近60个僵尸网络主控发起了此次DDoS攻击,攻击指令一......
  • “计算机专业 一定要优先报网络安全它是未来国家发展的大方向”
    前言“计算机专业一定要优先报网络安全它是未来国家发展的大方向”为什么推荐学网络安全?“没有网络安全就没有国家安全。”当前,网络安全已被提升到国家战略的高度,成为影响国家安全、社会稳定至关重要的因素之一。01高需求和就业前景:随着数字化进程的加速和网络攻......
  • 代码随想录day41 || 121 买卖股票最佳时机,122 买卖股票最佳时机||,123 买卖股票最佳时
    121买卖股票最佳时机funcmaxProfit(prices[]int)int{ //dp五部曲 //1dp数组以及下标含义dp[i][0]表示第i天持有股票dp[i][1]表示第i天不持有 //2递推公式,dp[i][0]=max(dp[i-1][0],0-price[i]) //dp[i][1]=max(dp[i-1][1],dp[i-1][0]+price[......
  • 世邦通信SPON IP网络对讲广播系统任意文件上传漏洞
    0x00漏洞编号暂无0x01危险等级高危0x02漏洞概述世邦通信SPONIP网络对讲广播系统采用领先的IPAudio™技术,将音频信号以数据包形式在局域网和广域网上进行传送,是一套纯数字传输系统。0x03漏洞详情漏洞类型:任意文件上传影响:上传恶意脚本简述:世邦通信SPONIP网络对讲广播......
  • 数字IP网络广播系统的特点和功能
    随着互联网数字化转型的发展大趋势,广播系统的主流方式也由传统模拟广播系统过渡到数字IP网络广播系统。数字IP网络公共广播,是将模拟音频信号数字编码,通过网络传输后,再由终端解码成模拟音频信号。可多路、单向或双向传输,局域网内延迟时间不超过100ms,并具有自动流量调整、声音补......
  • 计算机网络面试真题总结(四)
    文章收录在网站:http://hardyfish.top/文章收录在网站:http://hardyfish.top/文章收录在网站:http://hardyfish.top/文章收录在网站:http://hardyfish.top/什么是滑动窗口TCP滑动窗口是TCP协议中实现流量控制和可靠传输的关键机制。滑动窗口不仅可以防止发送端数据传输......
  • 2-网络攻击原理与常用方法
    2.1网络攻击概述1)概念:指损害网络系统安全属性的危害行为。危害行为导致网络系统的机密性、完整性、可用性、可控性、真实性、抗抵赖性等受到不同程度的破坏。常见的危害行为有四个基本类型:信息泄漏攻击完整性破坏攻击拒绝服务攻击非法使用攻击自治主体:攻击者初始化......
  • 【计算机网络】计算机网络的概念
    什么是计算机网络?计算机网络(Computernetworking)是一个将众多分散的、自治的计算机系统,通过通信设备与线路连接起来,由功能完善的软件实现资源共享和信息传递的系统。计算机网络、互连网、互联网的区别计算机网络(computernetworking)互连网(internet)互联网(因特网,Internet)计......
  • 2024年云南省职业院校技能大赛中职组“网络搭建与应用”赛项竞赛样卷
    2024年云南省职业院校技能大赛中职组“网络搭建与应用”赛项竞赛样卷文章目录2024年云南省职业院校技能大赛中职组“网络搭建与应用”赛项竞赛样卷第一部分:网络理论测试(100分)第二部分:网络建设与调试(400分)第三部分:服务搭建与运维(500分)竞赛说明一、竞赛内容分布......