import torch
from d2l import torch as d2l
from torch import nn
batch_size = 100
train_iter , test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)
input_size = 784
hidden_size = 300
output_size = 10
W1 = nn.Parameter(
torch.randn(input_size , hidden_size , requires_grad = True)*0.01
)
b1 = nn.Parameter(
torch.randn(1 , hidden_size , requires_grad = True)*0.01
)
W2 = nn.Parameter(
torch.randn(hidden_size , output_size , requires_grad = True)*0.01
)
b2 = nn.Parameter(
torch.randn(1 , output_size , requires_grad = True)*0.01
)
params = [W1 , b1 , W2 , b2]
W1.shape , b1.shape , W2.shape , b2.shape
(torch.Size([784, 300]),
torch.Size([1, 300]),
torch.Size([300, 10]),
torch.Size([1, 10]))
def relu(X):
zero = torch.zeros_like(X)
return torch.max(X,zero)
relu( torch.randn(1,2) )
tensor([[2.3051, 0.0000]])
def net(X):
hid1 = relu((X.reshape(X.shape[0],-1))@W1 + b1)
return hid1@W2+b2
loss = nn.CrossEntropyLoss(reduction="none")
lr = 0.1
trainer = torch.optim.SGD(params , lr)
help(d2l.train_ch3)
Help on function train_ch3 in module d2l.torch:
train_ch3(net, train_iter, test_iter, loss, num_epochs, updater)
Train a model (defined in Chapter 3).
Defined in :numref:`sec_softmax_scratch`
num_epoch = 10
d2l.train_ch3(net , train_iter , test_iter ,loss , num_epoch , trainer )
重点函数
- torch.zeros_like(x) 创建与x的shape相同的零矩阵张量
简洁版
import torch
from torch import nn
from d2l import torch as d2l
batch_size = 100
train_iter , test_iter = d2l.load_data_fashion_mnist(batch_size)
net = nn.Sequential(nn.Flatten(),
nn.Linear(784,500),
nn.ReLU(),
nn.Linear(500,10)
)
lr = 0.1
loss = nn.CrossEntropyLoss(reduction="none")
trainer = torch.optim.SGD(net.parameters() , lr)
help(d2l.train_ch3)
Help on function train_ch3 in module d2l.torch:
train_ch3(net, train_iter, test_iter, loss, num_epochs, updater)
Train a model (defined in Chapter 3).
Defined in :numref:`sec_softmax_scratch`
num_epoch = 10
d2l.train_ch3(net,train_iter , test_iter , loss , num_epoch , trainer)
标签:nn,torch,iter,mlp,train,d2l,size
From: https://www.cnblogs.com/cndccm/p/18264872