首页 > 其他分享 >深度学习深入浅出

深度学习深入浅出

时间:2023-04-09 10:11:06浏览次数:33  
标签:模型 深入浅出 学习 test train 深度 images

目录
深度学习是机器学习的一个分支,其核心思想是利用深层神经网络对数据进行建模和学习,从而实现识别、分类、预测等任务。在过去几年中,深度学习技术取得了许多突破性的成果,如在图像识别、语音识别、自然语言处理、游戏AI等领域中。

本文将简要介绍深度学习的基本原理,并使用Python中的TensorFlow库演示如何实现一个简单的神经网络模型。

一 基本原理

深度学习中最基本的模型是神经网络(Neural Network),它的结构模仿了人类的神经系统,包含多个层级(Layer)。

神经网络的基本组成单元是神经元(Neuron),每个神经元接收多个输入,经过加权和与偏置项相加后通过一个激活函数(Activation Function)输出。

多个神经元可以组成一个层级,不同层级之间的神经元可以进行连接,形成一个完整的神经网络。

深度学习中的深度(Depth)指的是神经网络的层数,一般来说,层数越多,网络的表达能力越强。

训练神经网络需要使用反向传播算法(Backpropagation),通过反向传播误差信号,更新神经网络中的参数(Weight)和偏置项(Bias),使得模型的输出更加接近于真实值。

深度学习中最常用的神经网络结构是多层感知机(Multilayer Perceptron,MLP),它是由多层神经元组成的网络,每层之间相互连接,其中输入层接收数据,输出层输出结果,中间的隐藏层则对输入数据进行非线性变换和特征提取。MLP的训练过程通常使用反向传播算法(Backpropagation,BP)进行参数优化。

二 深度学习的优点

  1. 可以自主地学习和提取特征

深度学习的一个最大优点是可以自主地学习和提取数据中的特征。相比于传统机器学习方法,需要人工提取特征,深度学习可以自动提取最相关的特征。这使得深度学习在许多领域取得了巨大的成功,如图像识别、自然语言处理等。

  1. 可以处理大规模数据

深度学习可以处理大规模数据,并且随着数据规模的增加,深度学习的表现也会变得更好。这使得深度学习在许多领域都具有非常广泛的应用,如语音识别、自然语言处理、图像识别等。

  1. 可以处理非线性关系

传统的机器学习算法通常只能处理线性关系,但深度学习可以处理非线性关系。这使得深度学习在许多领域都有很好的表现,如图像识别、语音识别等。

  1. 可以进行端到端的学习

深度学习可以进行端到端的学习,即从输入数据到输出结果的整个过程都可以通过深度学习来完成。这使得深度学习非常适合处理一些复杂的任务,如自然语言处理、语音识别等。

三 深度学习的缺点

  1. 数据要求高

深度学习的模型需要大量的数据进行训练,而且数据的质量也需要较高。如果数据的质量不高,比如包含较多的噪声或错误,那么深度学习的效果将会受到很大的影响。此外,深度学习对数据的标注要求也较高,标注不准确的数据可能会影响模型的学习效果。

  1. 计算资源要求高

深度学习的模型通常需要进行大量的计算,因此需要较高的计算资源。在传统的CPU上训练深度学习模型往往非常缓慢,因此需要使用GPU或者TPU等硬件加速器来加快训练速度。此外,训练深度学习模型所需要的存储资源也非常大,因此需要较高的存储容量。

  1. 模型过于复杂

深度学习的模型通常非常复杂,包含大量的参数和层数,因此很难理解其内部的工作原理。这使得深度学习模型的可解释性较低,难以分析和调试。此外,过于复杂的模型也容易过拟合,导致在新数据上的表现不佳。

  1. 对人类知识的依赖较低

