首页 > 其他分享 >unet网络解析

unet网络解析

时间:2024-07-24 16:29:12浏览次数:19  
标签:features nn self torch 网络 channels unet 解析 out

Unet网络结构学习记录
导入包 #!/user/bin/python # coding=utf-8 import numpy as np import torch import torch.nn as nn 能够访问PyTorch中定义的所有神经网络层(如全连接层、卷积层、池化层等)、损失函数(如交叉熵损失、均方误差损失等)以及激活函数(如ReLU、Sigmoid等 import torch.nn.functional as F 包含了许多神经网络中常用的函数,这些函数通常是无状态的,即它们不保存任何可学习的参数(权重或偏置)。这些函数主要用于在模型的 forward 方法中直接调用,以执行诸如激活、池化、归一化等操作 from torch.optim import lr_scheduler, optimizer orch.optim.lr_scheduler 模块包含了一系列用于调整学习率的调度器。 import torchvision import os, sys import cv2 as cv from torch.utils.data import DataLoader, sampler #数据读取 class SegmentationDataset(object): def __init__(self, image_dir, mask_dir): self.images = [] self.masks = [] files = os.listdir(image_dir) sfiles = os.listdir(mask_dir) # 使用 os.listdir 列出该目录下的所有文件和目录 for i in range(len(sfiles)): img_file = os.path.join(image_dir, files[i]) mask_file = os.path.join(mask_dir, sfiles[i])#用于将多个路径组件合并成一个完整的路径。 # print(img_file, mask_file) self.images.append(img_file) 将img_file(一个图像文件或图像数据的引用)添加到self.images列表中。这里self指的是当前类的实例,而images是该实例的一个属性,它存储了一个列表,用于存放图像数据。 self.masks.append(mask_file) 将mask_file(一个掩码文件或掩码数据的引用)添加到self.masks列表中。同样,self.masks也是当前类实例的一个属性,只不过它用于存储掩码数据
可以使图像与掩码一一对应def __len__(self): 定义了一个类,并且希望这个类的对象能够被用在需要知道其“长度”或“大小”的情况中 return len(self.images) def num_of_samples(self): ?????为什么需要返回两个len(self.images) return len(self.images) def __getitem__(self, idx): if torch.is_tensor(idx): idx = idx.tolist()#从张量转为列表 image_path = self.images[idx] mask_path = self.masks[idx] else: image_path = self.images[idx] mask_path = self.masks[idx] img = cv.imread(image_path, cv.IMREAD_GRAYSCALE) # BGR order mask = cv.imread(mask_path, cv.IMREAD_GRAYSCALE) # 输入图像 img = np.float32(img) / 255.0#将图像数据转换为浮点数并归一化 img = np.expand_dims(img, 0)#增加一个新的维度 # 目标标签0 ~ 1, 对于 mask[mask <= 128] = 0 mask[mask > 128] = 1 mask = np.expand_dims(mask, 0) sample = {'image': torch.from_numpy(img), 'mask': torch.from_numpy(mask),} return sample class UNetModel(torch.nn.Module): def __init__(self, in_features=1, out_features=2, init_features=32):#init_features用于指定初始卷积层或某个嵌入层的输出特征数量 super(UNetModel, self).__init__() features = init_features self.encode_layer1 = torch.nn.Sequential( torch.nn.Conv2d(in_channels=in_features, out_channels=features, kernel_size=3, padding=1, stride=1), torch.nn.BatchNorm2d(num_features=features),#批量归一化层 torch.nn.ReLU(), torch.nn.Conv2d(in_channels=features, out_channels=features, kernel_size=3, padding=1, stride=1), torch.nn.BatchNorm2d(num_features=features), torch.nn.ReLU() ) self.pool1 = torch.nn.MaxPool2d(kernel_size=2, stride=2) self.encode_layer2 = torch.nn.Sequential( torch.nn.Conv2d(in_channels=features, out_channels=features*2, kernel_size=3, padding=1, stride=1), torch.nn.BatchNorm2d(num_features=features*2), torch.nn.ReLU(), torch.nn.Conv2d(in_channels=features*2, out_channels=features*2, kernel_size=3, padding=1, stride=1), torch.nn.BatchNorm2d(num_features=features * 2), torch.nn.ReLU() ) self.pool2 = torch.nn.MaxPool2d(kernel_size=2, stride=2) self.encode_layer3 = torch.nn.Sequential( torch.nn.Conv2d(in_channels=features*2, out_channels=features*4, kernel_size=3, padding=1, stride=1), torch.nn.BatchNorm2d(num_features=features * 4), torch.nn.ReLU(), torch.nn.Conv2d(in_channels=features*4, out_channels=features*4, kernel_size=3, padding=1, stride=1), torch.nn.BatchNorm2d(num_features=features * 4), torch.nn.ReLU() ) self.pool3 = torch.nn.MaxPool2d(kernel_size=2, stride=2) self.encode_layer4 = torch.nn.Sequential( torch.nn.Conv2d(in_channels=features*4, out_channels=features*8, kernel_size=3, padding=1, stride=1), torch.nn.BatchNorm2d(num_features=features * 8), torch.nn.ReLU(), torch.nn.Conv2d(in_channels=features*8, out_channels=features*8, kernel_size=3, padding=1, stride=1), torch.nn.BatchNorm2d(num_features=features * 8), torch.nn.ReLU(), ) self.pool4 = torch.nn.MaxPool2d(kernel_size=2, stride=2) self.encode_decode_layer = torch.nn.Sequential( torch.nn.Conv2d(in_channels=features*8, out_channels=features*16, kernel_size=3, padding=1, stride=1), torch.nn.BatchNorm2d(num_features=features * 16), torch.nn.ReLU(), torch.nn.Conv2d(in_channels=features*16, out_channels=features*16, kernel_size=3, padding=1, stride=1), torch.nn.BatchNorm2d(num_features=features * 16), torch.nn.ReLU() ) self.upconv4 = torch.nn.ConvTranspose2d( features * 16, features * 8, kernel_size=2, stride=2 ) self.decode_layer4 = torch.nn.Sequential( torch.nn.Conv2d(in_channels=features*16, out_channels=features*8, kernel_size=3, padding=1, stride=1), torch.nn.BatchNorm2d(num_features=features*8), torch.nn.ReLU(), torch.nn.Conv2d(in_channels=features*8, out_channels=features*8, kernel_size=3, padding=1, stride=1), torch.nn.BatchNorm2d(num_features=features * 8), torch.nn.ReLU(), ) self.upconv3 = torch.nn.ConvTranspose2d( features * 8, features * 4, kernel_size=2, stride=2 ) self.decode_layer3 = torch.nn.Sequential( torch.nn.Conv2d(in_channels=features*8, out_channels=features*4, kernel_size=3, padding=1, stride=1), torch.nn.BatchNorm2d(num_features=features * 4), torch.nn.ReLU(), torch.nn.Conv2d(in_channels=features*4, out_channels=features*4, kernel_size=3, padding=1, stride=1), torch.nn.BatchNorm2d(num_features=features * 4), torch.nn.ReLU() ) self.upconv2 = torch.nn.ConvTranspose2d( features * 4, features * 2, kernel_size=2, stride=2 ) self.decode_layer2 = torch.nn.Sequential( torch.nn.Conv2d(in_channels=features*4, out_channels=features*2, kernel_size=3, padding=1, stride=1), torch.nn.BatchNorm2d(num_features=features * 2), torch.nn.ReLU(), torch.nn.Conv2d(in_channels=features*2, out_channels=features*2, kernel_size=3, padding=1, stride=1), torch.nn.BatchNorm2d(num_features=features * 2), torch.nn.ReLU() ) self.upconv1 = torch.nn.ConvTranspose2d( features * 2, features, kernel_size=2, stride=2 ) self.decode_layer1 = torch.nn.Sequential( torch.nn.Conv2d(in_channels=features*2, out_channels=features, kernel_size=3, padding=1, stride=1), torch.nn.BatchNorm2d(num_features=features), torch.nn.ReLU(), torch.nn.Conv2d(in_channels=features, out_channels=features, kernel_size=3, padding=1, stride=1), torch.nn.BatchNorm2d(num_features=features), torch.nn.ReLU() ) self.out_layer = torch.nn.Sequential( torch.nn.Conv2d(in_channels=features, out_channels=out_features, kernel_size=1, padding=0, stride=1), ) def forward(self, x): enc1 = self.encode_layer1(x) enc2 = self.encode_layer2(self.pool1(enc1)) enc3 = self.encode_layer3(self.pool2(enc2)) enc4 = self.encode_layer4(self.pool3(enc3)) bottleneck = self.encode_decode_layer(self.pool4(enc4)) dec4 = self.upconv4(bottleneck) dec4 = torch.cat((dec4, enc4), dim=1)#用于沿指定维度连接张量 dec4 = self.decode_layer4(dec4) dec3 = self.upconv3(dec4) dec3 = torch.cat((dec3, enc3), dim=1) dec3 = self.decode_layer3(dec3) dec2 = self.upconv2(dec3) dec2 = torch.cat((dec2, enc2), dim=1) dec2 = self.decode_layer2(dec2) dec1 = self.upconv1(dec2) dec1 = torch.cat((dec1, enc1), dim=1) dec1 = self.decode_layer1(dec1) out = self.out_layer(dec1) return out #训练 if __name__ == '__main__': index = 0 num_epochs = 50 train_on_gpu = True unet = UNetModel().cuda() # model_dict = unet.load_state_dict(torch.load('unet_road_model-100.pt')) image_dir = r'D:\daima\CrackForest-dataset-master\CrackForest-dataset-master\train' mask_dir = r'D:\daima\CrackForest-dataset-master\CrackForest-dataset-master\png' dataloader = SegmentationDataset(image_dir, mask_dir) optimizer = torch.optim.SGD(unet.parameters(), lr=0.01, momentum=0.9) train_loader = DataLoader( dataloader, batch_size=1, shuffle=False) for epoch in range(num_epochs): train_loss = 0.0 for i_batch, sample_batched in enumerate(train_loader): images_batch, target_labels = \ sample_batched['image'], sample_batched['mask'] print(target_labels.min()) print(target_labels.max()) if train_on_gpu: images_batch, target_labels = images_batch.cuda(), target_labels.cuda() optimizer.zero_grad() # forward pass: compute predicted outputs by passing inputs to the model m_label_out_ = unet(images_batch) # print(m_label_out_) # calculate the batch loss target_labels = target_labels.contiguous().view(-1) m_label_out_ = m_label_out_.transpose(1,3).transpose(1, 2).contiguous().view(-1, 2) target_labels = target_labels.long() loss = torch.nn.functional.cross_entropy(m_label_out_, target_labels) print(loss) # backward pass: compute gradient of the loss with respect to model parameters loss.backward() # perform a single optimization step (parameter update) optimizer.step() # update training loss train_loss += loss.item() if index % 100 == 0: print('step: {} \tcurrent Loss: {:.6f} '.format(index, loss.item())) index += 1 # test(unet) # 计算平均损失 train_loss = train_loss / dataloader.num_of_samples() # 显示训练集与验证集的损失函数 print('Epoch: {} \tTraining Loss: {:.6f} '.format(epoch, train_loss)) # test(unet) # save model unet.eval() torch.save(unet.state_dict(), 'unet_road_model.pt') #测试 def test(unet): model_dict=unet.load_state_dict(torch.load('unet_road_model.pt')) root_dir = r'D:\daima\CrackForest-dataset-master\CrackForest-dataset-master\test' fileNames = os.listdir(root_dir) for f in fileNames: image = cv.imread(os.path.join(root_dir, f), cv.IMREAD_GRAYSCALE) h, w = image.shape img = np.float32(image) /255.0 img = np.expand_dims(img, 0) x_input = torch.from_numpy(img).view( 1, 1, h, w) probs = unet(x_input.cuda()) m_label_out_ = probs.transpose(1, 3).transpose(1, 2).contiguous().view(-1, 2) grad, output = m_label_out_.data.max(dim=1) output[output > 0] = 255 predic_ = output.view(h, w).cpu().detach().numpy() # print(predic_) # print(predic_.max()) # print(predic_.min()) # print(predic_) # print(predic_.shape) # cv.imshow("input", image) result = cv.resize(np.uint8(predic_), (w, h)) cv.imshow("unet-segmentation-demo", result) cv.waitKey(0) cv.destroyAllWindows()

 

