首页 > 其他分享 >datawhale第二期夏令营基于术语词典干预的机器翻译挑战赛——baseline【笔记】

datawhale第二期夏令营基于术语词典干预的机器翻译挑战赛——baseline【笔记】

时间:2024-07-12 10:55:28浏览次数:16  
标签:dim en zh baseline self batch datawhale 机器翻译 sentences

嗨!大家好,这一期我们来看一下夏令营所提供的baseline。

首先,baseline是什么

对于很多第一次参加datawhale夏令营的小伙伴,看到手册里发布的baseline,都会有这样的疑问。

baseline是跑通比赛的第一个代码,里面用到的算法不会很复杂,更基础一些;本次baseline是构建和训练一个基于PyTorch的序列到序列(Seq2Seq)机器翻译模型。

一、数据处理

1.1 TranslationDataset类

这个类就像是一个聪明的助手,帮我们准备翻译所需的数据。在baseline中,他完成了以下的工作:

  1. 读取数据:从文件中读取英语和中文句子对。

  2. 制作词典:收集所有英语词和中文字,并给它们编号。

  3. 特殊词处理:确保专业术语(terminology)被包含在词典中。

  4. 数字化准备:创建从单词到数字的映射(word2idx)。

class TranslationDataset(Dataset):
    def __init__(self, filename, terminology):
        self.data = []
        with open(filename, 'r', encoding='utf-8') as f:
            for line in f:
                en, zh = line.strip().split('\t')
                self.data.append((en, zh))
        
        self.terminology = terminology
        
        # 创建词汇表
        self.en_tokenizer = get_tokenizer('basic_english')
        self.zh_tokenizer = list  # 使用字符级分词
        
        en_vocab = Counter(self.terminology.keys())
        zh_vocab = Counter()
        
        for en, zh in self.data:
            en_vocab.update(self.en_tokenizer(en))
            zh_vocab.update(self.zh_tokenizer(zh))
        
        # 添加特殊标记和常用词到词汇表
        self.en_vocab = ['<pad>', '<sos>', '<eos>'] + list(self.terminology.keys()) + [word for word, _ in en_vocab.most_common(10000)]
        self.zh_vocab = ['<pad>', '<sos>', '<eos>'] + [word for word, _ in zh_vocab.most_common(10000)]
        
        self.en_word2idx = {word: idx for idx, word in enumerate(self.en_vocab)}
        self.zh_word2idx = {word: idx for idx, word in enumerate(self.zh_vocab)}

1.2 collate_fn函数

这个函数帮助我们把不同长度的句子整理成一批。这个函数做了两件主要的事:

  1. 收集一批数据中的英语和中文句子。

  2. 把它们填充到同样的长度(用0填充),计算机能更好地处理它们。

def collate_fn(batch):
    en_batch, zh_batch = [], []
    for en_item, zh_item in batch:
        en_batch.append(en_item)
        zh_batch.append(zh_item)
    
    en_batch = nn.utils.rnn.pad_sequence(en_batch, padding_value=0, batch_first=True)
    zh_batch = nn.utils.rnn.pad_sequence(zh_batch, padding_value=0, batch_first=True)
    
    return en_batch, zh_batch

二. 模型架构

2.1 编码器(Encoder)

编码器负责理解输入的英语句子,embedding把每个英语单词变成一串数字,rnn(GRU)用来理解整个句子的含义,dropout帮助模型不要"死记硬背",而是真正理解句子。

class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, hid_dim, n_layers, dropout):
        super().__init__()
        self.embedding = nn.Embedding(input_dim, emb_dim)
        self.rnn = nn.GRU(emb_dim, hid_dim, n_layers, dropout=dropout, batch_first=True)
        self.dropout = nn.Dropout(dropout)
​
    def forward(self, src):
        embedded = self.dropout(self.embedding(src))
        outputs, hidden = self.rnn(embedded)
        return outputs, hidden

2.2 解码器(Decoder)

解码器负责生成中文翻译。在解码器中它也使用embedding来处理中文字,rnn(GRU)帮助它记住之前翻译的内容,fc_out用来预测下一个中文字。

class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, hid_dim, n_layers, dropout):
        super().__init__()
        self.output_dim = output_dim
        self.embedding = nn.Embedding(output_dim, emb_dim)
        self.rnn = nn.GRU(emb_dim, hid_dim, n_layers, dropout=dropout, batch_first=True)
        self.fc_out = nn.Linear(hid_dim, output_dim)
        self.dropout = nn.Dropout(dropout)
