首页 > 其他分享 >深度学习(UNet)

深度学习(UNet)

时间:2024-10-01 15:11:40浏览次数:1  
标签:__ conv self torch 学习 UNet 深度 128 256

       

和FCN类似,UNet是另一个做语义分割的网络,网络从输入到输出中间呈一个U型而得名。

相比于FCN,UNet增加了更多的中间连接,能够更好处理不同尺度上的特征。

网络结构如下:

下面代码是用UNet对VOC数据集做的语义分割。

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset,DataLoader
from torchvision import transforms
import os
from PIL import Image
import numpy as np

transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

colormap = [[0,0,0],[128,0,0],[0,128,0], [128,128,0], [0,0,128],
            [128,0,128],[0,128,128],[128,128,128],[64,0,0],[192,0,0],
            [64,128,0],[192,128,0],[64,0,128],[192,0,128],
            [64,128,128],[192,128,128],[0,64,0],[128,64,0],
            [0,192,0],[128,192,0],[0,64,128]]

class VOCData(Dataset):

    def __init__(self, root):
        super(VOCData, self).__init__()

        self.lab_path = root + 'VOC2012/SegmentationClass/'
        self.img_path = root + 'VOC2012/JPEGImages/'

        self.lab_names = self.get_file_names(self.lab_path)

        self.img_names=[]
        for file in self.lab_names:
            self.img_names.append(file.replace('.png', '.jpg'))

        self.cm2lbl = np.zeros(256**3) 
        for i,cm in enumerate(colormap): 
            self.cm2lbl[cm[0]*256*256+cm[1]*256+cm[2]] = i

        self.image = []
        self.label = []
        for i in range(len(self.lab_names)):
            image = Image.open(self.img_path+self.img_names[i]).convert('RGB')
            image = transform(image)
        
            label = Image.open(self.lab_path+self.lab_names[i]).convert('RGB').resize((256,256))
            label = torch.from_numpy(self.image2label(label))

            self.image.append(image)
            self.label.append(label)

    def __len__(self):
        return len(self.image)

    def __getitem__(self, idx):
        return self.image[idx], self.label[idx]

    def get_file_names(self,directory):
        file_names = []
        for file_name in os.listdir(directory):
            if os.path.isfile(os.path.join(directory, file_name)):
                file_names.append(file_name)
        return file_names

    def image2label(self,im):
        data = np.array(im, dtype='int32')
        idx = data[:, :, 0] * 256 * 256 + data[:, :, 1] * 256 + data[:, :, 2]
        return np.array(self.cm2lbl[idx], dtype='int64')

class convblock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(convblock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)        
        return x

class Unet(nn.Module):
    def __init__(self, num_classes):
        super(Unet, self).__init__()
        self.conv_block1 = convblock(3,64)
        self.conv_block2 = convblock(64,128)
        self.conv_block3 = convblock(128,256)
        self.conv_block4 = convblock(256,512)
        self.conv_block5 = convblock(512,1024)

        self.upsample1 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2, padding=0)
        self.upsample2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2, padding=0)
        self.upsample3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2, padding=0)
        self.upsample4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2, padding=0)

        self.conv_block6 = convblock(1024,512)
        self.conv_block7 = convblock(512,256)
        self.conv_block8 = convblock(256,128)
        self.conv_block9 = convblock(128,64)

        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv_out = convblock(64,num_classes)

    def forward(self, x):
        x1 = self.conv_block1(x)  
        x = self.maxpool(x1) 

        x2 = self.conv_block2(x) 
        x = self.maxpool(x2)

        x3 = self.conv_block3(x)
        x = self.maxpool(x3)

        x4 = self.conv_block4(x)
        x = self.maxpool(x4)

        x = self.conv_block5(x)

        x = self.upsample1(x)
        x = torch.cat([x4,x],dim=1)
        x = self.conv_block6(x)

        x = self.upsample2(x)
        x = torch.cat([x3,x],dim=1)
        x = self.conv_block7(x)

        x = self.upsample3(x)
        x = torch.cat([x2,x],dim=1)
        x = self.conv_block8(x)

        x = self.upsample4(x)
        x = torch.cat([x1,x],dim=1)
        x = self.conv_block9(x)

        x = self.conv_out(x)

        return x

def train():
    train_dataset = VOCData(root='./VOCdevkit/')
    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)

    net = Unet(21)

    optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
    criterion = nn.CrossEntropyLoss()

    net.to(device)
    net.train()

    num_epochs = 100
    for epoch in range(num_epochs):

        loss_sum = 0
        img_sum = 0
        for inputs, labels in train_loader:
            
            inputs =  inputs.to(device)
            labels =  labels.to(device)

            outputs = net(inputs)
            loss = criterion(outputs, labels)   

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loss_sum += loss.item()
            img_sum += inputs.shape[0]

        print('epochs:',epoch,loss_sum / img_sum )
    torch.save(net.state_dict(), 'unet.pth')


