首页 > 其他分享 >VGG使用块的网络——pytorch版

VGG使用块的网络——pytorch版

时间:2023-08-06 14:55:43浏览次数:52  
标签:nn conv VGG 网络 channels pytorch num arch out

import torch
from torch import nn
from d2l import torch as d2l

def vgg_block(num_convs,in_channels,out_channels):
    layers = []
    for _ in range(num_convs):
        layers.append(nn.Conv2d(
            in_channels,out_channels,kernel_size=3, padding=1
        ))
        layers.append(nn.ReLU())
        # 每个输出保证都是一样的
        in_channels = out_channels
    layers.append(nn.MaxPool2d(
        kernel_size=2,stride=2
    ))
    return nn.Sequential(*layers)

conv_arch=((1, 64), (1, 128), (2, 256), (2, 512), (2, 512))

def vgg(conv_arch):
    conv_blks=[]
    in_channels=1
    for (num_convs,out_channels) in conv_arch:
        conv_blks.append(vgg_block(
            num_convs,in_channels,out_channels
        ))
        in_channels=out_channels
    return nn.Sequential(
        *conv_blks,nn.Flatten(),
        nn.Linear(out_channels*7*7,4096),nn.ReLU(),
        nn.Dropout(0.5),nn.Linear(4096,4096),nn.ReLU(),
        nn.Dropout(0.5),nn.Linear(4096,10)
    )

net = vgg(conv_arch)

x = torch.randn(size=(1,1,224,224))
for blk in net:
    x = blk(x)
    print(blk.__class__.__name__,'output shape:\t',x.shape)

ratio = 4
small_conv_arch=[(pair[0],pair[1]//ratio) for pair in conv_arch]
net = vgg(small_conv_arch)

lr, num_epochs, batch_size = 0.05, 10, 128
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=224)
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())

 

标签:nn,conv,VGG,网络,channels,pytorch,num,arch,out
From: https://www.cnblogs.com/jinbb/p/17609411.html

相关文章

  • NiN网络——pytorch版
    importtorchfromtorchimportnnfromd2limporttorchasd2ldefnin_block(in_channels,out_channels,kernel_size,strides,padding):returnnn.Sequential(nn.Conv2d(in_channels,out_channels,kernel_size,strides,padding),nn.ReLU(),nn.Co......
  • GoogLeNet网络——pytorch版
    importtorchfromtorchimportnnfromtorch.nnimportfunctionalasFfromd2limporttorchasd2lclassInception(nn.Module):#c1-c4是每条路径的输出通道数def__init__(self,in_channels,c1,c2,c3,c4,**kwargs):super(Inception,self).__init__(......
  • 步幅与填充——pytorch
    importtorchfromtorchimportnndefcomp_conv2d(conv2d,x):#在维度前面加上通道数和批量大小数1x=x.reshape((1,1)+x.shape)#得到4维y=conv2d(x)#把前面两维去掉returny.reshape(y.shape[2:])#padding填充为1,左右conv2d=nn.Conv2d......
  • 多输入多输出通道——pytorch版
    importtorchfromd2limporttorchasd2lfromtorchimportnn#多输入通道互相关运算defcorr2d_multi_in(x,k):#zip对每个通道配对,返回一个可迭代对象,其中每个元素是一个(x,k)元组,表示一个输入通道和一个卷积核#再做互相关运算returnsum(d2l.corr2d......
  • 池化层——pytorch版
    importtorchfromtorchimportnnfromd2limporttorchasd2l#实现池化层的正向传播defpool2d(x,pool_size,mode='max'):#获取窗口大小p_h,p_w=pool_size#获取偏移量y=torch.zeros((x.shape[0]-p_h+1,x.shape[1]-p_w+1))foriinrange(y.sh......
  • LeNet卷积神经网络——pytorch版
    importtorchfromtorchimportnnfromd2limporttorchasd2lclassReshape(torch.nn.Module):defforward(self,x):#批量大小默认,输出通道为1returnx.view(-1,1,28,28)net=torch.nn.Sequential(#28+4-5+1=28输出通道为6Reshape()......
  • 前端黑魔法 —— 隐藏网络请求的调用栈
    前言浏览器网络控制台会记录每个请求的调用栈(Initiator/启动器),可协助调试者定位到发起请求的代码位置。为了不让破解者轻易分析程序,能否隐藏请求的调用栈?事件回调事实上,使用之前《如何让JS代码不可断点》文中的方案,通过「内置回调」到「原生函数」,即可隐藏请求的调用栈:......
  • 网络流学习笔记
    目录网络流介绍1.1一些概念1.2网络流整体思路EK算法dinic算法当前弧优化求二分图最大匹配费用流1.网络流介绍1.1一些概念网络流可以抽象为:你有一个自来水厂和很多输水管,和一个目标点,每一个输水管都有一个流量的限制。现在要将水从自来水厂运到输水......
  • 回顾网络基础
    OSI和TCP/IP是很基础但又非常重要的知识,很多知识点都是以它们为基础去串联的,作为底层,掌握得越透彻,理解上层时会越顺畅。今天这篇网络基础科普,就是根据OSI层级去逐一展开的。01计算机网络基础01 计算机网络的分类按照网络的作用范围:广域网(WAN)、城域网(MAN)、局域网(LAN);按照网络使用者:......
  • 002-深度学习数学基础(神经网络、梯度下降、损失函数)
    0.前言人工智能可以归结于一句话:针对特定的任务,找出合适的数学表达式,然后一直优化表达式,直到这个表达式可以用来预测未来。针对特定的任务:首先我们需要知道的是,人工智能其实就是为了让计算机看起来像人一样智能,为什么这么说呢?举一个人工智能的例子:我们人看到一个动物的图片,就......