首页 > 其他分享 >CNN tflearn处理mnist图像识别代码解说——conv_2d参数解释,整个网络的训练,主要就是为了学那个卷积核啊。

CNN tflearn处理mnist图像识别代码解说——conv_2d参数解释,整个网络的训练,主要就是为了学那个卷积核啊。

时间:2023-08-04 21:36:58浏览次数:34  
标签:filter 图像识别 network conv 卷积 2d size

官方参数解释:

Convolution 2D

tflearn.layers.conv.conv_2d (incoming, nb_filter, filter_size, strides=1, padding='same', activation='linear', bias=True, weights_init='uniform_scaling', bias_init='zeros', regularizer=None, weight_decay=0.001, trainable=True, restore=True, reuse=False, scope=None, name='Conv2D')

Input

4-D Tensor [batch, height, width, in_channels].

Output

4-D Tensor [batch, new height, new width, nb_filter].

Arguments

  • incoming: Tensor. Incoming 4-D Tensor.
  • nb_filter: int. The number of convolutional filters.
  • filter_size: int or list of int. Size of filters.
  • strides: 'intor list ofint`. Strides of conv operation. Default: [1 1 1 1].
  • padding: str from "same", "valid". Padding algo to use. Default: 'same'.
  • activation: str (name) or function (returning a Tensor) or None. Activation applied to this layer (see tflearn.activations). Default: 'linear'.
  • bias: bool. If True, a bias is used.
  • weights_init: str (name) or Tensor. Weights initialization. (see tflearn.initializations) Default: 'truncated_normal'.
  • bias_init: str (name) or Tensor. Bias initialization. (see tflearn.initializations) Default: 'zeros'.
  • regularizer: str (name) or Tensor. Add a regularizer to this layer weights (see tflearn.regularizers). Default: None.
  • weight_decay: float. Regularizer decay parameter. Default: 0.001.
  • trainable: bool. If True, weights will be trainable.
  • restore: bool. If True, this layer weights will be restored when loading a model.
  • reuse: bool. If True and 'scope' is provided, this layer variables will be reused (shared).
  • scope: str. Define this layer scope (optional). A scope can be used to share variables between layers. Note that scope will override name.
  • name: A name for this layer (optional). Default: 'Conv2D'.

 

代码:

# 64 filters


net = tflearn.conv_2d(net, 64, 3, activation='relu')

CNN tflearn处理mnist图像识别代码解说——conv_2d参数解释,整个网络的训练,主要就是为了学那个卷积核啊。_机器学习

如果一个卷积层有4个feature map,那是不是就有4个卷积核?
是的。

CNN tflearn处理mnist图像识别代码解说——conv_2d参数解释,整个网络的训练,主要就是为了学那个卷积核啊。_卷积核_02

这4个卷积核如何定义?
通常是随机初始化再用BP算梯度做训练。如果数据少或者没有labeled data的话也可以考虑用K-means的K个中心点,逐层做初始化。
卷积核是学习的。卷积核是因为权重的作用方式跟卷积一样,所以叫卷积层,其实你还是可以把它看成是一个parameter layer,需要更新的。

这四个卷积核就属于网络的参数,然后通过BP进行训练

整个网络的训练,主要就是为了学那个卷积核啊。

先初始化一个,之后BP调整,你可以去看看caffe的源码。

--------------------------------------------------------------------------------------------------
下面内容摘自:
1. from __future__ import division, print_function, absolute_import  
2.   
3. import tflearn  
4. from tflearn.layers.core import input_data, dropout, fully_connected  
5. from tflearn.layers.conv import conv_2d, max_pool_2d  
6. from tflearn.layers.normalization import local_response_normalization  
7. from tflearn.layers.estimator import regression  
8. #加载大名顶顶的mnist数据集(http://yann.lecun.com/exdb/mnist/)  
9. import tflearn.datasets.mnist as mnist  
10. X, Y, testX, testY = mnist.load_data(one_hot=True)  
11. X = X.reshape([-1, 28, 28, 1])  
12. testX = testX.reshape([-1, 28, 28, 1])  
13.   
14. network = input_data(shape=[None, 28, 28, 1], name='input')  
15. # CNN中的卷积操作,下面会有详细解释  
16. network = conv_2d(network, 32, 3, activation='relu', regularizer="L2")  
17. # 最大池化操作  
18. network = max_pool_2d(network, 2)  
19. # 局部响应归一化操作  
20. network = local_response_normalization(network)  
21. network = conv_2d(network, 64, 3, activation='relu', regularizer="L2")  
22. network = max_pool_2d(network, 2)  
23. network = local_response_normalization(network)  
24. # 全连接操作  
25. network = fully_connected(network, 128, activation='tanh')  
26. # dropout操作  
27. network = dropout(network, 0.8)  
28. network = fully_connected(network, 256, activation='tanh')  
29. network = dropout(network, 0.8)  
30. network = fully_connected(network, 10, activation='softmax')  
31. # 回归操作  
32. network = regression(network, optimizer='adam', learning_rate=0.01,  
33. 'categorical_crossentropy', name='target')  
34.   
35. # Training  
36. # DNN操作,构建深度神经网络  
37. model = tflearn.DNN(network, tensorboard_verbose=0)  
38. model.fit({'input': X}, {'target': Y}, n_epoch=20,  
39. 'input': testX}, {'target': testY}),  
40. 100, show_metric=True, run_id='convnet_mnist')