​
    def forward(self, input, hidden):
        input = input.unsqueeze(1)
        embedded = self.dropout(self.embedding(input))
        output, hidden = self.rnn(embedded, hidden)
        prediction = self.fc_out(output.squeeze(1))
        return prediction, hidden

2.3 Seq2Seq模型

这是把编码器和解码器组合在一起的完整翻译模型。这个模型用编码器理解英语句子,用解码器一个字一个字地生成中文翻译,有时候(由teacher_forcing_ratio控制)会使用正确的中文来指导翻译,这叫"教师强制"。

class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
​
    def forward(self, src, trg, teacher_forcing_ratio=0.5):
        batch_size = src.shape[0]
        trg_len = trg.shape[1]
        trg_vocab_size = self.decoder.output_dim
​
        outputs = torch.zeros(batch_size, trg_len, trg_vocab_size).to(self.device)
        
        _, hidden = self.encoder(src)
        
        input = trg[:, 0]
        
        for t in range(1, trg_len):
            output, hidden = self.decoder(input, hidden)
            outputs[:, t] = output
            teacher_force = random.random() < teacher_forcing_ratio
            top1 = output.argmax(1)
            input = trg[:, t] if teacher_force else top1
​
        return outputs

三、BLEU评分函数

BLEU(Bilingual Evaluation Understudy)是一种广泛使用的机器翻译评估方法。它通过比较机器翻译的结果与人工翻译的参考文本来评估翻译质量。

from sacrebleu.metrics import BLEU
​
def evaluate_bleu(model, dataset, src_file, ref_file, terminology, device):
    model.eval()
    src_sentences = load_sentences(src_file)
    ref_sentences = load_sentences(ref_file)
    
    translated_sentences = []
    for src in src_sentences:
        translated = translate_sentence(src, model, dataset, terminology, device)
        translated_sentences.append(translated)
    
    bleu = BLEU()
    score = bleu.corpus_score(translated_sentences, [ref_sentences])
    
    return score

BLEU怎么评估呢

  1. 准备工作

    model.eval()
    src_sentences = load_sentences(src_file)
    ref_sentences = load_sentences(ref_file)
    • model.eval() 告诉模型现在是评估时间,不是训练时间。

    • 我们从文件中加载源语言(英语)句子和参考翻译(中文)句子。

  2. 翻译过程

    translated_sentences = []
    for src in src_sentences:
        translated = translate_sentence(src, model, dataset, terminology, device)
        translated_sentences.append(translated)
    • 我们用我们的模型翻译每一个英语句子。

    • 这就像是给模型一次"考试",看它能把每个句子翻译得多好。

  3. 计算BLEU分数

    bleu = BLEU()
    score = bleu.corpus_score(translated_sentences, [ref_sentences])
    • 我们使用BLEU评分系统来给模型的翻译打分。

    • 它会比较模型的翻译(translated_sentences)和人工翻译(ref_sentences)。

BLEU是如何工作的?

BLEU就像一个严格但公平的老师,它这样给翻译打分:

  1. 精确度检查:它会看模型翻译中的词(或短语)有多少出现在了人工翻译中。

  2. 完整性检查:它也会确保模型的翻译不会太短。不能只翻对了一个词就完事儿了。

  3. 长度惩罚:如果机器翻译比人工翻译短太多或长太多,分数会降低。

  4. N-gram匹配:它不仅看单个词,还会看词组。比如"人工智能"这四个字在一起出现,比单独出现"人"、"工"、"智"、"能"要好。

最后,BLEU会给出一个0到100之间的分数。分数越高,说明机器翻译越接近人工翻译,质量越好。

通过这种评分方式,我们可以客观地评估我们的翻译模型性能,并且可以用它来比较不同模型或者跟踪同一个模型在训练过程中的进步。

标签:dim,en,zh,baseline,self,batch,datawhale,机器翻译,sentences
From: https://blog.csdn.net/han_qikemeng/article/details/140373424

