首页 > 其他分享 >深度学习(FCN)

深度学习(FCN)

时间:2024-09-15 22:25:29浏览次数:9  
标签:__ nn conv self 学习 深度 128 256 FCN

FCN是全卷积网络,用于做图像语义分割。通常将一般卷积网络最后的全连接层换成上采样或者反卷积网络,对图像的每个像素做分类,从而完成图像分割任务。

网络结构如下:

这里并没有完全按照原始网络结构实现,而是尝试upsample和convTranspose2d结合的方式,看看有什么效果。

下面代码是用VOC数据集做的语义分割,一共2000多张图片,21种类别,还是有一些效果的。

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset,DataLoader
from torchvision import transforms
import os
from PIL import Image
import numpy as np

transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

colormap = [[0,0,0],[128,0,0],[0,128,0], [128,128,0], [0,0,128],
            [128,0,128],[0,128,128],[128,128,128],[64,0,0],[192,0,0],
            [64,128,0],[192,128,0],[64,0,128],[192,0,128],
            [64,128,128],[192,128,128],[0,64,0],[128,64,0],
            [0,192,0],[128,192,0],[0,64,128]]

class VOCData(Dataset):

    def __init__(self, root):
        super(VOCData, self).__init__()

        self.lab_path = root + 'VOC2012/SegmentationClass/'
        self.img_path = root + 'VOC2012/JPEGImages/'

        self.lab_names = self.get_file_names(self.lab_path)

        self.img_names=[]
        for file in self.lab_names:
            self.img_names.append(file.replace('.png', '.jpg'))

        self.cm2lbl = np.zeros(256**3) 
        for i,cm in enumerate(colormap): 
            self.cm2lbl[cm[0]*256*256+cm[1]*256+cm[2]] = i

        self.image = []
        self.label = []
        for i in range(len(self.lab_names)):
            image = Image.open(self.img_path+self.img_names[i]).convert('RGB')
            image = transform(image)
        
            label = Image.open(self.lab_path+self.lab_names[i]).convert('RGB').resize((256,256))
            label = torch.from_numpy(self.image2label(label))

            self.image.append(image)
            self.label.append(label)

    def __len__(self):
        return len(self.image)

    def __getitem__(self, idx):
        return self.image[idx], self.label[idx]

    def get_file_names(self,directory):
        file_names = []
        for file_name in os.listdir(directory):
            if os.path.isfile(os.path.join(directory, file_name)):
                file_names.append(file_name)
        return file_names

    def image2label(self,im):
        data = np.array(im, dtype='int32')
        idx = data[:, :, 0] * 256 * 256 + data[:, :, 1] * 256 + data[:, :, 2]
        return np.array(self.cm2lbl[idx], dtype='int64')

class convblock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(convblock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu2 = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)        
        x = self.maxpool(x)
        return x


class Fcn32s(nn.Module):
    def __init__(self, num_classes):
        super(Fcn32s, self).__init__()
        self.conv_block1 = convblock(3,64)
        self.conv_block2 = convblock(64,128)
        self.conv_block3 = convblock(128,256)
        self.conv_block4 = convblock(256,512)
        self.conv_block5 = convblock(512,512)
        self.conv = nn.Conv2d(512,4096,kernel_size=1)
        self.up16x = nn.Upsample(scale_factor=16)
        self.convTrans2x = nn.ConvTranspose2d(4096, num_classes, kernel_size=4, stride=2, padding=1)

    def forward(self, x):
        x = self.conv_block1(x)
        x = self.conv_block2(x)
        x = self.conv_block3(x)
        x = self.conv_block4(x)
        x = self.conv_block5(x)
        x = self.conv(x)
        x = self.up16x(x)
        x = self.convTrans2x(x)
        return x
    

