首页 > 其他分享 >笨方法实现unet

笨方法实现unet

时间:2023-07-18 18:45:19浏览次数:36  
标签:nn 实现 256 self channels unet ._ 128 方法

import logging

logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s-%(filename)s[line:%(lineno)d]-%(levelname)s:%(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S')
import torch
import torch.nn as nn
import torch.nn.functional as F


class iUnet(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        # 第一次卷积 - encode   N 3 512 512 -> N 64 256 256
        self.conv1_1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1)
        self.bn1_1 = nn.BatchNorm2d(64)
        self.relu1_1 = nn.ReLU(inplace=1)

        self.conv1_2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1)
        self.bn1_2 = nn.BatchNorm2d(64)
        self.relu1_2 = nn.ReLU(inplace=1)

        self.pool1 = nn.MaxPool2d(2)  # N 64 512 512 -> N 64 256 256

        # 第二次卷积 - encode  N 64 256 256 -> N 128 128 128
        self.conv2_1 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1)
        self.bn2_1 = nn.BatchNorm2d(128)
        self.relu2_1 = nn.ReLU(inplace=1)

        self.conv2_2 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1)
        self.bn2_2 = nn.BatchNorm2d(128)
        self.relu2_2 = nn.ReLU(inplace=1)

        self.pool2 = nn.MaxPool2d(2)  # N 128 256 256 -> N 128 128 128

        # 第三次卷积 - encode  N 128 128 128 -> N 256 64 64
        self.conv3_1 =  nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1)
        self.bn3_1 = nn.BatchNorm2d(256)
        self.relu3_1 = nn.ReLU(inplace=1)

        self.conv3_2 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1)
        self.bn3_2 = nn.BatchNorm2d(256)
        self.relu3_2 = nn.ReLU(inplace=1)

        self.pool3 = nn.MaxPool2d(2)  # N 256 128 128 -> N 256 64 64

        # 第四次卷积 - encode N 256 64 64 -> N 512 32 32
        self.conv4_1 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1)
        self.bn4_1 = nn.BatchNorm2d(512)
        self.relu4_1 = nn.ReLU(inplace=1)

        self.conv4_2 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1)
        self.bn4_2 = nn.BatchNorm2d(512)
        self.relu4_2 = nn.ReLU(inplace=1)

        self.pool4 = nn.MaxPool2d(2)  # N 512 64 64 -> N 512 32 32

        # 第五次卷积 - encode  N 512 32 32 -> N 1024 32 32
        self.conv5_1 = nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, padding=1)
        self.bn5_1 = nn.BatchNorm2d(1024)
        self.relu5_1 = nn.ReLU(inplace=1)

        self.conv5_2 = nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, padding=1)
        self.bn5_2 = nn.BatchNorm2d(1024)
        self.relu5_2 = nn.ReLU(inplace=1)

        # 第1次解码 - decode
        self.upsample1 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self._conv1 = nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=3, padding=1)

        self._conv1_1 = nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=3, padding=1)
        self._bn1_1 = nn.BatchNorm2d(512)
        self._relu1_1 = nn.ReLU(inplace=1)

        self._conv1_2 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1)
        self._bn1_2 = nn.BatchNorm2d(512)
        self._relu1_2 = nn.ReLU(inplace=1)

        # 第2次解码 - decode
        self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self._conv2 = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, padding=1)

        self._conv2_1 = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, padding=1)
        self._bn2_1 = nn.BatchNorm2d(256)
        self._relu2_1 = nn.ReLU(inplace=1)

        self._conv2_2 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1)
        self._bn2_2 = nn.BatchNorm2d(256)
        self._relu2_2 = nn.ReLU(inplace=1)

        # 第3次解码 - decode
        self.upsample3 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self._conv3 = nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, padding=1)

        self._conv3_1 = nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, padding=1)
        self._bn3_1 = nn.BatchNorm2d(128)
        self._relu3_1 = nn.ReLU(inplace=1)

        self._conv3_2 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1)
        self._bn3_2 = nn.BatchNorm2d(128)
        self._relu3_2 = nn.ReLU(inplace=1)

        # 第4次解码 - decode
        self.upsample4 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self._conv4 = nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, padding=1)

        self._conv4_1 = nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, padding=1)
        self._bn4_1 = nn.BatchNorm2d(64)
        self._relu4_1 = nn.ReLU(inplace=1)

        self._conv4_2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1)
        self._bn4_2 = nn.BatchNorm2d(64)
        self._relu4_2 = nn.ReLU(inplace=1)

        # 输出类别信息
        self.out = nn.Conv2d(in_channels=64, out_channels=num_classes, kernel_size=1)

    def forward(self, x):
        # 编码
        x = self.relu1_1(self.bn1_1(self.conv1_1(x)))
        x1 = self.relu1_2(self.bn1_2(self.conv1_2(x)))
        # logging.info(f'x1:{x1.shape}')

        x = self.pool1(x1)
        x = self.relu2_1(self.bn2_1(self.conv2_1(x)))
        x2 = self.relu2_2(self.bn2_2(self.conv2_2(x)))
        # logging.info(f'x2:{x2.shape}')

        x = self.pool2(x2)
        x = self.relu3_1(self.bn3_1(self.conv3_1(x)))
        x3 = self.relu3_2(self.bn3_2(self.conv3_2(x)))
        # logging.info(f'x3:{x3.shape}')

        x = self.pool3(x3)
        x = self.relu4_1(self.bn4_1(self.conv4_1(x)))
        x4 = self.relu4_2(self.bn4_2(self.conv4_2(x)))
        # logging.info(f'x4:{x4.shape}')

        x = self.pool4(x4)
        x = self.relu5_1(self.bn5_1(self.conv5_1(x)))
        x = self.relu5_2(self.bn5_2(self.conv5_2(x)))
        # logging.info(f'x5:{x.shape}')

        # 解码
        x = self.upsample1(x)
        x = self._conv1(x)
        x = torch.cat([x, x4], dim=1)
        x = self._relu1_1(self._bn1_1(self._conv1_1(x)))
        x = self._relu1_2(self._bn1_2(self._conv1_2(x)))
        # logging.info(f'dx1:{x.shape}')

        x = self.upsample2(x)
        x = self._conv2(x)
        x = torch.cat([x, x3], dim=1)
        x = self._relu2_1(self._bn2_1(self._conv2_1(x)))
        x = self._relu2_2(self._bn2_2(self._conv2_2(x)))
        # logging.info(f'dx2:{x.shape}')

        x = self.upsample3(x)
        x = self._conv3(x)
        x = torch.cat([x, x2], dim=1)
        x = self._relu3_1(self._bn3_1(self._conv3_1(x)))
        x = self._relu3_2(self._bn3_2(self._conv3_2(x)))
        # logging.info(f'dx3:{x.shape}')

        x = self.upsample4(x)
        x = self._conv4(x)
        x = torch.cat([x, x1], dim=1)
        x = self._relu4_1(self._bn4_1(self._conv4_1(x)))
        x = self._relu4_2(self._bn4_2(self._conv4_2(x)))
        # logging.info(f'dx4:{x.shape}')

        x = self.out(x)
        return x


