stn在mnist上的实现
个人博客 - https://cxy-sky.github.io/
代码参考来源
:PyTorch框架实战系列(3)——空间变换器网络STN_Daniel Yuz的博客-CSDN博客
理论
:Pytorch中的仿射变换(affine_grid)_liangbaqiang的博客-CSDN博客
详细解读Spatial Transformer Networks(STN)-一篇文章让你完全理解STN了_黄小猿的博客-CSDN博客_stn
图片显示用的是matplotlib,自己没下opencv.
CNN
import torch
from torch import nn, optim
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.cnn = nn.Sequential(
nn.Conv2d(1, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=4),
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=3),
)
self.linear = nn.Sequential(
nn.Dropout2d(0.5),
nn.Linear(512, 10)
)
def forward(self, x):
x = self.cnn(x)
x = x.view(x.size()[0], -1)
# print(x.size())
x = self.linear(x)
return x
if __name__ == '__main__':
model = CNN()
x = torch.rand(1, 1, 28, 28)
print(model)
y = model(x)
print(y)
STN
import torch
from torch import nn
class STN(nn.Module):
def __init__(self):
super(STN, self).__init__()
self.location_cov = nn.Sequential(
nn.Conv2d(1, 8, kernel_size=7),
nn.ReLU(),
nn.MaxPool2d(2, stride=2),
nn.Conv2d(8, 10, kernel_size=5),
nn.ReLU(),
nn.MaxPool2d(2, stride=2),
)
self.localization_linear = nn.Sequential(
nn.Linear(in_features=10 * 3 * 3, out_features=32),
nn.ReLU(),
nn.Linear(in_features=32, out_features=2 * 3)
)
self.localization_linear[2].weight.data.zero_()
self.localization_linear[2].bias.data.copy_(torch.tensor([1, 0, 0,
0, 1, 0], dtype=torch.float))
self.cnn = nn.Sequential(
nn.Conv2d(1, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=4),
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=3),
)
self.linear = nn.Sequential(
nn.Dropout2d(0.5),
nn.Linear(512, 10)
)
def stn(self, x):
x2 = self.location_cov(x)
x2 = x2.view(x2.size()[0], -1)
x2 = self.localization_linear(x2)
theta = x2.view(x2.size()[0], 2, 3)
grid = nn.functional.affine_grid(theta, x.size(), align_corners=True)
x = nn.functional.grid_sample(x, grid, align_corners=True)
return x
def forward(self, x):
x = self.stn(x)
x = self.cnn(x)
x = x.view(x.size()[0], -1)
x = self.linear(x)
return x
if __name__ == '__main__':
x = torch.rand(1, 1, 28, 28)
model = STN()
print(model)
print(model(x))
train
import numpy as np
import torch
from torchvision import transforms
import torch.utils.data
import matplotlib.pyplot as plt
import torchvision
from torch.utils.tensorboard import SummaryWriter
from torchvision.datasets import ImageFolder
from PIL import Image
from torch import nn, optim
from stn.CNN import CNN
from stn.STN import STN
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# 数据处理
transform = transforms.Compose([
transforms.RandomRotation(45),
transforms.ToTensor(),
transforms.Normalize((0.5), (0.5))
]
)
train_data = torchvision.datasets.MNIST('../data/mnist',
download=True,
train=True,
transform=transform
)
test_data = torchvision.datasets.MNIST('../data/mnist',
download=True,
train=False,
transform=transform, )
train_loader = torch.utils.data.DataLoader(train_data,
batch_size=64,
shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data,
batch_size=64,
shuffle=True)
data_iter = iter(train_loader)
imgs = torchvision.utils.make_grid(next(data_iter)[0], 8)
imgs = imgs.numpy().transpose(1, 2, 0)
imgs = imgs * 0.5 + 0.5
plt.imshow(imgs)
plt.show()
# model = CNN()
model = STN()
model = model.to(device)
loss_fun = nn.CrossEntropyLoss().to(device)
opt_fun = optim.Adam(params=model.parameters(), lr=0.001)
loss = 0
train_acc_count = []
test_acc_count = []
train_loss = []
test_loss = []
def train(epoch):
for i in range(epoch):
for index, data in enumerate(train_loader):
imgs = data[0].to(device)
labels = data[1].to(device)
outputs = model(imgs).to(device)
loss = loss_fun(outputs, labels)
loss.backward()
opt_fun.step()
opt_fun.zero_grad()
if index % 100 == 0:
print("第{}轮,第{}次,loss为:{}".format(i + 1, index, loss.item()))
train_loss.append(loss.item())
def test():
test_count = 0.
for imgs, labels in test_loader:
with torch.no_grad():
outputs = model(imgs.to(device)).to(device)
test_acc_count = (torch.max(outputs, dim=1)[1] == labels.to(device)).sum().item()
test_count = labels.size()[0]
print("测试集准确率{}".format(test_acc_count / test_count))
if __name__ == '__main__':
# 设置随机数种子
np.random.seed(1)
torch.manual_seed(1)
torch.cuda.manual_seed_all(1)
# 保证每次结果一样
torch.backends.cudnn.deterministic = True
train(10)
test()
sava_path = '../model/mnistStn.pth'
torch.save(model.state_dict(), sava_path)
plt.plot(train_loss)
plt.show()
showImage
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torchvision
import torch
import matplotlib.pyplot as plt
from stn.STN import STN
transform = transforms.Compose([
transforms.RandomRotation(45),
transforms.ToTensor(),
transforms.Normalize((0.5), (0.5))
]
)
train_data = torchvision.datasets.MNIST('../data/mnist',
download=True,
train=True,
transform=transform
)
train_loader = torch.utils.data.DataLoader(train_data,
batch_size=64,
shuffle=True)
data_iter = iter(train_loader)
imgs, labels = next(data_iter)
pre = torchvision.utils.make_grid(imgs, 8)
pre = pre.numpy().transpose(1, 2, 0)
pre = pre * 0.5 + 0.5
plt.subplot(2, 1, 1)
plt.imshow(pre)
plt.title('pre')
model = STN()
model.load_state_dict(torch.load('../model/mnistStn.pth'))
now = model.stn(imgs).detach()
now = torchvision.utils.make_grid(now, 8)
now = now.numpy().transpose(1, 2, 0)
now = now * 0.5 + 0.5
plt.subplot(2, 1, 2)
plt.imshow(now)
plt.title('now')
plt.show()
train,epoch=10
标签:nn,STN,torch,import,集上,data,self,mnist,size From: https://www.cnblogs.com/cxy-sky/p/sky02.html 展示transom后的图片,还是感觉很神奇