首页 > 其他分享 >Pytorch构建超分辨率模型——常用模块


时间:2023-03-24 23:47:23浏览次数:42  
标签:__ loss nn val 分辨率 transform Pytorch 模块 self

Import required libraries:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
import torchvision.transforms.functional as TF
from torchvision.models import vgg19
from torchvision.datasets import ImageFolder

Define a simple convolutional block (Conv-BatchNorm-ReLU)

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
        super(ConvBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),

    def forward(self, x):
        return self.conv(x)

Define a simple upscaling block using sub-pixel convolution

class UpscaleBlock(nn.Module):
    def __init__(self, in_channels, scale_factor):
        super(UpscaleBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, in_channels * (scale_factor ** 2), kernel_size=3, stride=1, padding=1)
        self.pixel_shuffle = nn.PixelShuffle(scale_factor)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.pixel_shuffle(x)
        x = self.relu(x)
        return x

Define a custom super-resolution model (e.g., using ConvBlocks and UpscaleBlocks)

class SuperResolutionModel(nn.Module):
    def __init__(self, upscale_factor):
        super(SuperResolutionModel, self).__init__()
        self.conv1 = ConvBlock(3, 64, kernel_size=9, stride=1, padding=4)
        self.conv2 = ConvBlock(64, 32, kernel_size=1, stride=1, padding=0)
        self.upscale = UpscaleBlock(32, upscale_factor)
        self.conv3 = nn.Conv2d(32, 3, kernel_size=9, stride=1, padding=4)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.upscale(x)
        x = self.conv3(x)
        return x

Create a custom dataset for image super-resolution

class SuperResolutionDataset(torch.utils.data.Dataset):
    def __init__(self, image_folder, input_transform, target_transform):
        self.dataset = ImageFolder(image_folder)
        self.input_transform = input_transform
        self.target_transform = target_transform

    def __getitem__(self, index):
        img, _ = self.dataset[index]
        target = self.target_transform(img)
        input = self.input_transform(target)
        return input, target

    def __len__(self):
        return len(self.dataset)

Instantiate the model, loss function, and optimizer

upscale_factor = 2
model = SuperResolutionModel(upscale_factor).to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

Define input and target transformations for data preprocessing

input_transform = transforms.Compose([
    transforms.Resize((256 // upscale_factor, 256 // upscale_factor), interpolation=TF.InterpolationMode.BICUBIC),

target_transform = transforms.Compose([
    transforms.Resize((256, 256), interpolation=TF.InterpolationMode.BICUBIC),

Create DataLoader for training and validation data

train_dataset = SuperResolutionDataset("path/to/train_data", input_transform, target_transform)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4)

val_dataset = SuperResolutionDataset("path/to/val_data", input_transform, target_transform)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=4)

Training loop

val_loss = 0.0

with torch.no_grad():
    for inputs, targets in val_loader:
        inputs, targets = inputs.to(device), targets.to(device)

        outputs = model(inputs)
        loss = criterion(outputs, targets)

        val_loss += loss.item()

val_loss /= len(val_loader)
print(f"Validation Loss: {val_loss:.4f}")

Validation loop

val_loss = 0.0

with torch.no_grad():
    for inputs, targets in val_loader:
        inputs, targets = inputs.to(device), targets.to(device)

        outputs = model(inputs)
        loss = criterion(outputs, targets)

        val_loss += loss.item()

val_loss /= len(val_loader)
print(f"Validation Loss: {val_loss:.4f}")

From: https://www.cnblogs.com/maluyelang/p/17253690.html
