首页 > 其他分享 >机器学习中的神经网络重难点!纯干货(下篇)

机器学习中的神经网络重难点!纯干货(下篇)

时间:2024-06-17 09:30:42浏览次数:14  
标签:下篇 模型 生成器 batch char 重难点 干货 np model

上篇文章地址:机器学习中的神经网络重难点!纯干货(上篇)-CSDN博客

目录

长短时记忆网络

基本原理

一个示例

自注意力模型

基本原理

自注意力机制

具体步骤

一个案例

生成对抗网络

基本原理

一个案例

长短时记忆网络

LSTM就像一个有记忆的人,可以记住重要的信息并且忘记不重要的。

特别擅长处理长序列数据,因为它可以在很长的序列中捕捉和保持关键信息,而不会被无关信息淹没。

基本原理

LSTM的核心思想是细胞状态(cell state)和门(gates)的概念。细胞状态就像LSTM的记忆,它可以传递信息并在必要时保留或删除。

门用于控制信息的流动,包括遗忘不必要的信息和记住重要的信息。

LSTM的工作过程分为三个主要步骤:遗忘、存储和更新。

1、遗忘:细胞状态决定哪些信息应该被遗忘,哪些信息应该保留。门控制着遗忘的过程。

2、存储:新的信息被添加到细胞状态中,以更新记忆。门还可以控制信息的存储。

3、更新:基于当前的输入和细胞状态,LSTM生成新的输出和细胞状态,这将成为下一个时间步的输入。

一个示例

使用Python和TensorFlow来构建一个LSTM模型,并将其应用于文本生成。


import tensorflow as tf
import numpy as np

# 创建示例数据
text = "这是一个示例文本。LSTM将学会预测下一个字符。"
chars = sorted(list(set(text)))
char_to_index = {char: index for index, char in enumerate(chars)}
index_to_char = {index: char for index, char in enumerate(chars)}

# 准备数据
max_sequence_length = 100
sequences = []
next_chars = []
for i in range(0, len(text) - max_sequence_length, 1):
    sequences.append(text[i:i + max_sequence_length])
    next_chars.append(text[i + max_sequence_length])

X = np.zeros((len(sequences), max_sequence_length, len(chars)), dtype=np.bool)
y = np.zeros((len(sequences), len(chars)), dtype=np.bool)
for i, sequence in enumerate(sequences):
    for t, char in enumerate(sequence):
        X[i, t, char_to_index[char]] = 1
    y[i, char_to_index[next_chars[i]]] = 1

# 构建LSTM模型
model = tf.keras.Sequential([
    tf.keras.layers.LSTM(128, input_shape=(max_sequence_length, len(chars))),
    tf.keras.layers.Dense(len(chars), activation='softmax')
])

# 编译模型
model.compile(optimizer='adam', loss='categorical_crossentropy')

# 训练模型
model.fit(X, y, epochs=100)

