首页 > 其他分享 >机器翻译之创建Seq2Seq的编码器、解码器

机器翻译之创建Seq2Seq的编码器、解码器

时间:2024-09-18 22:23:48浏览次数:16  
标签:编码器 num outputs self shape 机器翻译 state 解码器 size

1.创建编码器、解码器的基类

1.1创建编码器的基类

from torch import nn


#构建编码器的基类
class Encoder(nn.Module):   #继承父类nn.Module
    def __init__(self, **kwargs):   #**kwargs:不定常的关键字参数
        super().__init__(**kwargs)
        
    def forward(self, X, *args):  #*args:不定常的位置参数
        #若继承了Encoder这个基类,就必须实现forward(),否则就会报下这个错
        raise  NotImplementedError          

1.2创建解码器的基类

#创建解码器的基类
#创建解码器的基类比创建编码器的基类多一个 state的初始化
class Decoder(nn.Module):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        
    #初始化state
    def init_state(self, enc_outputs, *args):
        raise NotImplementedError
    
    #前向传播,解码器比编码器多传入一个state
    def forward(self, X, state):
        raise NotImplementedError

 1.3合并编码器和解码器的基类

class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, **kwargs):
        super().__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        
    def forward(self, enc_X, dec_X, *args):
        """
        enc_X:编码器需传入的数据
        dec_X:解码器需传入的数据
        """
        enc_outputs = self.encoder(enc_X, *args)
        dec_state = self.decoder.init_state(enc_outputs, *args)
        return self.decoder(dec_X, dec_state)

 2.基于上述基类,正式创建Seq2Seq编码器与解码器的类

import collections
import math
import torch
import dltools

2.1创建Seq2Seq的编码器类 

class Seq2SeqEncoder(Encoder):  #继承父类Encoder
    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0, **kwargs):
        super().__init__(**kwargs)
        """
        vocab_size:词汇表大小
        embed_size:嵌入层大小
        num_hiddens:隐藏层的神经元数量
        num_layers:隐藏层的层数
        dropout=0 : 默认所有的神经元参与计算
        """
        #初始化嵌入层
        self.embedding = nn.Embedding(vocab_size, embed_size)
        #初始化神经网络层
        self.rnn = nn.GRU(embed_size, num_hiddens, num_layers, dropout=dropout)
        
    def forward(self, X, *args):
        #在进行embedding之前,X的shape=(batch_size, num_steps, vocab_size)
        X = self.embedding(X) 
        #X经过embedding处理,X的shape=(batch_size, num_steps, embed_size)
        X = X.permute(1, 0, 2)  
        #经过permute调换维度之后,X的shape=(num_steps, batch_size, embed_size)
        
        #此时, pytorch 会自动完成隐藏状态的初始化,即0, 不需要手动传入state
        outputs, state = self.rnn(X)
        #outputs的shape=(num_steps, batch_size, num_hiddens) ,最后一维是神经元的数量
        #state的shape=(num_layers, batch_size, num_hiddens)
        return outputs, state
#测试代码
encoder = Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=32, num_layers=2)
encoder.eval()
# batch_size=4, num_steps=7
X = torch.zeros((4, 7), dtype=torch.long)
outputs, state = encoder(X)

print(outputs.shape, state.shape)
torch.Size([7, 4, 16]) torch.Size([2, 4, 16])

2.2 创建Seq2Seq的解码器类

class Seq2SeqDecoder(Decoder):
    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0, **kwargs):
        super().__init__(**kwargs)
        #初始化嵌入层
        self.embedding = nn.Embedding(vocab_size, embed_size)
        #初始化神经网络层
        self.rnn = nn.GRU(embed_size + num_hiddens, num_hiddens, num_layers, dropout=dropout)
        #初始化输出层
        self.dense = nn.Linear(num_hiddens, vocab_size)
        
    #定义函数:获取状态state
    def init_state(self, enc_outputs, *args):
        #编码器输出的结果有两个,第二个为state
        return enc_outputs[1]
    
    #前向传播
    def forward(self, X, state):
        #X的原始shape=(batch_size, num_steps, vocab_size)
        X = self.embedding(X)  #X的shape=(batch_size, num_steps, embed_size)
        X = X.permute(1, 0, 2)  #调整数据维度, X的shape=(num_steps, batch_size, embed_size)
       
        # 把X和state拼接到一起. 方便计算. 
        # X现在的形状(num_steps, batch_size, embed_size) , 
        # state的形状(batch_size, num_hiddens)
        # 要把state的形状扩充成三维. 变成(num_steps, batch_size, num_hiddens)
        context = state[-1].repeat(X.shape[0], 1, 1)  #扩充X.shape[0]=num_steps次,1:所对应的维度不变
        X_and_context = torch.cat((X, context), 2) #按照索引为2的维度合并
        #此时,X_and_context的shape=(num_steps, batch_size, embed_size+num_hiddens)
        #神经网络层
        outputs, state = self.rnn(X_and_context, state)
        #输出层
        outputs = self.dense(outputs).permute(1, 0, 2) #将数据维度重新调换过来
        #outputs的shape=(batch_size, num_steps, vocab_size)
        #state的shape=(num_layers, batch_size, num_hiddens)
        return outputs, state
