首页 > 其他分享 >pytorch的FashionMNIST

pytorch的FashionMNIST

时间:2022-12-31 11:32:17浏览次数:44  
标签:loss torch dataloader print pytorch FashionMNIST test model


目录

pytorch的FashionMNIST项目从加载数据到训练模型评估到模型保存模型加载及预测

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

training_data = datasets.FashionMNIST(root = "data", train = True, download = True, transform = ToTensor(),)
test_data = datasets.FashionMNIST(root = "data", train = False, download = True, transform = ToTensor(),)

batch_size = 64

train_dataloader = DataLoader(training_data, batch_size = batch_size)
test_dataloader = DataLoader(test_data, batch_size = batch_size)

for X,y in test_dataloader:
print(f"Shape of X [N, C, H, W]: {X.shape}")
print(f"Shape of y: {y.shape} {y.dtype}")
break

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")


class NeuralNetwork(nn.Module):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512,10)
)

def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits

model = NeuralNetwork().to(device)
print(model)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr = 1e-3)

def train(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
model.train()
for batch, (X,y) in enumerate(dataloader):
X, y = X.to(device), y.to(device)

pred = model(X)
loss = loss_fn(pred, y)

optimizer.zero_grad()
loss.backward()
optimizer.step()

if batch % 100==0:
loss, current = loss.item(), batch*len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")

def test(dataloader, model, loss_fn):
size = len(dataloader.dataset)
num_batches = len(dataloader)
model.eval()
test_loss, correct =0,0
with torch.no_grad():
for X, y in dataloader:
X,y = X.to(device),y.to(device)
pred = model(X)
test_loss +=loss_fn(pred,y).item()
correct+=(pred.argmax(1)==y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f}\n")

epochs = 5
for t in range(epochs):
print(f"Epoch {t+1}\n -------------")
train(train_dataloader, model, loss_fn, optimizer)
test(test_dataloader, model, loss_fn)
print("Done!")

torch.save(model.state_dict(), "model.pth")
print("Saved PyTorch Model State to model.pth")
model = NeuralNetwork()
model.load_state_dict(torch.load("model.pth"))
classes = [
"T-shirt/top",
"Trouser",
"Pullover",
"Dress",
"Coat",
"Sandal",
"Shirt",
"Sneaker",
"Bag",
"Ankle boot",
]

model.eval()
x, y = test_data[0][0], test_data[0][1]
with torch.no_grad():
pred = model(x)
predicted, actual = classes[pred[0].argmax(0)], classes[y]
print(f'Predicted: "{predicted}", Actual: "{actual}"')


标签:loss,torch,dataloader,print,pytorch,FashionMNIST,test,model
From: https://blog.51cto.com/u_15255081/5982066

相关文章

  • Pytorch优化过程展示:tensorboard
      训练模型过程中,经常需要追踪一些性能指标的变化情况,以便了解模型的实时动态,例如:回归任务中的MSE、分类任务中的Accuracy、生成对抗网络中的图片、网络模......
  • PyTorch学习笔记 7.TextCNN文本分类
    PyTorch学习笔记7.TextCNN文本分类​​一、模型结构​​​​二、文本分词与编码​​​​1.分词与编码器​​​​2.数据加载器​​​​二、模型定义​​​​1.卷积层​​......
  • pytorch模型onnx部署(python版本,c++版本)
    转载:实践演练BERTPytorch模型转ONNX模型及预测-知乎(zhihu.com)使用bRPC和ONNXRuntime把BERT模型服务化-知乎(zhihu.com)1.安装anaconda一般有图形界面的个人电......
  • ubuntu pytorch install
    nvidia驱动安装https://www.cnblogs.com/lif323/p/17014199.htmlconda安装下载.sh到该网站下载需要的.sh文件wgethttps://repo.anaconda.com/archive/Anaconda3-20......
  • pytorch:二分类时的loss选择
    PyTorch二分类时BCELoss,CrossEntropyLoss,Sigmoid等的选择和使用这里就总结一下使用PyTorch做二分类时的几种情况:总体上来讲,有三种实现形式:最后分类层降至一维,使用sigmo......
  • PyTorch模型保存与加载
    保存与加载整个模型保存整个模型,包括网络结构和权重参数,保存后的文件用torch.load()加载后的类型是定义的网络结构类,如classCNN:torch.save(model,"model.pkl")加载整......
  • PyTorch的Dataset 和TorchData API的比较
    深度神经网络需要很长时间来训练。训练速度受模型的复杂性、批大小、GPU、训练数据集的大小等因素的影响。在PyTorch中,torch.utils.data.Dataset和torch.utils.data.DataL......
  • Pytorch 动态图, Autograd, grad_fn详解
    Pytorch动态图Autogradgrad_fn详解Autogradrequire_grad具有传递性,会将其结果也引入计算图中requires_grad iscontagious.Itmeansthatwhena Tensor iscre......
  • 4个例子帮你梳理PyTorch的nn module
    本文延续前一篇文章的例子。只是例子一样,代码实现是逐步优化的,但是知识点没什么必然关联。几个例子帮你梳理PyTorch知识点(张量、autograd)nn计算图和autograd是定义复杂......
  • 神经网络--Tensflow vs pytorch框架比较
       a     ......