深度学习可以自主地提取数据中的特征,从而免去了手动特征提取的繁琐过程。然而,这也使得深度学习模型对人类知识的依赖较低。这意味着深度学习可能会忽略一些重要的特征,因为这些特征在数据中并不明显。同时,深度学习也容易受到数据集本身的偏差影响,从而导致模型的预测结果不准确。

四 深度学习应用

深度学习可以应用于各种领域,比如图像识别、自然语言处理、语音识别等。在图像识别领域,深度学习可以用来识别图像中的物体,从而帮助计算机自主地理解图像内容。在自然语言处理领域,深度学习可以用来自动翻译、问答、文本生成等任务。在语音识别领域,深度学习可以用来识别人的语音指令,从而帮助人们更方便地与计算机进行交互。

手写数字识别

TensorFlow是由Google开发的一个开源机器学习库,可以用于各种机器学习任务,包括深度学习。它的核心是一个图(Graph)计算模型,用户可以使用TensorFlow构建图中的节点(Node)和边(Edge),并执行计算。

在TensorFlow中,神经网络模型是通过一系列的层级(Layer)组成的。每个层级包含多个神经元(Neuron),每个神经元的输出通过一个激活函数(Activation Function)进行变换。TensorFlow提供了多种常用的激活函数,如sigmoid、ReLU、tanh等。

手写数字识别是深度学习中的一个经典问题,它要求识别0-9十个数字的手写图像。在本文中,我们将使用MNIST数据集,它包含了一系列已经被标记过的手写数字图像,每个图像的大小为28x28像素。

首先,我们需要导入必要的库:

import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt

然后,我们需要加载手写数字数据集MNIST,并对数据进行预处理:

mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

train_images = train_images / 255.0
test_images = test_images / 255.0

train_images = train_images.reshape((-1, 784))
test_images = test_images.reshape((-1, 784))

接下来,我们可以定义我们的神经网络模型:

model = keras.Sequential([
    keras.layers.Dense(128, activation='relu', input_shape=(784,)),
    keras.layers.Dropout(0.5),
    keras.layers.Dense(10, activation='softmax')
])

这个模型有两个隐藏层,第一个隐藏层有128个神经元,使用ReLU激活函数,第二个隐藏层使用Dropout来避免过拟合,输出层有10个神经元,使用softmax激活函数。

接下来,我们需要编译模型,并训练它:

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

history = model.fit(train_images, train_labels, epochs=10, batch_size=64,
                    validation_data=(test_images, test_labels))

我们使用Adam优化器,稀疏交叉熵损失函数和准确率作为评价指标进行模型编译。然后,我们使用fit方法来训练模型,将训练集和测试集传递给模型,并设置10个epochs和64个batch size。

最后,我们可以使用训练好的模型来对手写数字进行预测:

predictions = model.predict(test_images)

print(np.argmax(predictions[:10], axis=1))
print(test_labels[:10])

我们使用predict方法来对测试集进行预测,并使用argmax函数找到预测结果中最大值的索引,作为预测的类别。最后,我们打印前10个预测结果和它们对应的真实标签。

完整代码:

import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt

mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

train_images = train_images / 255.0
test_images = test_images / 255.0

train_images = train_images.reshape((-1, 784))
test_images = test_images.reshape((-1, 784))