#测试
decoder = Seq2SeqDecoder(vocab_size=10, embed_size=8, num_hiddens=32, num_layers=2)
decoder.eval()
state = decoder.init_state(encoder(X))
outputs, state = decoder(X, state)
outputs.shape, state.shape
(torch.Size([4, 7, 10]), torch.Size([2, 4, 32]))

3.编码器 、解码器理论图

 

4.知识点个人理解

 

标签:编码器,num,outputs,self,shape,机器翻译,state,解码器,size
From: https://blog.csdn.net/Hiweir/article/details/142345210

相关文章

  • CMS32L051使用旋转编码器
    文章目录概要代码小结概要CMS32L051使用外部中断的方式识别旋转编码器的方向。选取其中一个信号A进行外部中断触发,由于信号A空闲时处于高电平,因此初始化时外部中断使用下降沿触发;触发第一个下降沿后,判断当前是否已经触发了上升沿,如果已经触发上升沿,则需要判断当前......
  • Arduino ESP32 oled显示,增量式编码器测距程序
      ESP3214引脚接编码器A,13引脚接编码器B,21、22为I2C默认引脚,程序根据编码器A触发ESP32的22脚中断,然后判断编码器B在ESP32的21脚状态是高电平还是低电平,来决定编码器是正转还是反转,也就是数值应该加还是减。   程序设计为编码器转一圈为1000个脉冲也就是编码器分辨率......
  • Spring Cloud全解析:服务调用之Feign的编解码器
    Feign的编解码器编码器在向服务发送请求时,有些情况需要对请求内容进行处理publicclassFeignSpringFormEncoderimplementsEncoder{@Overridepublicvoidencode(Objectobject,TypebodyType,RequestTemplatetemplate)throwsEncodeException解码器可以......
  • 看demo学算法之 自编码器
    大家好,这里是小琳AI课堂!今天我们来聊聊自编码器。......
  • STM32 TIM编码器接口测速(最详细的编码器接口笔记)
    编码器接口简单介绍方波的频率其实就代表了速度编码器接口测速原理TIM编码器测速本质上就是测频法,在指定时间内,对高电平信号进行计次编码器接口的设计逻辑就是,首先把A相和B项的所有边沿作为计数器的计数时钟,出现边沿信号的时候,就自增或者自减,如何判断自增还是自减?当出现......
  • 西门子电机编码器参数设置
    SimotionPLC解释1FK70221FK70331FK7(AM20)1FK7(AM24)1FK7(AS20)1FK7(AS24)encoderMode模式PROFIDRIVEPROFIDRIVEPROFIDRIVEPROFIDRIVEPROFIDRIVEPROFIDRIVEABSResolutionIncrements单圈线数PROFIDRIVEPROFIDRIVEPROFIDRIVEPRO......
  • 【花雕学编程】Arduino FOC 之步进电机正反转驱动、AS5600编码器信息读取及速度检测
    Arduino是一个开放源码的电子原型平台,它可以让你用简单的硬件和软件来创建各种互动的项目。Arduino的核心是一个微控制器板,它可以通过一系列的引脚来连接各种传感器、执行器、显示器等外部设备。Arduino的编程是基于C/C++语言的,你可以使用ArduinoIDE(集成开发环境)来编写、......
  • 关于Plex转码失败,下载解码器失败的问题
    plex服务器版本:1.40.5.8854系统:qnap最近我在外地访问我的plex时,部分视频总是出现转码失败的问题,然后我看了下plex的日志Sep02,202411:48:50.338[140385089600312]ERROR-[NSB]Errorinbrowserhandleread:125(Operationcanceled)socket=-1Sep02,202411:48:50......