关于conv_2d函数,在源码里是可以看到总共有14个参数,分别如下:

1.incoming: 输入的张量,形式是[batch, height, width, in_channels]
2.nb_filter: filter的个数
3.filter_size: filter的尺寸,是int类型
4.strides: 卷积操作的步长,默认是[1,1,1,1]
5.padding: padding操作时标志位,"same"或者"valid",默认是“same”
6.activation: 激活函数(ps:这里需要了解的知识很多,会单独讲)
7.bias: bool量,如果True,就是使用bias
8.weights_init: 权重的初始化
9.bias_init: bias的初始化,默认是0,比如众所周知的线性函数y=wx+b,其中的w就相当于weights,b就是bias
10.regularizer: 正则项(这里需要讲解的东西非常多,会单独讲)
11.weight_decay: 权重下降的学习率
12.trainable: bool量,是否可以被训练
13.restore: bool量,训练的模型是否被保存
14.name: 卷积层的名称,默认是"Conv2D"
关于max_pool_2d函数,在源码里有5个参数,分别如下:1.incoming ,类似于conv_2d里的incoming2.kernel_size:池化时核的大小,相当于conv_2d时的filter的尺寸3.strides:类似于conv_2d里的strides4.padding:同上5.name:同上看了这么多参数,好像有些迷糊,我先用一张图解释下每个参数的意义。其中的filter就是[1 0 10 1 01 0 1],size=3,由于每次移动filter都是一个格子,所以strides=1.关于最大池化可以看看下面这张图,这里面 strides=1,kernel_size =2(就是每个颜色块的大小),图中示意的最大池化(可以提取出显著信息,比如在进行文本分析时可以提取一句话里的关键字,以及图像处理中显著颜色,纹理等),关于池化这里多说一句,有时需要平均池化,有时需要最小池化。下面说说其中的padding操作,做图像处理的人对于这个操作应该不会陌生,说白了,就是填充。比如你对图像做卷积操作,比如你用的3×3的卷积核,在进行边上操作时,会发现卷积核已经超过原图像,这时需要把原图像进行扩大,扩大出来的就是填充,基本都填充0。




Convolution Demo. Below is a running demo of a CONV layer. Since 3D volumes are hard to visualize, all the volumes (the input volume (in blue), the weight volumes (in red), the output volume (in green)) are visualized with each depth slice stacked in rows. The input volume is of size W1=5,H1=5,D1=3">W1=5,H1=5,D1=3">W1=5,H1=5,D1=3, and the CONV layer parameters are K=2,F=3,S=2,P=1">K=2,F=3,S=2,P=1. That is, we have two filters of size 3×3">3×3, and they are applied with a stride of 2. Therefore, the output volume size has spatial size (5 - 3 + 2)/2 + 1 = 3. Moreover, notice that a padding of P=1">P=1

CNN tflearn处理mnist图像识别代码解说——conv_2d参数解释,整个网络的训练,主要就是为了学那个卷积核啊。_2d_03

 

General pooling. In addition to max pooling, the pooling units can also perform other functions, such as average pooling or even L2-norm pooling. Average pooling was often used historically but has recently fallen out of favor compared to the max pooling operation, which has been shown to work better in practice.

CNN tflearn处理mnist图像识别代码解说——conv_2d参数解释,整个网络的训练,主要就是为了学那个卷积核啊。_ide_04

CNN tflearn处理mnist图像识别代码解说——conv_2d参数解释,整个网络的训练,主要就是为了学那个卷积核啊。_tensorflow_05

Pooling layer downsamples the volume spatially, independently in each depth slice of the input volume. Left: In this example, the input volume of size [224x224x64] is pooled with filter size 2, stride 2 into output volume of size [112x112x64]. Notice that the volume depth is preserved. Right: The most common downsampling operation is max, giving rise to max pooling, here shown with a stride of 2. That is, each max is taken over 4 numbers (little 2x2 square).

