首页 > 其他分享 >STN在mnist数据集上的实现

STN在mnist数据集上的实现

时间:2022-10-06 20:15:51浏览次数:61  
标签:nn STN torch import 集上 data self mnist size

stn在mnist上的实现

个人博客 - https://cxy-sky.github.io/

代码参考来源PyTorch框架实战系列(3)——空间变换器网络STN_Daniel Yuz的博客-CSDN博客

理论Pytorch中的仿射变换(affine_grid)_liangbaqiang的博客-CSDN博客

详细解读Spatial Transformer Networks(STN)-一篇文章让你完全理解STN了_黄小猿的博客-CSDN博客_stn

​ 图片显示用的是matplotlib,自己没下opencv.

CNN

import torch
from torch import nn, optim


class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=4),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3),
        )
        self.linear = nn.Sequential(
            nn.Dropout2d(0.5),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        x = self.cnn(x)
        x = x.view(x.size()[0], -1)
        # print(x.size())
        x = self.linear(x)
        return x


if __name__ == '__main__':
    model = CNN()
    x = torch.rand(1, 1, 28, 28)
    print(model)
    y = model(x)
    print(y)

STN

import torch
from torch import nn


class STN(nn.Module):
    def __init__(self):
        super(STN, self).__init__()
        self.location_cov = nn.Sequential(
            nn.Conv2d(1, 8, kernel_size=7),
            nn.ReLU(),
            nn.MaxPool2d(2, stride=2),
            nn.Conv2d(8, 10, kernel_size=5),
            nn.ReLU(),
            nn.MaxPool2d(2, stride=2),
        )

        self.localization_linear = nn.Sequential(
            nn.Linear(in_features=10 * 3 * 3, out_features=32),
            nn.ReLU(),
            nn.Linear(in_features=32, out_features=2 * 3)
        )

        self.localization_linear[2].weight.data.zero_()
        self.localization_linear[2].bias.data.copy_(torch.tensor([1, 0, 0,
                                                                  0, 1, 0], dtype=torch.float))

        self.cnn = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=4),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3),
        )
        self.linear = nn.Sequential(
            nn.Dropout2d(0.5),
            nn.Linear(512, 10)
        )

    def stn(self, x):
        x2 = self.location_cov(x)
        x2 = x2.view(x2.size()[0], -1)
        x2 = self.localization_linear(x2)
        theta = x2.view(x2.size()[0], 2, 3)
        grid = nn.functional.affine_grid(theta, x.size(), align_corners=True)
        x = nn.functional.grid_sample(x, grid, align_corners=True)
        return x

    def forward(self, x):
        x = self.stn(x)
        x = self.cnn(x)
        x = x.view(x.size()[0], -1)
        x = self.linear(x)
        return x


if __name__ == '__main__':
    x = torch.rand(1, 1, 28, 28)
    model = STN()
    print(model)
    print(model(x))

train

import numpy as np
import torch
from torchvision import transforms
import torch.utils.data
import matplotlib.pyplot as plt
import torchvision
from torch.utils.tensorboard import SummaryWriter
from torchvision.datasets import ImageFolder
from PIL import Image
from torch import nn, optim

from stn.CNN import CNN
from stn.STN import STN

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# 数据处理
transform = transforms.Compose([
    transforms.RandomRotation(45),
    transforms.ToTensor(),
    transforms.Normalize((0.5), (0.5))
]
)

train_data = torchvision.datasets.MNIST('../data/mnist',
                                        download=True,
                                        train=True,
                                        transform=transform
                                        )

test_data = torchvision.datasets.MNIST('../data/mnist',
                                       download=True,
                                       train=False,
                                       transform=transform, )

train_loader = torch.utils.data.DataLoader(train_data,
                                           batch_size=64,
                                           shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data,
                                          batch_size=64,
                                          shuffle=True)

data_iter = iter(train_loader)
imgs = torchvision.utils.make_grid(next(data_iter)[0], 8)
imgs = imgs.numpy().transpose(1, 2, 0)
imgs = imgs * 0.5 + 0.5
plt.imshow(imgs)
plt.show()


# model = CNN()
model = STN()
model = model.to(device)
loss_fun = nn.CrossEntropyLoss().to(device)
opt_fun = optim.Adam(params=model.parameters(), lr=0.001)

loss = 0
train_acc_count = []
test_acc_count = []
train_loss = []
test_loss = []


def train(epoch):

    for i in range(epoch):
        for index, data in enumerate(train_loader):
            imgs = data[0].to(device)
            labels = data[1].to(device)
            outputs = model(imgs).to(device)
            loss = loss_fun(outputs, labels)
            loss.backward()
            opt_fun.step()
            opt_fun.zero_grad()
            if index % 100 == 0:
                print("第{}轮,第{}次,loss为:{}".format(i + 1, index, loss.item()))
                train_loss.append(loss.item())


def test():
    test_count = 0.
    for imgs, labels in test_loader:
        with torch.no_grad():
            outputs = model(imgs.to(device)).to(device)
            test_acc_count = (torch.max(outputs, dim=1)[1] == labels.to(device)).sum().item()
            test_count = labels.size()[0]
    print("测试集准确率{}".format(test_acc_count / test_count))


if __name__ == '__main__':
    # 设置随机数种子
    np.random.seed(1)
    torch.manual_seed(1)
    torch.cuda.manual_seed_all(1)
    # 保证每次结果一样
    torch.backends.cudnn.deterministic = True
    train(10)
    test()
    sava_path = '../model/mnistStn.pth'
    torch.save(model.state_dict(), sava_path)
    plt.plot(train_loss)
    plt.show()


showImage

from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torchvision
import torch
import matplotlib.pyplot as plt

from stn.STN import STN

transform = transforms.Compose([
    transforms.RandomRotation(45),
    transforms.ToTensor(),
    transforms.Normalize((0.5), (0.5))
]
)

train_data = torchvision.datasets.MNIST('../data/mnist',
                                        download=True,
                                        train=True,
                                        transform=transform
                                        )

train_loader = torch.utils.data.DataLoader(train_data,
                                           batch_size=64,
                                           shuffle=True)

data_iter = iter(train_loader)
imgs, labels = next(data_iter)
pre = torchvision.utils.make_grid(imgs, 8)
pre = pre.numpy().transpose(1, 2, 0)
pre = pre * 0.5 + 0.5
plt.subplot(2, 1, 1)
plt.imshow(pre)
plt.title('pre')

model = STN()
model.load_state_dict(torch.load('../model/mnistStn.pth'))
now = model.stn(imgs).detach()
now = torchvision.utils.make_grid(now, 8)
now = now.numpy().transpose(1, 2, 0)
now = now * 0.5 + 0.5
plt.subplot(2, 1, 2)
plt.imshow(now)
plt.title('now')

plt.show()

train,epoch=10

请添加图片描述

​ 展示transom后的图片,还是感觉很神奇

请添加图片描述

标签:nn,STN,torch,import,集上,data,self,mnist,size
From: https://www.cnblogs.com/cxy-sky/p/sky02.html

相关文章