首页 > 其他分享 >机器翻译之seq2seq训练、预测、评估代码

机器翻译之seq2seq训练、预测、评估代码

时间:2024-09-22 14:55:58浏览次数:3  
标签:vocab pred 代码 seq2seq len label num 机器翻译 device

目录

1.seq2seq训练代码

2.预测代码 

 3.评估代码

 4.知识点个人理解


 

1.seq2seq训练代码

seq2seq的训练代码:pytorch中训练代码一般都相同类似

#将无效的序列数据都变成0(屏蔽无效内容的部分)
def sequence_mask(X, valid_len, value=0):
    """
    valid_len:有效序列的长度
    """
    #找到最大序列长度
    maxlen = X.size(1)
    #判断掩码区域
    mask = torch.arange((maxlen), dtype=torch.float32, device=X.device)[None] < valid_len[:, None]
    #[~mask]表示取相反的数据(取原本为False的数据)
    X[~mask] = value
    return X

#重写交叉熵损失, 添加屏蔽无效内容的部分

class MaskedSoftmaxCELoss(nn.CrossEntropyLoss):
    #重写forward
    #预测值pred的形状:(batch_size, num_steps, vocab_size)
    #真实值label的形状:(batch_size, num_steps)
    #valid_len的形状:(batch_size)
    def forward(self, pred, label, valid_len):
        #创建一个像label形状的全是1的tensor,赋值给初始权重
        weights = torch.ones_like(label)
        #使用掩码,将无效的序列内容屏蔽(其权重变为0),重新赋值
        weights = sequence_mask(weights, valid_len)
        #设置不聚合维度
        self.reduction = 'none'
        #调用原始的forward()来计算未屏蔽无效内容前的交叉熵损失
        #pred的shape使用permute()转换,将num_steps换到最后
        unweighted_loss = super().forward(pred.permute(0, 2 ,1), label)
        #用unweighted_loss * 屏蔽后的weights 求平均:每一批数据的交叉熵损失
        weighted_loss = (unweighted_loss * weights).mean(dim=1)
        return weighted_loss



#测试重写的交叉熵损失
loss = MaskedSoftmaxCELoss()
loss(torch.ones(3, 4, 10), torch.ones((3, 4), dtype=torch.long), torch.tensor([4, 2, 0]))
tensor([2.3026, 1.1513, 0.0000])
#训练代码


def train_seq2seq(net, data_iter, lr, num_epochs, tgt_vocab, device):
    #初始化  RNN网络的xavier初始化的代码都一样
    def xavier_init_weights(m):
        #判断模型是线性模型时
        if type(m) == nn.Linear:
            nn.init.xavier_uniform_(m.weight)
        #若模型为GRU(2层循环层)模型时
        if type(m) == nn.GRU:
            #遍历每一层的权重参数名称
            for param in m._flat_weights_names:
                #若权重在参数中
                if 'weight' in param:
                    nn.init.xavier_uniform_(m._parameters[param])
    #网络应用xavier初始化的权重
    net.apply(xavier_init_weights)
    #网络转到device上
    net.to(device)
    #用网络初始化的参数与学习率设置优化器
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    #创建损失函数的实例对象
    loss = MaskedSoftmaxCELoss()
    #设置实时更新的画图可视化dltools.Animator()
    animator = dltools.Animator(xlabel='epoch', ylabel='loss', xlim=[10, num_epochs])
    
    for epoch in range(num_epochs):
        timer = dltools.Timer()   #训练数据的计时
        metric = dltools.Accumulator(2) #累加统计两种数值:训练的总损失, 词元数量
        
        for batch in data_iter:  #遍历数据迭代器的批次
            #梯度清零(只要在反向传播之前就行)
            optimizer.zero_grad()
            #取数据
            X, X_valid_len, Y, Y_valid_len = [x.to(device) for x in batch]
            bos = torch.tensor([tgt_vocab['<bos>']] * Y.shape[0], device=device).reshape(-1,1)
            # 开头加上了bos, 那么Y就要去掉最后一列, 保证序列的长度不变. 
            dec_input = torch.cat([bos, Y[:, :-1]], 1)  #给每一行都加上bos
            #获取预测值,state不接收
            Y_pred, _ = net(X, dec_input, X_valid_len)
            #计算损失
            l = loss(Y_pred, Y, Y_valid_len)  #Y_valid_len属于*args其他位置参数传入的
            #反向传播
            l.sum().backward()
            #梯度裁剪
            dltools.grad_clipping(net, theta=1)
            num_tokens = Y_valid_len.sum()
            #更新梯度
            optimizer.step()
            with torch.no_grad():  #不求导
                metric.add(l.sum(), num_tokens)
            
        if (epoch+1) % 10 ==0:  #若每训练循环10次
            animator.add(epoch+1, (metric[0]/ metric[1]))
    print(f'loss {metric[0]/ metric[1]:.3f}, {metric[1] / timer.stop():.1f}', f'tokens/sec on {str(device)}')
