首页 > 编程语言 >李沐《动手学深度学习》多层感知机python代码实现

李沐《动手学深度学习》多层感知机python代码实现

时间:2024-11-06 21:17:32浏览次数:3  
标签:nn python torch iter 感知机 num lr d2l 李沐

一、多层感知机手动实现

# 多层感知机的手动实现
%matplotlib inline
import torch 
from torch import nn
from d2l import torch as d2l

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

num_inputs, num_outputs, num_first_hiddens = 784, 10, 256

W1 = nn.Parameter(
    torch.randn(num_inputs, num_first_hiddens, requires_grad=True)*0.01)

b1 = nn.Parameter(torch.zeros(num_first_hiddens, requires_grad=True))

W2 = nn.Parameter(
    torch.randn(num_first_hiddens, num_outputs, requires_grad=True)*0.01)

b2 = nn.Parameter(torch.zeros(num_outputs, requires_grad=True))

params = [W1, b1, W2, b2]

# relu函数,输入是隐藏层W1*X + b1
def relu(X):
    zero_x = torch.zeros(X.shape)
    return torch.max(X, zero_x)

# 模型函数,输入是数据X
def net(X):
    X = X.reshape((-1, num_inputs))
    H = relu(X@W1 + b1)
    return (H@W2 + b2)

loss = nn.CrossEntropyLoss(reduction='none')

num_epochs = 10
lr = 0.1
updater = torch.optim.SGD(params, lr=lr)

d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, updater)

二、利用torch框架实现多层感知机

# 多层感知机的简洁实现
%matplotlib inline
import torch
from d2l import torch as d2l
from torch import nn

def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01)
        
net = nn.Sequential(nn.Flatten(), # 先把输入数据(1,28,28)展开为(1,784)
                    nn.Linear(784, 256),  # 784输入->256输出的隐藏层
                    nn.ReLU(),  # 对隐藏层的输出再做一个ReLu函数
                    nn.Linear(256, 10)) # 输出层
net.apply(init_weights)

lr = 0.1
batch_size = 256
num_epochs = 10
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

updater = torch.optim.SGD(net.parameters(), lr=lr)
loss = nn.CrossEntropyLoss(reduction='none')

d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, updater)

标签:nn,python,torch,iter,感知机,num,lr,d2l,李沐
From: https://blog.csdn.net/yuzixuan233/article/details/143575182

相关文章

  • 李沐《动手学深度学习》线性回归python代码实现
    一、手动实现线性回归#线性回归的手动实现%matplotlibinlineimporttorchimportrandomfromd2limporttorchasd2l#随机按照参数w和b外加一些噪音来创造训练数据集data和labelsdefsynthetic_data(w,b,num_examples):X=torch.normal(0,1,(num_example......
  • Python学习笔记-生成器的应用与原理
    生成器是Python中一种特殊的迭代工具,通过延迟计算的方式来逐步生成序列中的元素。这种特性使得生成器在处理大数据、无限序列或需要惰性求值的场景中十分有效。生成器的核心思想是通过yield语句逐步返回值,暂停并保留当前状态,直到下次调用继续执行,从而节省内存并优化性能......
  • Python学习笔记-断点操作结合异常处理
    在编程中,调试和错误处理是提升代码质量和开发效率的关键环节。调试能帮助识别并修复问题,异常处理则使得程序能在出现错误时有效地管理而不至于崩溃。断点与异常处理的结合应用是高级编程中不可或缺的技巧,能够帮助更高效地定位问题,提高程序的鲁棒性。本文将通过详细的断点和......
  • Python——数据结构与算法-时间复杂度&空间复杂度-链表&树状结构
    1.数据结构和算法简介程序可以理解为:程序=数据结构+算法概述/目的:都可以提高程序的效率(性能)数据结构指的是存储,组织数据的方式.算法指的是为了解决实际业务问题而思考思路和方法,就叫:算法.2.算法的5大特性介绍概述:为了解决实际业务问题,......
  • python面向对象(一)
    前言Python是一种功能强大的编程语言,因其简洁的语法和强大的库而备受开发者喜爱。而在Python中,面向对象编程(OOP)是一种核心的编程范式,它通过模拟现实世界中的对象和交互来帮助我们设计清晰、易维护的代码。在本篇博客中,我们将深入探讨Python的面向对象编程的基本概念,了解如......
  • 【PAT_Python解】1114 全素日
    原题链接:PTA|程序设计类实验辅助教学平台Tips:以下Python代码仅个人理解,非最优算法,仅供参考!多学习其他大佬的AC代码!defis_prime(n):ifn<=3:returnn>=2ifn%6notin(5,1):returnFalseforiinrange(5,int(n**0.5)+1,6):......
  • 【PAT_Python解】1110 区块反转
    原题链接:PTA|程序设计类实验辅助教学平台Tips:以下Python代码仅个人理解,非最优算法,仅供参考!多学习其他大佬的AC代码!importsys#读取输入head,n,k=map(int,sys.stdin.readline().split())#初始化数据,装入字典,最终取值data={}next={}for_inrange(n):......
  • 【PAT_Python解】1113 钱串子的加法
    原题链接:PTA|程序设计类实验辅助教学平台Tips:以下Python代码仅个人理解,非最优算法,仅供参考!多学习其他大佬的AC代码!defadd_base30(num1,num2):max_length=max(len(num1),len(num2))#在前面补零,使两个字符串长度相同num1=num1.zfill(max_lengt......
  • 0基础学Python——类的单例模式、反射函数、记录类的创建个数、迭代器、生成器及生成
    0基础学Python——类的单例模式、反射函数、记录类的创建个数、迭代器、生成器及生成器练习类的单例模式定义代码演示反射函数代码演示记录类的创建个数迭代器定义特点生成器定义特点写法生成器练习生成器生成1-无穷的数字生成器生成无穷个素数类的单例模式定义......
  • 0基础学Python——面向对象-可迭代、面向对象-迭代器、call方法、call方法实现装饰器
    0基础学Python——面向对象-可迭代、面向对象-迭代器、call方法、call方法实现装饰器、计算函数运行时间面向对象--可迭代实现方法面向对象--迭代器实现方法call方法作用call方法实现装饰器代码演示计算函数运行时间代码演示面向对象–可迭代把对象看做容器,存储......