交叉验证
在实际情况中,数据集是分为训练集和测试集的。而测试集通常被用户保留,并不对外公开,以防止在测试模型时作弊,故意使用让模型效果更好的数据进行测试,以至于模型遇上新的数据效果很差。
于是我们通常将训练集进行分割,一部分用于训练,一部分用以测试,这里的测试其实叫做验证。
由于数据集的一部分用以测试而获取不到,这部分的数据损失可能对训练结果造成影响。为了减小这种影响,我们充分的利用可用的数据集,将训练集中的不同部分作为验证集,进行交叉验证,以减小验证集选取的偶然性对结果造成的影响。
import torch
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
batch_size = 200
learning_rate = 0.01
epochs = 10
# 加载训练集 60k
train_db = datasets.MNIST('../data',train=True,download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,),(0.3081,))
]
))
train_loader = torch.utils.data.DataLoader(
train_db,
batch_size=batch_size,
shuffle=True
)
# 加载测试集 10k
test_db = datasets.MNIST('../data',train=False,download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,),(0.3081,))
])
)
test_loader = torch.utils.data.DataLoader(
test_db,
batch_size=batch_size,
shuffle=True
)
# 打印训练集(60k)和测试集(10k)的大小
print('train:',len(train_db),'test:',len(test_db))
# 将训练数据集(60k)划分为训练集(50k)和验证集(10k)
train_db,val_db = torch.utils.data.random_split(train_db,[50000,10000])
# 打印训练集和测试集的大小
print('train_db:',len(train_db),'val_db:',len(val_db))
train_loader = torch.utils.data.DataLoader(
train_db,
batch_size=batch_size,
shuffle=True
)
val_loader = torch.utils.data.DataLoader(
val_db,
batch_size=batch_size,
shuffle=True
)
class MLP(torch.nn.Module):
def __init__(self):
super(MLP,self).__init__()
self.model = torch.nn.Sequential(
torch.nn.Linear(784,200),
torch.nn.LeakyReLU(inplace=True),
torch.nn.Linear(200, 200),
torch.nn.LeakyReLU(inplace=True),
torch.nn.Linear(200, 10),
torch.nn.LeakyReLU(inplace=True),
)
def forward(self,x):
x = self.model(x)
return x
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
net = MLP().to(device)
optimizer = optim.SGD(net.parameters(),lr=learning_rate)
criteon = torch.nn.CrossEntropyLoss()
for epoch in range(epochs):
for batch_idx,(data,target) in enumerate(train_loader):
data = data.view(-1,28*28).to(device)
target =target.to(device)
logits = net(data)
loss = criteon(logits,target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (batch_idx % 100) == 0:
print('Train Epoch:{} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch,
batch_idx * len(data),
len(train_loader.dataset),
100. * batch_idx / len(train_loader),
loss.item()
))
test_loss = 0
correct = 0
for data,target in val_loader:
data = data.view(-1,28 * 28).to(device)
target = target.to(device)
logits = net(data)
test_loss += criteon(logits,target).item()
pred = logits.data.max(1)[1]
correct += pred.eq(target).sum().item()
test_loss /= len(val_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss,
correct,
len(val_loader.dataset),
100. * correct / len(val_loader.dataset)
))
标签:交叉,torch,db,batch,loader,train,验证,data
From: https://www.cnblogs.com/dxmstudy/p/17447405.html