model = keras.Sequential([
    keras.layers.Dense(128, activation='relu', input_shape=(784,)),
    keras.layers.Dropout(0.5),
    keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

history = model.fit(train_images, train_labels, epochs=10, batch_size=64,
                    validation_data=(test_images, test_labels))
                    
predictions = model.predict(test_images)

print(np.argmax(predictions[:10], axis=1))
print(test_labels[:10])

标签:模型,深入浅出,学习,test,train,深度,images
From: https://www.cnblogs.com/qi66/p/17299891.html

相关文章

  • Markdown学习
    Markdown学习标题+(空格)+标题名字字体Helloworld!(粗体)Helloworld!(斜体)Helloworld!(斜体加粗)Helloworld!(删除线)引用掌控自己,一定要坚持分割线图片超链接[点击跳转到枫的博客](枫亦穰穰-博客园(cnblogs.com))列表abcab表格......
  • 算法学习之选择排序【C语言】
    选择排序排序规则选择排序是一种简单直观的排序算法,其基本思想是每次从待排序的数据元素中选出最小(或最大)的一个元素,存放到序列的起始位置,直到全部元素排序完成。具体步骤如下:1.从第一个数开始,与其后的数一一比较,如后小前大,则交换,依次比较直至最后一组数。2.通过上述步骤,得到参加循......
  • Unity框架:JKFrame2.0学习笔记(二)——Singleton单例模式
    Singleton单例模式的基类,不用mono的类可以直接继承源码namespaceJKFrame{///<summary>///单例模式的基类///</summary>publicabstractclassSingleton<T>whereT:Singleton<T>,new(){privatestaticTinstance;public......
  • 学习率lr下降错误问题
    在更新学习率的部分有这样一个函数get_lr()有很大的bug:get_last_lr()才表示当前的学习率,使用get_lr()会衰减两次!!红色部分是get_last_lr()打印的;白色部分是get_lr()打印的:可以看到错误的写法确实会在节点处衰减两次  这是由于step()时会调一次get_lr(),为了得到lr又掉用一次g......
  • CS231N assignment 1 _ softmax 学习笔记 & 解析
    [注意:考虑到这个和SVM重复很多,所以会一笔带过/省略一些]softmax和SVM只是线性分类器分类结果的评判不同,完全依靠打分最大来评判结果,误差就是希望结果尽可能接近正确分类值远大于其他值.我们将打分结果按照指数权重正则化为和为1的向量:而这个值希望尽可能接近1,也就是-l......
  • 【MySQL】MySQL基础07— SQL学习 — DQL — 分组查询(转载请注明出处)
    SQL学习—DQL—分组查询5.分组查询背景:在分组函数的内容中,我们提及和分组函数一起查询的字段会有限制,产生错误。因为分组函数是将所以的参数统计成一个结果,而查询的字段是返回符合条件的个数,那么就会出错。所以引入了分组查询,将表中的相同的内容切分成数块,然后分别进行统......
  • 随机森林算法深入浅出
    目录一随机森林算法的基本原理二随机森林算法的优点1.随机森林算法具有很高的准确性和鲁棒性2.随机森林算法可以有效地避免过拟合问题3.随机森林算法可以处理高维度数据4.随机森林算法可以评估特征的重要性三随机森林算法的缺点1.随机森林算法对于少量数据集表现不佳2.随......
  • go语言学习-冒泡排序
    冒泡排序冒泡排序属于交换类的排序算法,比如有一段乱序的数,591681464925463第一轮迭代:从第一个数开始,依次比较相邻的两个数,如果后面的一个数比前面的一个数大,那么交换位置,直接到处理最后一个数,最后这个数是最大的第二轮迭代,因为最后一个数已经是最大的了,重复第一轮操作,......
  • SpringCloud源码学习笔记3——Nacos服务注册源码分析
    系列文章目录和关于我一丶基本概念&Nacos架构1.为什么需要注册中心实现服务治理、服务动态扩容,以及调用时能有负载均衡的效果。如果我们将服务提供方的ip地址配置在服务消费方的配置文件中,当服务提供方实例上线下线,消费方都需要重启服务,导致二者耦合度过高。注册中心就是在......
  • 【MySQL】MySQL基础05 — SQL学习 — DQL — 常见函数 — 分组函数(转载请注明出处)
    SQL学习—DQL—常见函数—分组函数4.常见函数(附加)/*概念:类似于java的方法,将一组逻辑语句封装在方法体中,对外暴露方法名。好处:1.隐藏了实现细节2.提高代码的重用性调用语法:select函数名(实参列表)【from表】;特点: 1.叫什么(函数名) 2.干什么(函数功能)分类: 1.单......