class Fcn16s(nn.Module):
    def __init__(self, num_classes):
        super(Fcn16s, self).__init__()
        self.conv_block1 = convblock(3,64)
        self.conv_block2 = convblock(64,128)
        self.conv_block3 = convblock(128,256)
        self.conv_block4 = convblock(256,512)
        self.conv_block5 = convblock(512,512)
        self.conv1 = nn.Conv2d(512, num_classes, kernel_size=1)
        self.conv2 = nn.Conv2d(512,4096,kernel_size=1)
        self.convTrans2x = nn.ConvTranspose2d(4096, num_classes, kernel_size=4, stride=2, padding=1)
        self.up8x = nn.Upsample(scale_factor=8)
        self.convTrans2x2 = nn.ConvTranspose2d(num_classes, num_classes, kernel_size=4, stride=2, padding=1)

    def forward(self, x):
        x = self.conv_block1(x)
        x = self.conv_block2(x)
        x = self.conv_block3(x)
        x1 = self.conv_block4(x)
        x2 = self.conv_block5(x1)
        x1 = self.conv1(x1)
        x2 = self.conv2(x2)
        x2 = self.convTrans2x(x2)
        x = x1+x2
        x = self.up8x(x)
        x = self.convTrans2x2(x)
        return x
    

class Fcn8s(nn.Module):
    def __init__(self, num_classes):
        super(Fcn8s, self).__init__()
        self.conv_block1 = convblock(3,64)
        self.conv_block2 = convblock(64,128)
        self.conv_block3 = convblock(128,256)
        self.conv_block4 = convblock(256,512)
        self.conv_block5 = convblock(512,512)
        self.conv1 = nn.Conv2d(256, num_classes, kernel_size=1)
        self.conv2 = nn.Conv2d(512, num_classes, kernel_size=1)
        self.conv3 = nn.Conv2d(512,4096,kernel_size=1)
        self.upsample2x1 = nn.ConvTranspose2d(4096, num_classes, kernel_size=4, stride=2, padding=1)
        self.upsample2x2 = nn.ConvTranspose2d(num_classes, num_classes, kernel_size=4, stride=2, padding=1)
        self.up = nn.Upsample(scale_factor=4)
        self.upsample2x3 = nn.ConvTranspose2d(num_classes, num_classes, kernel_size=4, stride=2, padding=1)

    def forward(self, x):
        x = self.conv_block1(x)
        x = self.conv_block2(x)
        x1 = self.conv_block3(x)
        x2 = self.conv_block4(x1)
        x3 = self.conv_block5(x2)
        x1 = self.conv1(x1)
        x2 = self.conv2(x2)
        x3 = self.conv3(x3)
        x3 = self.upsample2x1(x3)
        x3 = x2 + x3
        x3 = self.upsample2x2(x3)
        x3 = x1 + x3
        x3 = self.up(x3)
        x = self.upsample2x3(x3)
        return x

def train():
    train_dataset = VOCData(root='./VOCdevkit/')
    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)

    #net = Fcn32s(21)
    #net = Fcn16s(21)
    net = Fcn8s(21)

    optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
    criterion = nn.CrossEntropyLoss()

    net.to(device)
    net.train()

    num_epochs = 100
    for epoch in range(num_epochs):

        loss_sum = 0
        img_sum = 0
        for inputs, labels in train_loader:
            
            inputs =  inputs.to(device)
            labels =  labels.to(device)

            outputs = net(inputs)
            loss = criterion(outputs, labels)   

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loss_sum += loss.item()
            img_sum += inputs.shape[0]

        print('epochs:',epoch,loss_sum / img_sum )
    torch.save(net.state_dict(), 'my_fcn.pth')


def val():
    net = Fcn8s(21)
    net.load_state_dict(torch.load('my_fcn.pth'))

    net.to(device)
    net.eval()

    image = Image.open('./VOCdevkit/VOC2012/JPEGImages/2007_009794.jpg').convert('RGB')
    image = transform(image).unsqueeze(0).to(device)

    out = net(image).squeeze(0)
    ToPIL= transforms.ToPILImage()
    maxind = torch.argmax(out,dim=0)
    outimg = torch.zeros([3,256,256])
    for y in range(256):
        for x in range(256):
            outimg[:,x,y] = torch.from_numpy(np.array(colormap[maxind[x,y]]))

    re = ToPIL(outimg)
    re.show()

if __name__ == "__main__":
    train()
    val()

