目录
1.导包
#导包
import os
import torch
import dltools
2.读取本地数据
#读取本地数据
with open('./fra-eng/fra.txt', 'r', encoding='utf-8') as f:
raw_text = f.read() #一次读取所有数据
print(raw_text[:75])
Go. Va ! Hi. Salut ! Run! Cours ! Run! Courez ! Who? Qui ? Wow! Ça alors !
3.定义函数:数据预处理
#数据预处理
def preprocess_nmt(text):
#判断标点符号前面是否有空格
def no_space(char, prev_char):
return char in set(',.!?') and prev_char != ' '
#替换识别不了的字符,替换不正常显示的空格,将大写字母变成小写
text = text.replace('\u202f', ' ').replace('\xa0', ' ').lower()
#在单词和标点之间插入空格
out = [' '+ char if i>0 and no_space(char, text[i-1]) else char for i, char in enumerate(text)]
return ''.join(out) #合并out
#测试:数据预处理
text = preprocess_nmt(raw_text)
print(text[:80])
go . va ! hi . salut ! run ! cours ! run ! courez ! who ? qui ? wow ! ça alors !
4.定义函数:词元化
#定义函数:词元化
def tokenize_nmt(text, num_examples=None):
"""
text:传入的数据文本
num_examples=None:样本数量为空,判断数据集中剩余的数据量是否满足一批所取的数据量
"""
source, target = [], []
#以换行符号\n划分每一行
for i, line in enumerate(text.split('\n')):
#if num_examples 表示不是空,相当于 if num_examples != None
if num_examples and i > num_examples:
break
#从每一行数据中 以空格键tab分割数据
parts = line.split('\t') #将英文与对应的法语分割开
if len(parts) == 2: #单词文本与标点符号两个元素
source.append(parts[0].split(' ')) #用空格分割开单词文本与标点符号两个元素
target.append(parts[1].split(' '))
return source, target
#测试词元化代码
source, target = tokenize_nmt(text)
source[:6], target[:6]
([['go', '.'], ['hi', '.'], ['run', '!'], ['run', '!'], ['who', '?'], ['wow', '!']], [['va', '!'], ['salut', '!'], ['cours', '!'], ['courez', '!'], ['qui', '?'], ['ça', 'alors', '!']])
5.统计每句话的长度的分布情况
#统计每句话的长度的分布情况
def show_list_len_pair_hist(legend, xlabel, ylabel, xlist, ylist):
dltools.set_figsize() #创建一个适当的画布
_,_,patches = dltools.plt.hist([[len(l) for l in xlist], [len(l) for l in ylist]])
dltools.plt.xlabel(xlabel) #添加x标签
dltools.plt.ylabel(ylabel) #添加y标签
for patch in patches[1].patches: #为patches[1]的柱体添加斜线
patch.set_hatch('/')
dltools.plt.legend(legend) #添加标注
#测试代码:统计每句话的长度的分布情况
show_list_len_pair_hist(['source', 'target'], '# tokens per sequence', 'count', source, target)
6. 获取词汇表
#获取词汇表
src_vocab = dltools.Vocab(source, min_freq=2, reserved_tokens=['<pad>', '<bos>', '<eos>'])
len(src_vocab)
10012
7. 截断或者填充文本序列
def truncate_pad(line, num_steps, padding_token):
"""
line:传入的数据
num_steps:子序列长度
padding_token:需要填充的词元
"""
if len(line) > num_steps:
return line[:num_steps] #太长就截断
#太短就补充
return line + [padding_token] * (num_steps - len(line)) #填充
#测试
#source[0]表示英文单词
truncate_pad(src_vocab[source[0]], 10, src_vocab['<pad>'])
[47, 4, 1, 1, 1, 1, 1, 1, 1, 1]
8.将机器翻译的文本序列转换成小批量tensor
def build_array_nmt(lines, vocab, num_steps):
#通过vocab拿到line的索引
lines = [vocab[l] for l in lines]
#每个序列结束之后+一个'eos'
lines = [l + [vocab['eos']] for l in lines]
#对每一行文本 截断或者填充文本序列,再转化为tensor
array = torch.tensor([truncate_pad(l, num_steps, vocab['<pad>']) for l in lines])
#获取有效长度
valid_len = (array != vocab['<pad>']).type(torch.int32).sum(1)
return array, valid_len
9.加载数据
def load_data_nmt(batch_size, num_steps, num_examples=600):
# 需要返回数据集的迭代器和词表
text = preprocess_nmt(raw_text)
source, target = tokenize_nmt(text, num_examples)
src_vocab = dltools.Vocab(source, min_freq=2, reserved_tokens=['<pad>', '<bos>', '<eos>'])
tgt_vocab = dltools.Vocab(target, min_freq=2, reserved_tokens=['<pad>', '<bos>', '<eos>'])
src_array, src_valid_len = build_array_nmt(source, src_vocab, num_steps)
tgt_array, tgt_valid_len = build_array_nmt(target, tgt_vocab, num_steps)
data_arrays = (src_array, src_valid_len, tgt_array, tgt_valid_len)
data_iter = dltools.load_array(data_arrays, batch_size)
return data_iter, src_vocab, tgt_vocab
#测试代码
train_iter, src_vocab, tgt_vocab = load_data_nmt(batch_size=2, num_steps=8)
for X, X_valid_len, Y, Y_valid_len in train_iter:
print('X:', X.type(torch.int32))
print('X的有效长度:', X_valid_len)
print('Y:', Y.type(torch.int32))
print('Y的有效长度:',Y_valid_len)
break
X: tensor([[17, 20, 4, 0, 1, 1, 1, 1], [ 7, 84, 4, 0, 1, 1, 1, 1]], dtype=torch.int32) X的有效长度: tensor([4, 4]) Y: tensor([[ 11, 61, 144, 4, 0, 1, 1, 1], [ 6, 33, 17, 4, 0, 1, 1, 1]], dtype=torch.int32) Y的有效长度: tensor([5, 5])
10.知识点个人理解
标签:vocab,src,text,len,source,num,机器翻译,数据处理 From: https://blog.csdn.net/Hiweir/article/details/142344945