# 生成文本
seed_text = "这是一个示例文本。LSTM将学会预测下一个字符。"
generated_text = seed_text
for i in range(100):
    x = np.zeros((1, max_sequence_length, len(chars)))
    for t, char in enumerate(generated_text[-max_sequence_length:]):
        x[0, t, char_to_index[char]] = 1
    predicted_char = index_to_char[np.argmax(model.predict(x, verbose=0)]
    generated_text += predicted_char

代码中实现的是构建了一个LSTM模型,用于文本生成。模型学会了根据前面的文本生成后续文本,展示了LSTM如何捕捉文本中的序列信息。

自注意力模型

Transformer是一种能够理解文本和序列数据的神经网络模型。它的独特之处在于使用了自注意力机制,这意味着它能够同时关注输入数据中的不同部分,而不像传统的循环神经网络(RNN)或卷积神经网络(CNN)那样依赖于固定窗口大小或序列顺序。

Transformer的核心思想是将输入数据分为不同的“词嵌入”(word embeddings),然后使用自注意力机制来决定这些词嵌入之间的关联程度。这种方式使得模型可以处理长文本并捕捉到不同单词之间的复杂关系。

基本原理

自注意力机制

自注意力机制是Transformer的核心。它的思想是计算输入序列中每个位置对其他位置的重要性。这个重要性是通过计算一个权重值的方式来实现的,而这个权重值是根据输入的相似性来决定的。重要的是,这种计算是基于输入数据本身完成的,因此不受序列长度的限制。

具体步骤

1、嵌入层(Embedding Layer):将输入的文本序列转化为向量形式,每个词对应一个向量。这些向量被训练成具有语义信息的表示。

2、自注意力计算:对于每个词,计算它与所有其他词的相似性得分,然后将这些分数作为权重来加权其他词的嵌入向量。这个过程允许模型更关注与当前词相关的词。

3、多头自注意力:Transformer可以通过多个自注意力头来捕捉不同层次的关系,每个头都会生成一组权重,最后合并它们。

4、残差连接与层归一化:将多头自注意力的输出与输入相加,并应用层归一化,以防止梯度消失或爆炸。

5、前馈神经网络(Feed-Forward Network):对每个位置的向量进行非线性变换,以增强模型的表示能力。

6、编码器和解码器:在机器翻译等任务中,Transformer通常由编码器和解码器组成,它们分别用于处理输入和生成输出。

一个案例

首先,假设我们有一个文本分类任务,需要将文本句子分为正面和负面情感。

代码中,包括创建Transformer模型、输入数据和运行模型的步骤:


import torch
import torch.nn as nn

# 创建一个简化的Transformer模型类
class TransformerModel(nn.Module):
    def __init__(self, d_model, nhead, num_encoder_layers):
        super(TransformerModel, self).__init__()
        self.transformer = nn.Transformer(d_model, nhead, num_encoder_layers)
    
    def forward(self, src, tgt):
        output = self.transformer(src, tgt)
        return output

# 创建示例数据
# 假设我们有一个输入序列(src)和一个目标序列(tgt)
# 这里我们使用随机生成的数据,实际应用中应该替换为真实数据
src = torch.rand((10, 3, 512))  # 输入序列,形状:(sequence_length, batch_size, embedding_dim)
tgt = torch.rand((20, 3, 512))  # 目标序列,形状:(sequence_length, batch_size, embedding_dim)

# 初始化模型
model = TransformerModel(512, 8, 6)

# 运行模型
output = model(src, tgt)
print(output.shape)

在上述代码中:

1、我们创建了一个简化的Transformer模型(TransformerModel),它接受输入序列(src)和目标序列(tgt)。

2、我们使用torch.rand函数生成了随机的示例数据,其中src代表输入序列,tgt代表目标序列。这些数据应该由任务提供,实际应用中会更有意义。

3、我们初始化了模型并将输入数据传递给它,最后打印了输出的形状。

要注意的是,这个示例的数据和任务仅用于演示Transformer模型的使用方式。在实际情况下,需要准备适当的数据和任务设置来训练和使用Transformer模型。

生成对抗网络

生成对抗网络核心思想是模拟人类创造事物的方式。

GANs由两个主要部分组成:生成器(Generator)和判别器(Discriminator)。这两个部分之间进行博弈,使生成器逐渐学会创建逼真的数据,而判别器则逐渐变得更擅长区分真假数据。

基本原理

1、生成器(Generator):生成器的任务是接收一个随机噪声向量,然后将其转化为逼真的数据,例如图像。生成器是一个神经网络,通过不断调整其参数,使其生成的数据与真实数据尽可能相似。

2、判别器(Discriminator):判别器的任务是区分生成器生成的数据和真实数据。判别器也是一个神经网络,它会对输入的数据进行评估,输出一个0到1之间的概率值,表示数据的真实程度。

3、对抗训练:生成器和判别器交替进行训练。生成器努力生成更逼真的数据,而判别器努力变得更善于区分真伪。这种博弈过程推动了生成器不断提高生成的数据质量。

1、生成器的损失函数:生成器的目标是最小化,其中是判别器对生成器生成数据的评价。

2、判别器的损失函数:判别器的目标是最小化,其中是判别器对真实数据的评价,是对生成器生成数据的评价。

一个案例

我们以图像生成为例,使用Python和TensorFlow库来演示一个简单的GANs模型。

import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, Reshape
from tensorflow.keras.models import Sequential

# 创建生成器和判别器网络
def build_generator():
    model = Sequential()
    model.add(Dense(128, input_dim=100, activation='relu'))
    model.add(Dense(784, activation='sigmoid'))
    model.add(Reshape((28, 28, 1)))
    return model

def build_discriminator():
    model = Sequential()
    model.add(Flatten(input_shape=(28, 28, 1)))
    model.add(Dense(128, activation='relu'))
    model.add(Dense(1, activation='sigmoid'))
    return model

# 定义损失函数和优化器,并编译生成器和判别器
generator = build_generator()
discriminator = build_discriminator()

discriminator.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
discriminator.trainable = False

gan_input = tf.keras.Input(shape=(100,))
x = generator(gan_input)
gan_output = discriminator(x)

gan = tf.keras.Model(gan_input, gan_output)
gan.compile(loss='binary_crossentropy', optimizer='adam')


# 训练GANs模型
for epoch in range(epochs):
    for _ in range(batch_size):
        noise = np.random.normal(0, 1, [batch_size, 100])
        generated_images = generator.predict(noise)
        image_batch = x_train[np.random.randint(0, x_train.shape[0], size=batch_size)]

        discriminator.trainable = True
        d_loss_real = discriminator.train_on_batch(image_batch, np.ones((batch_size, 1)))
        d_loss_fake = discriminator.train_on_batch(generated_images, np.zeros((batch_size, 1)))
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

        noise = np.random.normal(0, 1, [batch_size, 100])
        discriminator.trainable = False
        g_loss = gan.train_on_batch(noise, np.ones((batch_size, 1)))

# 生成并可视化手写数字图像
num_samples = 16
noise = np.random.normal(0, 1, [num_samples, 100])
generated_images = generator.predict(noise)

for i in range(num_samples):
    plt.subplot(4, 4, i + 1)
    plt.imshow(generated_images[i, :, :, 0] * 0.5 + 0.5, cmap='gray')
    plt.axis('off')

plt.show()

代码展示了一个简单的GANs模型,用于生成手写数字图像。

训练之后,生成器将生成逼真的手写数字图像,判别器将变得更难以区分真假图像。

生成对抗络是用于各种创意领域。它基于生成器和判别器的博弈,通过不断优化生成器来创造逼真的数据。

标签:下篇,模型,生成器,batch,char,重难点,干货,np,model
From: https://blog.csdn.net/CKissjy/article/details/139732105

相关文章

  • 机器学习中的神经网络重难点!纯干货(上篇)
     目录前馈神经网络基本原理公式解释一个示例卷积神经网络基本原理公式解释一个示例循环神经网络基本原理公式解释一个案例长短时记忆网络基本原理公式解释一个示例自注意力模型基本原理自注意力机制具体步骤公式解释一个案例生成对抗网络基本原理公......
  • 好用的库函数,qsort函数大详解(干货满满!)(初阶)
    前言;  我一直在思考今天要写什么类型的文章,看到之前写的冒泡排序的写法,不过冒牌排序的算法只能针对于整型,我们如果想要排序浮点型,字符型的数据呢?这个时候我突然想到了比冒泡排序还好用的一个库函数,就是我们今天的主角——qsort函数,下面不多废话,直接进入正文: 目录:1.qsor......
  • 史上最详细的轨迹优化教程-机器人避障及轨迹平滑实现(干货满满)
    有一些朋友问我到底如何用优化方法实现轨迹优化(避障+轨迹平滑等),今天就出一个干货满满的教程,绝对是面向很多工业化场景的讲解,为了便于理解,我选用二维平面并给出详细代码实现,三维空间原理相似。本教程禁止转载,主要是有问题可以联系我探讨,我的邮箱[email protected]下面......
  • 干货分享,数字化校园整体解决方案
     数字化校园囊括了校园事务的各个方面,同时, 数字化校园又是所有相关子系统的数据输出与枢纽。可以看出, 数字化校园是一个大而全的系统。鉴于此, 数字化校园的模块众多,本文将 数字化校园的所有模块做出大致梳理,以完整支持 数字化校园平台建设。基础平台 基于在信息化......
  • PMP考前集训干货总结
    一、口诀1、谋定而后动:发现问题——>分析问题——>解决问题2、遇到问题,先记录——>讨论分析——>找解决方案3、获资源,优先谈判——>找领导——>招募4、遇采购索赔,优先谈判——>ADR(调解、仲裁)——>法院5、人的问题找沟通:与干系人见面、接触、开会、讨论、达成一致6、凡......
  • 干货分享!渗透测试成功的8个关键
     01知道为什么要测试执行渗透测试的目的是什么?是满足审计要求?是你需要知道某个新应用在现实世界中表现如何?你最近换了安全基础设施中某个重要组件而需要知道它是否有效?或者渗透测试根本就是作为你定期检查防御健康的一项例行公事?当你清楚做测试的原因时,你也就知晓自己想......
  • 【无量化,无管理】指标体系建设方案(36页PPT),干货满满
    引言:现代管理学之父彼得·德鲁克曾经说过:“无量化,无管理”、以及“先量化,后决策”,指明了量化管理在企业经营及决策中的意义;其中量化管理的依据就是经营管理指标。在实际中指标很多,如财务指标、经营指标、绩效指标、人力指标……据统计,一个小型企业有上百个指标,而中、大型企业......
  • [干货!必看文章]学会如何用L4级AI软件开发平台免费制作应用程序
    前言:  自从ChatGPT问世以来,就掀起了全球AI大模型的浪潮。国外有Claude,Llama,Grok,Suno,国内有kimi,有智谱AI,有通义千问,还有文心一言...国内大模型市场规模已经达到了216亿,在2028年预估将达到1179亿。显而易见AI已然是当前最火爆的行业。因此,懂AI,会用AI已经成为了一项必备的技......
  • 干货:TikTok限流、0播放问题怎么解决?
    针对TikTok运营中常见的流量问题,如疑似限流、零播放等,本文将为大家提供详尽的解答和策略。Q1:为何视频播放量始终难以突破几百?A1:视频播放量低,首要考虑的并非限流。内容质量是关键,建议检查视频内容是否吸引人且符合TikTok的规范。同时,网络环境稳定性也很重要,非常建议使用原生......
  • 【旅行使身体和灵魂都在路上】智慧旅游解决方案集合推荐,干货满满!
    引言:2024年端午节假期,全国文化和旅游市场总体平稳有序。据测算,全国国内旅游出游合计1.1亿人次,同比增长6.3%;国内游客出游总花费403.5亿元,同比增长8.1%。假期中,群众赛龙舟、吃粽子、唱山歌、赏古曲,传统节日文化内涵与旅游发展深度融合。广东、湖南、浙江、贵州、云南等地......