标签:features,nn,self,torch,网络,channels,unet,解析,out
From: https://www.cnblogs.com/candice1/p/18321187

相关文章

  • ava 集合框架全解析:Collection vs Collections,Comparable vs Comparator,HashSet 工作
    Java中的集合框架是开发过程中不可或缺的一部分,但初学者常常会混淆其中的术语,特别是Collection和Collections。这篇博客将详细介绍它们之间的区别,并进一步探讨Comparable和Comparator、HashSet的工作原理,以及HashMap和Hashtable的区别。Collection和Collecti......
  • Qt - 网络相关的类
    1.QHostInfoQHostInfo类为主机名查找提供了静态函数。QHostInfo查找与主机名关联的IP地址或与IP地址关联的主机名。这个类提供了两个方便的静态函数:一个异步工作并在找到主机时发出信号,另一个阻塞并返回QHostInfo对象。要异步查找主机的IP地址,调用lookupHost(),它接受主机名......
  • 从零开始:神经网络(1)——什么是人工神经网络
      声明:本文章是根据网上资料,加上自己整理和理解而成,仅为记录自己学习的点点滴滴。可能有错误,欢迎大家指正。     人工神经网络(ArtificialNeuralNetwork,简称ANN)是一种模仿生物神经网络结构和功能的计算模型。它由大量的节点(或称神经元)相互连接而成,这些节点通常......
  • 网络安全前景大好,转行这些职位成了“香饽饽”
    网络安全就业前景大数据、人工智能、云计算、物联网、5G等新兴技术的高速发展,蒸蒸日上。但是随之也出现了许多问题,比如:政府单位、企业、个人信息泄露,网络安全问题日益严峻,网络空间安全建设刻不容缓。网络安全人才需求量巨大,人才缺口高达95%,人才输送与人才缺口的比例严重不......
  • 网络运维管理系统,维护企业网络信息安全的管理系统推荐!
    古有烽火传信,今则网络织就天下经纬;网络运维,企业之根本,信息之安全,重于泰山。当今,企业网络已成为企业运营的核心基础设施,其稳定性和安全性直接关系到企业的业务连续性和企业机密信息的保护,因此,选择一款高效、全面的网络运维管理系统对于维护企业网络信息安全至关重要。安企神软......
  • 使用STM32实现简单的网络通信
    概述在本文中,我们将介绍如何使用STM32微控制器实现简单的网络通信。我们将使用STM32Cube软件来配置和编程STM32微控制器。我们将使用TCP/IP协议栈,以便在STM32微控制器与计算机之间进行通信。我们将通过创建一个简单的服务器端和一个客户端来演示网络通信的实现。准备工作在......
  • 微信小程序 - 最新详细实现集成腾讯地图配置流程及使用教程,基于腾讯位置服务做地图标
    前言网上的教程代码太乱了,并且很少有真实请求的示例,本文提供优质配置教程及示例源码。在微信小程序开发中,详解实现接入腾讯地图教程,后台配置完整流程及使用教程,附带腾讯地图显示渲染和地图标记点,获取本机当前定位省市区或精确的经纬度,IP属地定位获取城市名称/市区名,将经......
  • python_网络编程_socket
    一、网络编程的基本概念通信协议:internet协议,任何私有网络支持此协议,就可以接入互联网二、七层协议与四层协议从下到上分别是:物理层、数据链路层、网络层、传输层、会话层、表示层、应用层三、掌握TCP、IP协议ip协议是整个TCP、IP协议族的核心IP地址就是会联网上计算......