相关文章

  • DataWhale夏令营(机器学习方向)——分子性质AI预测挑战赛
     #AI夏令营#Datawhale#夏令营该笔记是在博主Mr.chenlex跑分后的基础上加以改进,原文连接:DatawhaleAI夏令营-机器学习:分子性质AI预测挑战赛#ai夏令营datawhale#夏令营-CSDN博客Baseline改进前后代码介绍Baseline改进前后跑分结果直接套用原博主的Baseline(需另进行库的......
  • 【2024datawhale 分子AI预测赛笔记】数据挖掘速通Baseline -分类/回归
    赛题概述精准预测分子性质有助于高效筛选出具有优异性能的候选药物。以PROTACs为例,它是一种三元复合物由目标蛋白配体、linker、E3连接酶配体组成,靶向降解目标蛋白质。(研究PROTACs技术在靶向降解目标蛋白质方面的潜力。)提醒:需要python和机器学习基础。赛事任务根据提......
  • 【基于星火大模型的群聊对话分角色要素提取BaseLine学习笔记】
    @目录项目背景项目任务我的思路Baseline详解数据抽取完整代码星火认知大模型Spark3.5Max的URL值,其他版本大模型URL值请前往文档(https://www.xfyun.cn/doc/spark/Web.html)查看星火认知大模型调用秘钥信息,请前往讯飞开放平台控制台(https://console.xfyun.cn/services/bm35)查看星火......
  • 机器翻译及实践 进阶版:基于Transformer实现机器翻译(日译中)
    机器翻译及实践进阶版:基于Transformer实现机器翻译(日译中)前言一、所需要的前置知识——Transformer1.自注意力机制1.1Query&Key&Value版注意力机制1.1.1什么是Query&Key&Value版注意力机制1.1.2为什么引入Query&Key&Value版注意力机制1.1.3如何实现Query&Key&Value......
  • 机器翻译及实践 初级版:含注意力机制的编码器—解码器模型
    机器翻译及实践初级版:含注意力机制的编码器—解码器模型前言一、什么是机器翻译?二、所需要的前置知识(一).Seq2Seq1.什么是Seq2Seq2.机器翻译为什么要用Seq2Seq3.如何使用Seq2Seq3.1编码器的实现3.2解码器的实现3.3训练模型(二).注意力机制1.什么是注意力机制2.机器翻译为......
  • 【机器学习】Datawhale-AI夏令营分子性质AI预测挑战赛
    #ai夏令营#datawhale#夏令营1.赛事简介还是大家熟悉的预测算法类:分子性质AI预测挑战赛要求选手根据提供的demo数据集,可以基于demo数据集进行数据增强、自行搜集数据等方式扩充数据集,并自行划分数据。运用深度学习、强化学习或更加优秀人工智能的方法预测PROTACs的降解......
  • Windows Security Baselines(安全基线指南) 是由微软提供的一个安全配置集合,旨在帮助组
    安全基线指南-WindowsSecurity|MicrosoftLearnWindowsSecurityBaselines(安全基线)是由微软提供的一个安全配置集合,旨在帮助组织和管理员快速部署一套推荐的安全设置,以增强Windows操作系统及其组件的安全性。这些基线覆盖了操作系统本身、MicrosoftEdge浏览器、Inter......
  • WindowsBaselineAssistant Windows安全基线核查加固助手,WindowsBaselineAssistant Wi
    GitHub-DeEpinGh0st/WindowsBaselineAssistant:Windows安全基线核查加固助手WindowsBaselineAssistantWindows安全基线核查加固助手,WindowsBaselineAssistantWindowsBaselineAssistant(WBA)是一个用于检测和加固Windows安全基线的辅助工具,借助此工具你可以免去繁琐的......
  • 【datawhale打卡】深入剖析大模型原理——Qwen Blog
    教程及参考文档QwenBlog科普神文,一次性讲透AI大模型的核心概念Largelanguagemodels,explainedwithaminimumofmathandjargon0.前置知识由于我没有LLM基础,所以直接上手看文档看的是一头雾水。然后就去补了一下基础知识,这里算是一点简单的个人理解和总结吧。LLM......
  • datawhale-动手学图深度学习task04
    动手学图深度学习图表示学习研究在嵌入空间(EmbeddingSpace,指在高维数据被映射到低维空间的数学结构)表示图的方法,在图上表示学习核嵌入指的是同一件事,“嵌入”是指将网络中的每个节点映射到低维空间(需要深入了解节点的相似性和网络结构),旨在捕捉图结构中的拓扑信息、节点内容信......