#验证封装的训练代码
embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1
batch_size, num_steps = 64, 10
lr, num_epochs, device = 0.005, 500, dltools.try_gpu()

train_iter, src_vocab, tgt_vocab = dltools.load_data_nmt(batch_size, num_steps)

encoder = Seq2SeqEncoder(len(src_vocab), embed_size, num_hiddens, num_layers, dropout)

decoder = Seq2SeqDecoder(len(tgt_vocab), embed_size, num_hiddens, num_layers, dropout)

net = EncoderDecoder(encoder, decoder)

train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)

 

2.预测代码 

def predict_seq2seq(net, src_sentence, src_vocab, tgt_vocab, num_steps, device):
    """
    src_sentence:传入的需要翻译的句子
    src_vocab:需要翻译的词汇表
    tgt_vocab:目标真实值词汇表
    num_steps:子序列长度
    device:GPU或CPU设备
    """
    #预测的时候需要把net设置为评估模式
    net.eval()
    #获取处理后的文本词元索引(输出是一个索引列表),在文本的结尾加上'<eos>'
    src_tokens = src_vocab[src_sentence.lower().split(' ')] + [src_vocab['<eos>']]
    #获取编码器输入内容src_tokens的有效长度,转化为tensor(用列表创建tensor)
    enc_valid_len = torch.tensor([len(src_tokens)], device=device)
    #处理src_tokens太长/太短的问题:截断或者补充pad ,  num_steps表示隔多长截断一次, 覆盖赋值
    src_tokens = dltools.truncate_pad(src_tokens, num_steps, src_vocab['<pad>'])
    
    #给src_tokens增加一个维度来表示批次,  获取enc_X 编码器的输入数据
    enc_X = torch.unsqueeze(torch.tensor(src_tokens, dtype=torch.long, device=device), dim=0)
    #向网络的编码器中传入enc_X, enc_valid_len,获取编码器的输出结果
    enc_outputs = net.encoder(enc_X, enc_valid_len)
    #将编码器的输出结果enc_output和有效长度enc_valid_len传入解码器中,获取解码器的输出结果初始化状态dec_state
    dec_state = net.decoder.init_state(enc_outputs, enc_valid_len)
    
    #给预测结果也提前添加一个维度,   tgt_vocab预测词汇表的第一个词应该是文本开头的bos
    dec_X = torch.unsqueeze(torch.tensor([tgt_vocab['<bos>']], dtype=torch.long, device=device), dim=0)
    
    output_seq = []
    for _ in range(num_steps):  #循环子序列长度次数
        #将dec_state, dec_X输入网络的解码器中
        Y, dec_state = net.decoder(dec_X, dec_state)
        #将Y重新赋值给dec_X,实现循环输入
        dec_X = Y.argmax(dim=2)  #将Y的vocab_size对应的索引2维度聚合找最大值(预测的值)
        
        #获取预测值:将dec_X去掉一个batc_size维度(此时batc_size=1,就一批数据,可以不要这个维度)
        pred = dec_X.squeeze(dim=0).type(torch.int32).item()
        #判断结束的条件
        if pred == tgt_vocab['<eos>']:
            break
        output_seq.append(pred)
    #返回值:按照索引返回对应词表中的词
    return ' '.join(tgt_vocab.to_tokens(output_seq))

 3.评估代码

seq2seq的评估指标: BLEU: bilingual evaluation understudy 双语互译质量评估辅助工具

def bleu(pred_seq, label_seq, k):
    """
    pred_seq:预测序列
    label_seq:真实序列
    k: 设定几元连续
    """
    #pred_seq, label_seq预测与目标序列的空格分隔处理(分词)
    pred_tokens, label_tokens = pred_seq.split(' '), label_seq.split(' ')
    #获取预测词与目标词的长度
    len_pred, len_label = len(pred_tokens), len(label_tokens)
    #计算bleu的左边部分_比较最小值
    score = math.exp(min(0, 1 - (len_label / len_pred)))
    for n in range(1, k + 1):  #range左闭右开   #分几元连续的情况
        #赋值 ,  #num_matches:预测值与目标值匹配的数量, 
        #collections.defaultdict(int)创建了一个默认值为int的字典  label_subs
        num_matches, label_subs = 0, collections.defaultdict(int)
        #循环连续词元的数量
        for i in range(len_label - n + 1):
            #若预测的词能与目标值匹配上
            label_subs[' '.join(label_tokens[i: i + n])] += 1
        
        for i in range(len_pred - n + 1):
            #若能匹配上
            if label_subs[' '.join(pred_tokens[i: i + n])] > 0:
                num_matches += 1  #匹配数+1
                label_subs[' '.join(pred_tokens[i: i + n])] -= 1
        score *=  math.pow(num_matches / (len_pred - n + 1), math.pow(0.5, n)) 
    return score
   
