首页 > 其他分享 >机器学习笔记:序列到序列学习[详细解释]

机器学习笔记:序列到序列学习[详细解释]

时间:2024-08-10 11:58:04浏览次数:12  
标签:编码器 解码器 self 笔记 学习 state num 序列 size

介绍

本节我们使用两个循环神经网络的编码器和解码器, 并将其应用于序列到序列(sequence to sequence,seq2seq)类的学习任务。遵循编码器-解码器架构的设计原则, 循环神经网络编码器使用长度可变的序列作为输入, 将其转换为固定形状的隐状态。 换言之,输入序列的信息被编码到循环神经网络编码器的隐状态中。

结构

首先,我们使用上一节提到的编码器-解码器结构,其中编码器使用一个双隐层的门控循环单元构成的循环神经网络(链接均为我之前发布的博客笔记,seq2seq是基于之前这几节的内容的)。而解码器使用一个双隐层的门控循环单元构成的循环神经网络,后接一个全连接层。

编码器作用

编码器通过循环神经网络,将每个时间步的输入X和上一时间步的隐藏状态进行处理生成下一时间步的隐状态,即H_t=f(X,H_{t-1})。之后再通过编码操作把每个时间步的隐状态转化为上下文变量c,即c=q(h_1,h_2,...,h_T)

解码器作用

解码器通过先前的输出序列和上下文变量c共同决定当前时间步输出,概率为P(y_t|y_1,y_2,...,y_{t-1},c),解码器隐状态的更新操作为s_t=g(s_{t-1},y_{t-1},c)

代码实现

引入库

import collections
import math
from mxnet import autograd, gluon, init, np, npx
from mxnet.gluon import nn, rnn
from d2l import mxnet as d2l

npx.set_np()

编码器的实现

对编码器的代码进行解释:在对数据集进行处理后,形成的是一个三维的数据,size=(batch_size,num_steps,vocab_size),经过嵌入层的处理后,转换为size=(batch_size,num_steps,embed_size)。之后对第一维度和第二维度进行交换,使得size=(num_steps,batch_size,embed_size),在之前的RNN中我们已经知道第一维度应该是时间步数(这样RNN可以沿着时间步继续走下去),通过维度转换,使得第一维度成为时间步,需要注意的是此时不能使用X.T直接进行转置,因为我们只希望转换前两个维度,并不希望对整个矩阵进行变换。接下来的操作与RNN一致,获得RNN的输出与状态(需要注意的是,此时的RNN是一个GRU)。

class Seq2SeqEncoder(d2l.Encoder):
    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,
                 dropout=0, **kwargs):
        super(Seq2SeqEncoder, self).__init__(**kwargs)
        # 嵌入层
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.rnn = rnn.GRU(num_hiddens, num_layers, dropout=dropout)

    def forward(self, X, *args):
        # 输出'X'的形状:(batch_size,num_steps,embed_size)
        X = self.embedding(X)
        # 在循环神经网络模型中,第一个轴对应于时间步
        X = X.swapaxes(0, 1)
        state = self.rnn.begin_state(batch_size=X.shape[1], ctx=X.ctx)
        output, state = self.rnn(X, state)
        # output的形状:(num_steps,batch_size,num_hiddens)
        # state的形状:(num_layers,batch_size,num_hiddens)
        return output, state

解码器的实现

上下文变量c与输入y_t进行拼接(concatenate)操作,使得每一个时间步读取对应的上下文变量和输入。解码器使用一个全连接层进行Softmax运算产生输出。

需要注意的是,在编码器的返回变量中,output和state共同构成一个元组。在init_state()中,通过对编码器的输出索引第二个元素得到所需要的状态state。在前向计算中,获得最后一个层的状态,作为上下文变量,将上下文变量context与输入进行连接,一起输入到门控循环单元GRU中。

class Seq2SeqDecoder(d2l.Decoder):
    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,
                 dropout=0, **kwargs):
        super(Seq2SeqDecoder, self).__init__(**kwargs)
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.rnn = rnn.GRU(num_hiddens, num_layers, dropout=dropout)
        self.dense = nn.Dense(vocab_size, flatten=False)

    def init_state(self, enc_outputs, *args):
        return enc_outputs[1]

    def forward(self, X, state):
        # 输出'X'的形状:(batch_size,num_steps,embed_size)
        X = self.embedding(X).swapaxes(0, 1)
        # context的形状:(batch_size,num_hiddens)
        context = state[0][-1]
        # 广播context,使其具有与X相同的num_steps
        context = np.broadcast_to(context, (
            X.shape[0], context.shape[0], context.shape[1]))
        X_and_context = np.concatenate((X, context), 2)
        output, state = self.rnn(X_and_context, state)
        output = self.dense(output).swapaxes(0, 1)
        # output的形状:(batch_size,num_steps,vocab_size)
        # state的形状:(num_layers,batch_size,num_hiddens)
        return output, state

实例化解码器测试输出大小