if __name__ == '__main__':
    data = torch.randn(4, 3, 384, 384)
    net = iUnet()
    pred = net(data)    

标签:nn,实现,256,self,channels,unet,._,128,方法
From: https://www.cnblogs.com/ddzhen/p/17563842.html

相关文章

  • 智安云重磅上线,"数智一体"赋能智安云生态价值实现!
    智安网络作为互联网行业的先行者和持续创新者,一直秉承着为客户创造更多数字化价值的理念,在这一信念下,致力于为用户打造一个开放、安全、透明和便捷的云计算平台。2023年7月19日,智安云平台V1.0.2版本正式上线,开始面向广大用户提供底层定制化服务、大数据、综合云计算服务以及企业级......
  • 如何使用CSS3 @font-face 实现个性化字体
    如何使用CSS3@font-face实现个性化字体。 在网页中,我们可以使用CSS的font-family属性来定义字体。但是,定义的字体能否在用户的电脑上正确显示,取决于用户的电脑上是否安装了该字体。我们经常看到国外的一些个人网站使用了非常漂亮的字体,而这些字体通常用户的电脑上并没有安装,所......
  • 动态加载页面的爬虫方法
    首先,可以直接手动拉到网页最下面,然后把F12里面的网页节点元素复制成文本,去获取目标进行下载,代码如下,用到的库BeautifulSoup:importosimporturllib.requestimportrefrombs4importBeautifulSoupasbsimportrandomasrdimporttimedefget_imgs(text):soup=bs......
  • Asp.Net Core 实现异步操作锁 (SemaphoreSlim)
    /设置同时访问线程最大数量staticSemaphoreSlim_semaphore=newSemaphoreSlim(4);staticvoidAccessDatabase(stringname,intseconds){Console.WriteLine($"{name}waitstoaccessadatabase");_semaphore.Wait();Console.WriteLine($"{name}wa......
  • 【解决方法】通过二层互联实现AP 与AC 的互联与AP 的上线
    环境:工具:锐捷EVE模拟器,VMwareWorkstationPro远程工具:SecureCRT系统版本:Windows10问题描述:描述:搭建一个瘦AP网络环境,使用2层互联,用户与其他网段的通信都通过核心上的网关进行。提示:若按照教程还是无法完成操作,可以进入右侧的企鹅,找我看看。解决方法-视频与文......
  • Python中的方法重写与名称修饰
    Python中的方法重写与名称修饰今天写python系统的时候,发现父类怎么老是调用不了子类重写的方法,整了好久才发现,python的名称修饰机制,好久没写代码了,这一块知识点忘干净了,下面进行总结。在Python中,方法重写是面向对象编程中的重要概念,它允许子类对父类的方法进行重新定义以满足子......
  • iOS测试包的安装方法
    iOS测试包根据要安装的机器类型可以分为2种:.app模拟器测试包.ipa真机测试包.app模拟器测试包的安装方式方式一:Xcode生成安装包1.Xcode运行项目,生成app包2.将APP包拖到模拟器中方式二:IPA包下载得到安装包1.将ipa包的后缀改成.zip,然后解压2.取出Payload目录下的.app文件......
  • vue使用hiprint实现打印(vue-plugin-hiprint)
    1、安装插件:npminstallvue-plugin-hiprint或yarnaddvue-plugin-hiprint2、普通使用:<template><divclass="box"><divclass="box-tool"><el-button-group><el-buttontype="primary......
  • Python获取文件夹下文件夹的名字,并存excel为一列(方法一)
    大家好,我是皮皮。一、前言这个事情还得从前几天在Python最强王者群【东哥】问了一个Python自动化办公处理的问题,需求倒是不难,一起来看看吧。二、实现过程这里【wangning】又给了一个答案,他自己之前整理的文章,不过需要自己稍微修改下才行。后来【魏哥】看到了,并且给出了如下......
  • Python pandas.DataFrame.iat函数方法的使用
    DataFrame.iat按整数位置访问行/列对的单个值。与iloc类似,两者都提供基于整数的查找。如果只需要在DataFrame或Series中获取或设置一个值,则使用iat。Raises:当整数位置超出界限时抛出IndexError例子:>>>df=pd.DataFrame([[0,2,3],[0,4,1],[10,20,30]],.......