# 开始预测
engs = ['go .', 'i lost .', 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
for eng, fra in zip(engs, fras):
    translation = predict_seq2seq(net, eng, src_vocab, tgt_vocab, num_steps, device)
    print(f'{eng} => {translation}, bleu {bleu(translation, fra, k=2):.3f}')
go . => va !, bleu 1.000
i lost . => j'ai perdu perdu ., bleu 0.783
he's calm . => il <unk> gagné suis perdu ., bleu 0.000
i'm home . => je suis chez nous <unk> !, bleu 0.562

 4.知识点个人理解

标签:vocab,pred,代码,seq2seq,len,label,num,机器翻译,device
From: https://blog.csdn.net/Hiweir/article/details/142354850

相关文章

  • 机器翻译之数据处理
    目录1.导包  2.读取本地数据3.定义函数:数据预处理  4.定义函数:词元化 5.统计每句话的长度的分布情况6.获取词汇表7. 截断或者填充文本序列 8.将机器翻译的文本序列转换成小批量tensor 9.加载数据10.知识点个人理解1.导包 #导包importosimporttorch......
  • 听说ChatGPT o1推理模型即将问世,传统问答系统是否还有存在的必要?毕业设计:基于知识图谱
     OpenAI隆重推出全新一代的o1模型,该模型在多个领域展现出了非凡的能力,标志着人工智能技术的又一次飞跃。该模型专门解决比此前的科学、代码和数学模型能做到的更难的问题,实现复杂推理。那来看看并体验以下我们传统的问答系统的设计流程和具体面貌吧!!!1.1系统架构设计1.1.1......
  • 推荐一个很酷的脚本工具,几行代码,就能编写有用的 shell 脚本,月猛增 7.4 K Star太牛逼了
     今天给大家介绍的是gum,它是一个很酷的脚本工具。项目介绍gum是一个很棒的脚本工具,提供了高度可配置,随时可用的实用程序,只需几行代码,就能编写有用的shell脚本。让我们构建一个简单的脚本来创建提交。由下面的代码开始:#!/bin/sh询问gumchoose的提交类型:gum ch......
  • 【代码随想录Day24】回溯算法Part03
    93.复原IP地址题目链接/文章讲解:代码随想录视频讲解:回溯算法如何分割字符串并判断是合法IP?|LeetCode:93.复原IP地址_哔哩哔哩_bilibiliclassSolution{List<String>result=newArrayList<>();LinkedList<String>path=newLinkedList<>();publicL......
  • STM32流水灯程序代码及解析:三种实现方式
    STM32流水灯程序代码及解析:三种实现方式在这篇文章中,我们将介绍三种方式来实现STM32的流水灯程序,包括使用HAL库、标准库和直接操作寄存器的方法。通过这三种不同的方式。1.硬件准备STM32开发板(如STM32F4或STM32F1系列)若干LED灯(通常是4个)适当的电阻连接线2.接线图将L......
  • 毕业设计|springboot产业园区智慧公寓管理系统-|免费|代码讲解
    收藏点赞不迷路 关注作者有好处编号:springboot547springboot产业园区智慧公寓管理系统-开发语言:Java数据库:MySQL技术:Spring+SpringMVC+MyBatis工具:IDEA/Ecilpse、Navicat、Maven1.万字文档展示(部分)2.系统图片展示第5章系统详细设计这个环节需要使用前面的设......
  • 毕业设计|springboot人事管理系统论文-|免费|代码讲解
    收藏点赞不迷路 关注作者有好处编号:springboot350springboot人事管理系统论文-开发语言:Java数据库:MySQL技术:Spring+SpringMVC+MyBatis工具:IDEA/Ecilpse、Navicat、Maven1.万字文档展示(部分)2.系统图片展示......
  • JAVA毕业设计|(免费)springboot农产品智慧物流系统包含文档代码讲解
    收藏点赞不迷路 关注作者有好处编号:springboot537springboot农产品智慧物流系统开发语言:Java数据库:MySQL技术:Spring+SpringMVC+MyBatis工具:IDEA/Ecilpse、Navicat、Maven1.万字文档展示(部分)2.系统图片展示第5章系统详细设计......
  • JAVA毕业设计|(免费)springbootJAVA流浪动物救助平台-包含文档代码讲解
    收藏点赞不迷路 关注作者有好处编号:springboot530springbootJAVA流浪动物救助平台-开发语言:Java数据库:MySQL技术:Spring+SpringMVC+MyBatis工具:IDEA/Ecilpse、Navicat、Maven1.万字文档展示(部分)2.系统图片展示第5章系统详细设计系统实现部分就是将系统分析,系......
  • JAVA毕业设计|(免费)Springboot和BS架构宠物健康咨询系统包含文档代码讲解
    收藏点赞不迷路 关注作者有好处编号:springboot509Springboot和BS架构宠物健康咨询系统开发语言:Java数据库:MySQL技术:Spring+SpringMVC+MyBatis工具:IDEA/Ecilpse、Navicat、Maven1.万字文档展示(部分)2.系统图片展示第5章系统详细设计5.1管理员功能模块的实现5......