def val():
    net = Unet(21)
    net.load_state_dict(torch.load('unet.pth'))

    net.to(device)
    net.eval()

    image = Image.open('./VOCdevkit/VOC2012/JPEGImages/2012_001064.jpg').convert('RGB')
    image = transform(image).unsqueeze(0).to(device)

    out = net(image).squeeze(0)
    ToPIL= transforms.ToPILImage()
    maxind = torch.argmax(out,dim=0)
    outimg = torch.zeros([3,256,256])
    for y in range(256):
        for x in range(256):
            outimg[:,x,y] = torch.from_numpy(np.array(colormap[maxind[x,y]]))

    re = ToPIL(outimg)
    re.show()


if __name__ == "__main__":
    train()
    #val()

标签:__,conv,self,torch,学习,UNet,深度,128,256
From: https://www.cnblogs.com/tiandsp/p/17922362.html

相关文章

  • 论文总结1--基于深度强化学习的四足机器人步态分析--2024.10.01
    四足机器人的运动控制方法研究1.传统运动控制-基于模型的控制方法  目前,在四足机器人研究领域内应用最广泛的控制方法就是基于模型的控制方法,其中主要包括基于虚拟模型控制(VirtualModelControl,VMC)方法、基于零力矩点(ZeroMomentPoint,ZMP)的控制方法、弹簧负载倒立摆算法......
  • 《数据结构(刘大有)》学习(6)
    系列文章目录一、绪论二、顺序表、链表三、堆栈、队列四、数组五、字符串六、树目录树的基本概念树的定义树的特点树的相关术语度层数高度路径二叉树定义特点定理满二叉树定义特点完全二叉树定义特点二叉树的存储结构顺序存储结点结构优点缺点 链式存储 结点结构三......
  • 基于nodejs+vue学生网课学习数据分析与展示系统[开题+源码+程序+论文]计算机毕业设计
    本系统(程序+源码+数据库+调试部署+开发环境)带文档lw万字以上,文末可获取源码系统程序文件列表开题报告内容研究背景随着互联网技术的飞速发展和全球疫情的持续影响,在线教育已成为教育领域的重要组成部分。各大教育平台纷纷推出网课服务,以满足广大学生在家学习的需求。然而,......
  • 【机器学习-无监督学习】降维与主成分分析
    【作者主页】FrancekChen【专栏介绍】⌈⌈⌈Python机器学习⌋......
  • Python从0到100(六十一):机器学习实战-实现客户细分
    一、导入数据在此项目中,我们使用UCI机器学习代码库中的数据集。该数据集包含关于来自多种产品类别的各种客户年度消费额(货币单位计价)的数据。该项目的目标之一是准确地描述与批发商进行交易的不同类型的客户之间的差别。这样可以使分销商清晰地了解如何安排送货服务,以便......
  • [rCore学习笔记 028] Rust 中的动态内存分配
    引言想起我们之前在学习C的时候,总是提到malloc,总是提起,使用malloc现场申请的内存是属于堆,而直接定义的变量内存属于栈.还记得当初学习STM32的时候CubeIDE要设置stack和heap的大小.但是我们要记得,这么好用的功能,实际上是操作系统在负重前行.那么为了实现动态内存分配功......
  • 深度学习(计算数据集均值标准差)
      深度学习中有些数据集可能不符合imagenet计算出的均值和标准差,需要根据自己的数据集单独计算。下面这个脚本能够计算当前数据集均值和标准差。 importtorchimportosfromPILimportImagefromtorchvisionimporttransforms#trans=transforms.Compose([#......
  • 【学术】过来人对研究生阶段的学习建议
    【学术】过来人对研究生阶段的学习建议不同阶段的学习模式高中—>本科—>研究生是不同的学习阶段,面对不同的生存学习环境,需要做出调整。诚然,一些科研能力强的研究生会产生一种如鱼得水的感觉——不仅适应良好,且成长很快。两大法则:硕博生和本科生最大的区别在于独立处理......
  • 有监督学习&无监督学习
    有监督学习&无监督学习有监督学习(SupervisedLearning)和无监督学习(UnsupervisedLearning)是机器学习中的两种主要方法,它们在目标、数据使用和应用场景上有显著的区别:目标不同:有监督学习:目的是通过训练数据集(包含输入特征和对应的标签)来学习一个模型,以便对新的、未见......
  • Elasticsearch学习笔记(3)
    RestAPIElasticsearch(ES)官方提供了多种语言的客户端库,用于与Elasticsearch进行交互。这些客户端库的主要功能是帮助开发者更方便地构建和发送DSL(DomainSpecificLanguage)查询语句,并通过HTTP请求与Elasticsearch集群进行通信。官方文档地址:https://www.elastic.co/guide/en/......