嗨!大家好,这一期我们来看一下夏令营所提供的baseline。
首先,baseline是什么
对于很多第一次参加datawhale夏令营的小伙伴,看到手册里发布的baseline,都会有这样的疑问。
baseline是跑通比赛的第一个代码,里面用到的算法不会很复杂,更基础一些;本次baseline是构建和训练一个基于PyTorch的序列到序列(Seq2Seq)机器翻译模型。
一、数据处理
1.1 TranslationDataset类
这个类就像是一个聪明的助手,帮我们准备翻译所需的数据。在baseline中,他完成了以下的工作:
-
读取数据:从文件中读取英语和中文句子对。
-
制作词典:收集所有英语词和中文字,并给它们编号。
-
特殊词处理:确保专业术语(terminology)被包含在词典中。
-
数字化准备:创建从单词到数字的映射(word2idx)。
class TranslationDataset(Dataset): def __init__(self, filename, terminology): self.data = [] with open(filename, 'r', encoding='utf-8') as f: for line in f: en, zh = line.strip().split('\t') self.data.append((en, zh)) self.terminology = terminology # 创建词汇表 self.en_tokenizer = get_tokenizer('basic_english') self.zh_tokenizer = list # 使用字符级分词 en_vocab = Counter(self.terminology.keys()) zh_vocab = Counter() for en, zh in self.data: en_vocab.update(self.en_tokenizer(en)) zh_vocab.update(self.zh_tokenizer(zh)) # 添加特殊标记和常用词到词汇表 self.en_vocab = ['<pad>', '<sos>', '<eos>'] + list(self.terminology.keys()) + [word for word, _ in en_vocab.most_common(10000)] self.zh_vocab = ['<pad>', '<sos>', '<eos>'] + [word for word, _ in zh_vocab.most_common(10000)] self.en_word2idx = {word: idx for idx, word in enumerate(self.en_vocab)} self.zh_word2idx = {word: idx for idx, word in enumerate(self.zh_vocab)}
1.2 collate_fn函数
这个函数帮助我们把不同长度的句子整理成一批。这个函数做了两件主要的事:
-
收集一批数据中的英语和中文句子。
-
把它们填充到同样的长度(用0填充),计算机能更好地处理它们。
def collate_fn(batch): en_batch, zh_batch = [], [] for en_item, zh_item in batch: en_batch.append(en_item) zh_batch.append(zh_item) en_batch = nn.utils.rnn.pad_sequence(en_batch, padding_value=0, batch_first=True) zh_batch = nn.utils.rnn.pad_sequence(zh_batch, padding_value=0, batch_first=True) return en_batch, zh_batch
二. 模型架构
2.1 编码器(Encoder)
编码器负责理解输入的英语句子,embedding
把每个英语单词变成一串数字,rnn
(GRU)用来理解整个句子的含义,dropout
帮助模型不要"死记硬背",而是真正理解句子。
class Encoder(nn.Module): def __init__(self, input_dim, emb_dim, hid_dim, n_layers, dropout): super().__init__() self.embedding = nn.Embedding(input_dim, emb_dim) self.rnn = nn.GRU(emb_dim, hid_dim, n_layers, dropout=dropout, batch_first=True) self.dropout = nn.Dropout(dropout) def forward(self, src): embedded = self.dropout(self.embedding(src)) outputs, hidden = self.rnn(embedded) return outputs, hidden
2.2 解码器(Decoder)
解码器负责生成中文翻译。在解码器中它也使用embedding
来处理中文字,rnn
(GRU)帮助它记住之前翻译的内容,fc_out
用来预测下一个中文字。
class Decoder(nn.Module): def __init__(self, output_dim, emb_dim, hid_dim, n_layers, dropout): super().__init__() self.output_dim = output_dim self.embedding = nn.Embedding(output_dim, emb_dim) self.rnn = nn.GRU(emb_dim, hid_dim, n_layers, dropout=dropout, batch_first=True) self.fc_out = nn.Linear(hid_dim, output_dim) self.dropout = nn.Dropout(dropout) def forward(self, input, hidden): input = input.unsqueeze(1) embedded = self.dropout(self.embedding(input)) output, hidden = self.rnn(embedded, hidden) prediction = self.fc_out(output.squeeze(1)) return prediction, hidden
2.3 Seq2Seq模型
这是把编码器和解码器组合在一起的完整翻译模型。这个模型用编码器理解英语句子,用解码器一个字一个字地生成中文翻译,有时候(由teacher_forcing_ratio
控制)会使用正确的中文来指导翻译,这叫"教师强制"。
class Seq2Seq(nn.Module): def __init__(self, encoder, decoder, device): super().__init__() self.encoder = encoder self.decoder = decoder self.device = device def forward(self, src, trg, teacher_forcing_ratio=0.5): batch_size = src.shape[0] trg_len = trg.shape[1] trg_vocab_size = self.decoder.output_dim outputs = torch.zeros(batch_size, trg_len, trg_vocab_size).to(self.device) _, hidden = self.encoder(src) input = trg[:, 0] for t in range(1, trg_len): output, hidden = self.decoder(input, hidden) outputs[:, t] = output teacher_force = random.random() < teacher_forcing_ratio top1 = output.argmax(1) input = trg[:, t] if teacher_force else top1 return outputs
三、BLEU评分函数
BLEU(Bilingual Evaluation Understudy)是一种广泛使用的机器翻译评估方法。它通过比较机器翻译的结果与人工翻译的参考文本来评估翻译质量。
from sacrebleu.metrics import BLEU def evaluate_bleu(model, dataset, src_file, ref_file, terminology, device): model.eval() src_sentences = load_sentences(src_file) ref_sentences = load_sentences(ref_file) translated_sentences = [] for src in src_sentences: translated = translate_sentence(src, model, dataset, terminology, device) translated_sentences.append(translated) bleu = BLEU() score = bleu.corpus_score(translated_sentences, [ref_sentences]) return score
BLEU怎么评估呢
-
准备工作
model.eval() src_sentences = load_sentences(src_file) ref_sentences = load_sentences(ref_file)
-
model.eval()
告诉模型现在是评估时间,不是训练时间。 -
我们从文件中加载源语言(英语)句子和参考翻译(中文)句子。
-
-
翻译过程
translated_sentences = [] for src in src_sentences: translated = translate_sentence(src, model, dataset, terminology, device) translated_sentences.append(translated)
-
我们用我们的模型翻译每一个英语句子。
-
这就像是给模型一次"考试",看它能把每个句子翻译得多好。
-
-
计算BLEU分数
bleu = BLEU() score = bleu.corpus_score(translated_sentences, [ref_sentences])
-
我们使用BLEU评分系统来给模型的翻译打分。
-
它会比较模型的翻译(
translated_sentences
)和人工翻译(ref_sentences
)。
-
BLEU是如何工作的?
BLEU就像一个严格但公平的老师,它这样给翻译打分:
-
精确度检查:它会看模型翻译中的词(或短语)有多少出现在了人工翻译中。
-
完整性检查:它也会确保模型的翻译不会太短。不能只翻对了一个词就完事儿了。
-
长度惩罚:如果机器翻译比人工翻译短太多或长太多,分数会降低。
-
N-gram匹配:它不仅看单个词,还会看词组。比如"人工智能"这四个字在一起出现,比单独出现"人"、"工"、"智"、"能"要好。
最后,BLEU会给出一个0到100之间的分数。分数越高,说明机器翻译越接近人工翻译,质量越好。
通过这种评分方式,我们可以客观地评估我们的翻译模型性能,并且可以用它来比较不同模型或者跟踪同一个模型在训练过程中的进步。
标签:dim,en,zh,baseline,self,batch,datawhale,机器翻译,sentences From: https://blog.csdn.net/han_qikemeng/article/details/140373424