标签:filter,图像识别,network,conv,卷积,2d,size
From: https://blog.51cto.com/u_11908275/6968481

相关文章

  • ICCV论文速读:SOTA!越简单,越强大!ByteTrackV2-通用2D、3D跟踪算法(开源)
    前言 本文提出了一个分层的数据关联策略来寻找低分检测框中的真实目标,这缓解了目标丢失和轨迹不连续的问题。这个简单通用的数据关联策略在2D和3D设置下都表现良好。另外,由于在3D场景中预测对象在世界坐标系中的速度比较容易,本文提出了一种辅助的运动预测策略,将检测到的速度与卡......
  • 使用 Spring 3 MVC HttpMessageConverter 功能构建 RESTful web 服务(转)
    Spring,构建Java™平台和EnterpriseEdition(JavaEE)应用程序的著名框架,现在在其模型-视图-控制器(Model-View-Controller,MVC)层支持具象状态传输(REST)。RESTfulweb服务根据客户端请求生成多个具象(representations)很重要。在本篇文章中,学习使用HttpMessageConverter 生成......
  • Spring源码分析(五) MappingJackson2HttpMessageConverter
    大家用过springmvc的肯定都用过@RequestBody和@ResponseBody注解吧,你了解这个的原理吗?这篇文章我们就来说下它是怎么实现json转换的。首先来看一个类RequestResponseBodyMethodProcessor,这个类继承了AbstractMessageConverterMethodProcessor,我们来看看这个类的构造方法protec......
  • go语言基础-strings和strconv包
    作为一种基本数据结构,每种语言都有一些对于字符串的预定义处理函数。Go中使用 strings 包来完成对字符串的主要操作。前缀和后缀HasPrefix() 判断字符串 s 是否以 prefix 开头:strings.HasPrefix(s,prefixstring)boolHasSuffix() 判断字符串 s 是否以 suffix......
  • Python 优化第一步: 性能分析实践 使用cporfile+gprof2dot可视化
    拿来主义:python-mcProfile-oprofile.pstatsto_profile.pygprof2dot-fpstatsprofile.pstats|dot-Tpng-oclick.png然后顺着浅色线条优化就OK了。 windows下:google下graphviz-2.38.msi,然后安装。dot命令需要。gitclone https://github.com/jrfonseca/gprof2dot.git......
  • 防干扰/抗电压波动双按键/2路触摸触控芯片VK36N2D SOP8 适用于厨房秤/温控器/加湿器等
    概述:VK36N2DSOP8具有2个触摸按键,可用来检测外部触摸按键上人手的触摸动作。该芯片具有较高的集成度,仅需极少的外部组件便可实现触摸按键的检测。提供了2个1对1输出脚,可通过IO脚选择上电输出电平,有直接输出和锁存输出2个型号可选。芯片内部采用特殊的集成电路,具有高电源电压抑制比......
  • 图像识别技术在医疗领域的应用与前景展望
    导言:图像识别技术在医疗领域中正发挥着越来越重要的作用。随着计算机视觉和深度学习技术的发展,图像识别已成为医学影像分析、疾病诊断和治疗方案制定的有力工具。本文将介绍图像识别技术在医疗领域的应用,并展望其未来在医学健康领域的发展前景。一、图像识别技术在医疗领域的应......
  • MappingJackson2HttpMessageConverter数据处理
    主键用的雪花算法,值域超过了js的范围……后端返回的日期字段总不是我想要的格式……空值的字段就不要返回了,省点流量吧……试试换成自己的MappingJackson2HttpMessageConverter呗Talkischeap,showyouthecode!importcom.fasterxml.jackson.annotation.JsonInclude;importco......
  • IBM ThinkPad T400 windows Vista sp1 官方恢复光盘(1CD+2DV
    http://www.nbbbs.com.cn/bbs/thread-12226-1-1.html IBMThinkPadT400windowsVistasp1官方恢复光盘(1CD+2DVD)下载1CD+2DVD版VISTA官方的恢复碟的纳米盘下载地址T400VISTABOOTCD下载地扯:T400vistaboot.nrgT400VISTA1DVDT400vista1.nrgT400VISTA2DVDT400vista2......
  • 浅谈-HttpMessageConverter接口
    HttpMessageConverter接口是SpringFramework中的一个接口,用于处理HTTP请求和响应体的消息转换。解释如下:在SpringWeb应用中,控制器(Controller)处理HTTP请求时,通常会返回响应结果给客户端。这些响应结果可以是Java对象、字符串、JSON数据、XML数据等。HttpMess......