今天进行了CIFAR10的实战任务
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
#%%
transform = transforms.Compose(
[transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
train_set = torchvision.datasets.CIFAR10(root='E:\pytest\CIFAR_classficaion\data',train=True,download=True,
transform=transform)
train_loader = torch.utils.data.DataLoader(train_set,batch_size = 4,
shuffle = True,num_workers = 2)
test_set = torchvision.datasets.CIFAR10(root='E:\pytest\CIFAR_classficaion\data',train=False,download=True,
transform=transform)
test_loader = torch.utils.data.DataLoader(test_set,batch_size = 4,
shuffle = False,num_workers = 2)
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
#%%
def image_show(img):
img = img.numpy()
plt.imshow(np.transpose(img,(1,2,0)))
data_iter = iter(train_loader)
image,label = data_iter.__next__()
image_show(torchvision.utils.make_grid(image))
#%%
class net(nn.Module):
def __init__(self):
super(net,self).__init__()
self.conv1 = nn.Conv2d(3,6,5)
self.conv2 = nn.Conv2d(6,16,5)
self.pool = nn.MaxPool2d(2,2)
self.fc1 = nn.Linear(16*5*5,120)
self.fc2 = nn.Linear(120,84)
self.fc3 = nn.Linear(84,10)
def forward(self,x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
net = net()
#%%
import torch.optim as optim
l = nn.CrossEntropyLoss()
optim = optim.SGD(net.parameters(),lr = 0.001,momentum=0.9)
#%%
for epoch in range (2):
losses = []
for i,data in enumerate(train_loader,0):
inputs,labels = data
optim.zero_grad()
outputs = net(inputs)
loss = l(outputs,labels)
loss.backward()
optim.step()
if i % 100 == 0:
losses.append(loss.item())
print(loss.item())
plt.plot(losses)
plt.show
print("over")
#%%
correct = 0
total = 0
with torch.no_grad():
for image,labels in test_loader:
outputs = net(image)
loss = l(outputs,labels)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Accuracy: %d %%' % (100 * correct / total))
标签:总结,12,15,nn,self,torch,import,net,data
From: https://blog.51cto.com/u_16196891/8845356