首页 > 其他分享 >深度学习(超分辨率)

深度学习(超分辨率)

时间:2024-12-21 21:53:09浏览次数:4  
标签:__ nn img 分辨率 学习 深度 size self out

             

简单训练了一个模型,可以实现超分辨率效果。模型在这里

模型用了一些卷积层,最后接一个PixelShuffle算子。

训练数据是原始图像resize后的亮度通道。

标签是原始图像的亮度通道。

损失函数设为MSE。

代码如下:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import Compose, CenterCrop, ToTensor, Resize
from PIL import Image
from os import listdir
from os.path import join
import numpy as np

crop_size = 256
upscale_factor = 3
crop_size = crop_size - (crop_size % upscale_factor)

input_transformer= Compose([
        CenterCrop(crop_size),
        Resize(crop_size // upscale_factor),
        ToTensor()])

target_transform =Compose([
        CenterCrop(crop_size),
        ToTensor()])

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(1, 64, 5, 1, 2)
        self.conv2 = nn.Conv2d(64, 64, 3, 1, 1)
        self.conv3 = nn.Conv2d(64, 32, 3, 1, 1)
        self.conv4 = nn.Conv2d(32, upscale_factor ** 2, 3, 1, 1)
        self.pixel_shuffle = nn.PixelShuffle(upscale_factor)

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        x = self.pixel_shuffle(self.conv4(x))       
        return x

class SRData(Dataset):
    def __init__(self, image_dir):
        self.image_filenames = [join(image_dir, x) for x in listdir(image_dir)]

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, index):
        image = Image.open(self.image_filenames[index]).convert('YCbCr')
        y, _, _ = image.split()

        img = input_transformer(y)
        lab = target_transform(y)
        return img, lab

def train():
    num_epochs = 2

    model = Net()
    optimizer = optim.Adam(model.parameters(), lr=0.01)
    criterion = nn.MSELoss()

    train_dataset = SRData('./dataset')
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model.to(device)
    model.train()

    for epoch in range(num_epochs):
        running_loss = 0.0
        for images, labels in train_loader:

            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}")

    torch.save(model, 'super_res.pth')

def test():

    img = Image.open("test.jpg").convert('YCbCr')
    y, cb, cr = img.split()

    model = torch.load("super_res.pth")
    img_to_tensor = ToTensor()
    input = img_to_tensor(y).view(1, 1, y.size[1], y.size[0])

    model = model.cuda()
    input = input.cuda()

    out = model(input)
    out = out.cpu()
    out_img_y = out[0].detach().numpy()
    out_img_y *= 255.0
    out_img_y = out_img_y.clip(0, 255)
    out_img_y = Image.fromarray(np.uint8(out_img_y[0]), mode='L')

    out_img_cb = cb.resize(out_img_y.size, Image.BICUBIC)
    out_img_cr = cr.resize(out_img_y.size, Image.BICUBIC)
    out_img = Image.merge('YCbCr', [out_img_y, out_img_cb, out_img_cr]).convert('RGB')

    out_img.save("out.jpg")

if __name__ == "__main__":
  #  train()
    test()

效果如下:

原图:

结果:

标签:__,nn,img,分辨率,学习,深度,size,self,out
From: https://www.cnblogs.com/tiandsp/p/18611110

相关文章

  • Golang学习笔记_13——数组
    Golang学习笔记_10——SwitchGolang学习笔记_11——指针Golang学习笔记_12——结构体文章目录数组1.定义2.访问和修改3.多维数组4.计算数组长度5.数组作为函数参数6.遍历7.数组的内存表示源码数组Go语言中的数组是一种具有固定长度、相同类型元素的集......
  • Golang学习笔记_14——切片
    Golang学习笔记_11——指针Golang学习笔记_12——结构体Golang学习笔记_13——数组文章目录切片1.定义2.创建3.基本操作4.动态性5.子切片6.数组和切片7.注意8.高级用法源码切片Go语言中的切片(slice)是一种非常强大且灵活的数据结构,它基于数组,但提供了......
  • Java学习笔记
    面向过程小知识点基本类型变量和引用类型变量局部变量和成员变量成员变量分为:静态成员变量和实例成员变量staticfinal修饰的成员变量称为常量(宏替换)多态使用父类类型的引用指向子类的对象该引用只能调用父类中定义的方法和变量如果子类中重写了父类中的一个方法,那么在调......
  • 《PyTorch深度学习实战》(一)
    1.张量张量(Tensor)是一个数学对象,可以看作是向量和矩阵的推广。在数学和物理学中,张量被用来描述多维空间中的量,这些量可能具有多个方向和大小。张量的定义和性质如下:阶数(Order):张量的阶数表示张量的维度。一个标量(Scalar)是0阶张量,一个向量(Vector)是1阶张量,一个矩阵(Matrix)是2阶张......
  • flask毕设学习交流平台的设计与实现(程序+论文)
    本系统(程序+源码+数据库+调试部署+开发环境)带论文文档1万字以上,文末可获取,系统界面在最后面。系统程序文件列表开题报告内容选题背景随着信息技术的飞速发展,学习交流平台在教育领域的应用日益广泛。现有研究主要集中在在线教育平台的技术实现、用户行为分析以及教学模式创......
  • flask毕设学习资源分享系统的设计与实现(程序+论文)
    本系统(程序+源码+数据库+调试部署+开发环境)带论文文档1万字以上,文末可获取,系统界面在最后面。系统程序文件列表开题报告内容选题背景随着信息技术的飞速发展和互联网的广泛普及,学习资源分享已成为教育领域的一大热点。现有研究主要集中在在线教育平台的设计、学习资源的管......
  • springboot毕设科技英语学习网站程序+论文+部署
    本系统(程序+源码)带文档lw万字以上 文末可获取一份本项目的java源码和数据库参考。系统程序文件列表开题报告内容一、研究背景随着全球化的迅猛发展以及科技领域国际交流合作的日益频繁,科技英语的重要性愈发凸显。在学术研究、科研成果发布、高新技术交流等众多方面,科技英......
  • 学习编程从游戏开始——多彩俄罗斯方块的设计构想
    0.前言我想通过编写一个完整的游戏程序方式引导读者体验程序设计的全过程。我将采用多种方式编写具有相同效果的应用程序,并通过不同方式形成的代码和实现方法的对比来理解程序开发更深层的知识。了解我编写教程的思路,请参阅体现我最初想法的那篇文章中的“1.编程计划”:学习编程......
  • 强化学习算法中的log_det_jacobian —— 概率分布的仿射变换(Bijector)(续)
    前文:强化学习算法中的log_det_jacobian——概率分布的仿射变换(Bijector)前文说到概率分布的仿射变换(Bijector)在贝叶斯、变分推断等领域有很重要的作用,但是在强化学习中呢,其实在强化学习中也会用到,但是最为普遍的应用场景其实只是做简单的tanh变换。在强化学习中一般用高斯分......
  • 深度学习——循环神经网络(八)
    序列模型训练生成数据序列importmatplotlib_inlineimporttorchimporttorch.nnasnnimportd2l.torchasd2limportmatplotlib.pyplotaspltimportnumpyasnpT=1000time=torch.arange(1,T+1,1,dtype=torch.float32)x=torch.sin(0.01*time)......