Import required libraries:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
import torchvision.transforms.functional as TF
from torchvision.models import vgg19
from torchvision.datasets import ImageFolder
Define a simple convolutional block (Conv-BatchNorm-ReLU)
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
super(ConvBlock, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.conv(x)
Define a simple upscaling block using sub-pixel convolution
class UpscaleBlock(nn.Module):
def __init__(self, in_channels, scale_factor):
super(UpscaleBlock, self).__init__()
self.conv = nn.Conv2d(in_channels, in_channels * (scale_factor ** 2), kernel_size=3, stride=1, padding=1)
self.pixel_shuffle = nn.PixelShuffle(scale_factor)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.pixel_shuffle(x)
x = self.relu(x)
return x
Define a custom super-resolution model (e.g., using ConvBlocks and UpscaleBlocks)
class SuperResolutionModel(nn.Module):
def __init__(self, upscale_factor):
super(SuperResolutionModel, self).__init__()
self.conv1 = ConvBlock(3, 64, kernel_size=9, stride=1, padding=4)
self.conv2 = ConvBlock(64, 32, kernel_size=1, stride=1, padding=0)
self.upscale = UpscaleBlock(32, upscale_factor)
self.conv3 = nn.Conv2d(32, 3, kernel_size=9, stride=1, padding=4)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.upscale(x)
x = self.conv3(x)
return x
Create a custom dataset for image super-resolution
class SuperResolutionDataset(torch.utils.data.Dataset):
def __init__(self, image_folder, input_transform, target_transform):
self.dataset = ImageFolder(image_folder)
self.input_transform = input_transform
self.target_transform = target_transform
def __getitem__(self, index):
img, _ = self.dataset[index]
target = self.target_transform(img)
input = self.input_transform(target)
return input, target
def __len__(self):
return len(self.dataset)
Instantiate the model, loss function, and optimizer
upscale_factor = 2
model = SuperResolutionModel(upscale_factor).to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
Define input and target transformations for data preprocessing
input_transform = transforms.Compose([
transforms.Resize((256 // upscale_factor, 256 // upscale_factor), interpolation=TF.InterpolationMode.BICUBIC),
transforms.ToTensor()
])
target_transform = transforms.Compose([
transforms.Resize((256, 256), interpolation=TF.InterpolationMode.BICUBIC),
transforms.ToTensor()
])
Create DataLoader for training and validation data
train_dataset = SuperResolutionDataset("path/to/train_data", input_transform, target_transform)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4)
val_dataset = SuperResolutionDataset("path/to/val_data", input_transform, target_transform)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=4)
Training loop
model.eval()
val_loss = 0.0
with torch.no_grad():
for inputs, targets in val_loader:
inputs, targets = inputs.to(device), targets.to(device)
outputs = model(inputs)
loss = criterion(outputs, targets)
val_loss += loss.item()
val_loss /= len(val_loader)
print(f"Validation Loss: {val_loss:.4f}")
Validation loop
model.eval()
val_loss = 0.0
with torch.no_grad():
for inputs, targets in val_loader:
inputs, targets = inputs.to(device), targets.to(device)
outputs = model(inputs)
loss = criterion(outputs, targets)
val_loss += loss.item()
val_loss /= len(val_loader)
print(f"Validation Loss: {val_loss:.4f}")
标签:__,loss,nn,val,分辨率,transform,Pytorch,模块,self
From: https://www.cnblogs.com/maluyelang/p/17253690.html