标签:__,nn,conv,self,学习,深度,128,256,FCN
From: https://www.cnblogs.com/tiandsp/p/18415257

相关文章

  • 【源码论文】基于小程序/安卓的大学生党务学习平台uniapp-JAVA.VUE
      博主介绍:......
  • 10个JavaWeb和JavaSE小项目:SSM、SpringBoot。毕设学习好模板。
    本仓列举了15个JavaWeb和JavaSE小项目:有SSM的、SpringBoot+Mybatis的、纯JavaSE+JavaFX的。对初学者非常友好,感兴趣的同学拿去学习。有问题请私信我。汽车租赁管理demo教务信息查询管理demo简易就业信息管理系统简易理财管理系统医院人事管理系统房屋租赁管理dem......
  • [实践应用] 深度学习之模型性能评估指标
    文章总览:YuanDaiMa2048博客文章总览深度学习之模型性能评估指标分类任务回归任务排序任务聚类任务生成任务其他介绍在机器学习和深度学习领域,评估模型性能是一项至关重要的任务。不同的学习任务需要不同的性能指标来衡量模型的有效性。以下是对一些常见任务及其相......
  • opencv学习:calcHist 函数绘制图像直方图及代码实现
    cv2.calcHist函数是OpenCV库中用于计算图像直方图的函数。直方图是一种统计图像中像素值分布的工具,它可以提供图像的亮度、颜色等信息。这个函数可以用于灰度图像和彩色图像。函数语法hist=cv2.calcHist(images,channels,mask,histSize,ranges,accumulate=False)......
  • opencv学习:图像下采样和上采样及拉普拉斯金字塔
    图像下采样和上采样OpenCV(OpenSourceComputerVisionLibrary)是一个开源的计算机视觉和机器学习软件库,它提供了大量的图像处理功能,包括图像的上采样和下采样。下采样(Downsampling)下采样是减少图像分辨率的过程,通常用于图像压缩、图像分析等场景。在OpenCV中,下采样可以通过......
  • opencv学习:图像旋转的两种方法,旋转后的图片进行模板匹配代码实现
    图像旋转在图像处理中,rotate和rot90是两种常见的图像旋转方法,它们在功能和使用上有一些区别。下面我将分别介绍这两种方法,并解释它们的主要区别rot90 方法rot90方法是NumPy提供的一种数组旋转函数,它主要用于对二维数组(如图像)进行90度的旋转。这个方法比较简单,只支持9......
  • opencv学习:信用卡卡号识别
    该代码用于从信用卡图像中自动识别和提取数字信息。该系统将识别信用卡类型,并输出信用卡上的数字序列。1.创建命令行参数数字模板信用卡#创建命令行参数解析器ap=argparse.ArgumentParser()#添加命令行参数-i/--image,指定输入图像路径ap.add_argument("-i","--i......
  • SAP学习笔记 - 开发05 - Fiori UI5 开发环境搭建2 Fiori Tools插件安装,SEGW创建后台程
    上一章学习了FioriUI5的开发环境搭建 -安装VSCode -安装Node.js -安装SAPUI5SAP学习笔记-开发04-FioriUI5开发环境搭建-CSDN博客本章继续学习FioriUI5开发环境搭建-VSCode安装FioriTools插件-SEGW创建后台程序,注册服务,GatewayClient确认服务......
  • 深度学习自编码器 - 正则自编码器篇
    序言深度学习领域中,自编码器(Autoencoder\text{Autoencoder}Autoencoder)作为一种无监督学习技术,凭借其独特的结构在数据降维、特征提取、异常检测及数据去噪等方面展现出......
  • 清理内存还只会从桌面拖拽进回收站?进来学习清理内存的正确方式
    电脑清理内存的方法随着电脑的使用,内存会被不断的占用,如果不及时清理,会导致电脑运行变慢,甚至崩溃。因此,清理电脑内存是非常重要的。下面介绍几种清理电脑内存的方法。 首先,可以使用操作系统自带的工具来清理内存。Windows系统自带的“磁盘清理”工具可以帮助清理系统垃圾文件,清......