decoder = Seq2SeqDecoder(vocab_size=10, embed_size=8, num_hiddens=16,
                         num_layers=2)
decoder.initialize()
state = decoder.init_state(encoder(X))
output, state = decoder(X, state)
output.shape, len(state), state[0].shape
((4, 7, 10), 1, (2, 4, 16))

疑问

其实我这里有一个地方感到不解,在解码器解码时,之前编码器的state应为一个三维数组,怎么会索引[0][-1]之后出现一个二维数组?难道哪里把它封装成元组了吗?

应该的确是这样,因为根据测试输出大小时输出的state长度为1,说明解码器对state直接进行了包装,使得state[0]才是真正的状态,但我没有想到这是哪个操作进行的。

标签:编码器,解码器,self,笔记,学习,state,num,序列,size
From: https://blog.csdn.net/2301_79335566/article/details/141065641

相关文章

  • 第六周学习报告
    又经过了一周的学习,今天对本周学习进行总结本周学习了Java面向对象进阶内容抽象类和抽象方法抽象类使用abstract关键字声明的类被称为抽象类。抽象类不能被实例化。抽象方法使用abstract关键字声明的方法被称为抽象方法。抽象方法没有方法体,即大括号{}内没有代码实现。抽象......
  • 8.10第四周周六学习总结
    1vj团队12补题不错的一个题解https://blog.fishze.com/archives/3011)字符串变化(模拟+找规律)题目:给定一个字符串,给定一个特定操作方式:该字符串前半段+该字符串自己+该字符串后半段求next(每一个字符向后移动一个),组成一个新字符串,求经过10^100次这样的操作后,......
  • 2-SAT 学习笔记
    2-SAT用于求解布尔方程组,其中每个方程最多含有两个变量,方程的形式为\((a∨b)=1\),即式子\(a\)为真或式子\(b\)为真。求解的方法是根据逻辑关系式建图,然后求强联通子图,每一个强联通子图的答案都是一样的。建图:这里以模版题为例:题意:给定若干个需要满足的条件,其形式为\(a,1......
  • 大一暑假学习记录6
    这一周我基本完成了刘立嘉老师布置的暑假作业,其中通讯录的录入与显示,整数分解为若干项之和是我认为最难做的题目,前者的难点是sample有查询越界、最大N,反复查询同一记录等等。后者则是样例等价,多行输出难以解决。于是我又重新学习了结构体部分的内容,定义了Contact结构体来存储......
  • 学生Java学习路程-6
    ok,到了一周一次的总结时刻,我大致会有下面几个方面的论述:1.这周学习了Java的那些东西2.这周遇到了什么苦难3.未来是否需要改进方法等几个方面阐述我的学习路程。复习面向对象数组数组的三种初始化方法:默认,静态,动态引用类型Man放入数组中的测试代码数组的拷贝使用规范使......
  • Lazysysadmin靶机笔记
    Lazysysadmin靶机笔记概述lazysysadmin是一台Vulnhub靶机,整体比较简单,要对一些存在服务弱口令比较敏感。靶机地址:https://pan.baidu.com/s/19nBjhMpGkdBDBFSnMEDfOg?pwd=heyj提取码:heyj一、nmap扫描1、主机发现#-sn只做ping扫描,不做端口扫描sudonmap-sn192.168.247.1......
  • 多元时间序列分析统计学基础:基本概念、VMA、VAR和VARMA
    多元时间序列是一个在大学课程中经常未被提及的话题。但是现实世界的数据通常具有多个维度,所以需要多元时间序列分析技术。在这文章我们将通过可视化和Python实现来学习多元时间序列概念。这里假设读者已经了解单变量时间序列分析。1、什么是多元时间序列?顾名思义,多元时间序列是......
  • 爆火下28万次!MIT最新-理解深度学习
        最近疯传的-理解深度学习-高达28万次,被认为可能。涵盖了深度学习从基础到高级各个方面的内容,包括基本概念、监督学习、强化学习、线性回归、神经网络、扩散模型等等。全面系统地机器学习的基础概念和深度学习的多种模型,还包括最新的Transformer和图神经网络。 免......
  • 谷粒商城实战笔记-145-性能压测-性能监控-jvisualvm使用-解决插件不能安装
    文章目录jvisualvm的作用安装查看gc相关信息的插件解决jvisualvm不能正常安装插件的问题1,查看java版本2,打开网址3,修改jvisualvm的设置jvisualvm的作用JVisualVM是一个集成在JavaDevelopmentKit(JDK)中的多功能工具,它提供了一种可视化的方式来监控和分析Java应用......
  • 【编程笔记】解决移动硬盘无法访问文件或目录损坏且无法读取
    解决移动硬盘无法访问文件或目录损坏且无法读取只解决:移动硬盘无法访问文件或目录损坏且无法读取问题由于频繁下载数据,多次安装虚拟机导致磁盘无法被系统识别。磁盘本身是好的,只是不能被识别,如果将磁盘格式化,就可以正常使用,这样磁盘内数据就丢失了。怎样才能即保留数据......