手撕Transformer – Day9 – TransformerTrain
Transformer 网络结构图
目录
TransformerTrain 代码
Part1 库函数
# 该模块是训练transformer的,所以主要的部分在于训练的数据集dataloader怎么设置以及如何进行epoch训练
'''
# Part1主要是引入一些库的函数
'''
import torch
from torch import nn
from transformer import Transformer
from dataset import de_vocab, en_vocab, de_preprocess, en_preprocess, train_dataset, PAD_IDX
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from config import SEQ_MAX_LEN
Part2 实现一个 DeEnDataset数据集,作为一个类
'''
# Part2 设计一个dataset数据集,继承于Dataset,dataset需要实现的功能
# 1. 在初始化就要初始化好训练集和测试集合。目前比较简单只有训练数据,所以初始化的时候只需要设计两个list,作为输入和输出就行。
# 2. 需要设计一些函数,返回数据集的一些信息,比如数据集的长度,等等
'''
class DeEnDataset(Dataset):
def __init__(self):
super().__init__()
# 初始化的时候,要设置好所有的初始数据,这里的数据存储,主要是通过,list来存储的,并且输入和输出是分开存储的。
self.enc_x = []
self.dec_x = []
for de, en in train_dataset:
# 第一步分词,预处理
de_tokens, de_ids = de_preprocess(de)
en_tokens, en_ids = en_preprocess(en)
# 判断序列长度是否超限度,对于超限度的句子直接去除了,这里的目的单纯是因为这种长序列的少见,以及内存,以及训练效果不佳啥的,实际是可以训练的
if len(de_ids) > SEQ_MAX_LEN or len(en_ids) > SEQ_MAX_LEN:
continue
# 一个是decoder_x(输出的x),一个是encoder_x(输入的x)
self.enc_x.append(de_ids)
self.dec_x.append(en_ids)
# 获取长度
def __len__(self):
return len(self.enc_x)
# 获取对应元素
def __getitem__(self, index):
return self.enc_x[index], self.dec_x[index]
Part3 batch处理,Tensor+Padding
def collate_fn(batch):
enc_index_batch = []
dec_index_batch = []
# 遍历tensor化,到list里面去
for enc_x, dec_x in batch:
enc_index_batch.append(torch.tensor(enc_x, dtype=torch.long))
dec_index_batch.append(torch.tensor(dec_x, dtype=torch.long))
# 然后进行padding,因为pad_sequence只能在tensorlist用,应该batch_first表示张量以batch为第一个维度也就是(batch,seq_len)
pad_enc_x = pad_sequence(enc_index_batch, batch_first=True, padding_value=PAD_IDX)
pad_dec_x = pad_sequence(dec_index_batch, batch_first=True, padding_value=PAD_IDX)
# 形状为 (batch, batchmax_seq_len),所以可能存在不同batch,这个句子长度是不一样的。
# 所以为什么 position_emdding 里面有个 seq_max_len = 5000,然后取其前面部分的,因为每个batch可能句子长度不一样,所以位置编码要随时适应句子长度。
return pad_enc_x, pad_dec_x
Part4 测试-训练
'''
# Part4 测试,真正开始训练
'''
if __name__=='__main__':
dataset = DeEnDataset()
dataloader = DataLoader(dataset, batch_size=200, shuffle=True, collate_fn=collate_fn)
# 尝试看看有没有现有模型,如果有现有模型就加载进行后训练,反之则创建一个进行重新训练,这里主要是用于前向传播
try:
transformer = torch.load('checkpoints/model.pth')
except:
transformer = Transformer(
de_vocab_size=len(de_vocab),
en_vocab_size=len(en_vocab),
emd_size=512,
head=8,
q_k_size=64,
v_size=64,
f_size=2048,
nums_encoder_block=6,
nums_decoder_block=6,
dropout=0.1,
seq_max_len=SEQ_MAX_LEN
)
# 初始化损失
loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_IDX)
# 初始化优化器,反向传播更新参数用
optimizer = torch.optim.SGD(transformer.parameters(), lr=1e-1, momentum=0.99)
# 开始训练
transformer.train()
EPOCHS = 300
for epoch in range(EPOCHS):
batch_i = 0
loss_sum = 0
for pad_enc_x, pad_dec_x in dataloader:
# 这里一个去掉第一个词,一个去掉最后一个词,有说法的(所以相当于预测每下个词的概率)
real_dec_z = pad_dec_x[:, 1:] # decoder正确输出
pad_enc_x = pad_enc_x
pad_dec_x = pad_dec_x[:, :-1] # decoder实际输入
dec_z = transformer(pad_enc_x, pad_dec_x) # decoder实际输出
batch_i += 1
print(dec_z.size())
loss = loss_fn(dec_z.reshape(-1, dec_z.size()[-1]), real_dec_z.reshape(-1)) # 把整个batch中的所有词拉平
loss_sum += loss.item()
print('epoch:{} batch:{} loss:{}'.format(epoch, batch_i, loss.item()))
optimizer.zero_grad()
loss.backward()
optimizer.step()
torch.save(transformer, 'checkpoints/model.pth'.format(epoch))
参考
视频讲解:transformer-带位置信息的词嵌入向量_哔哩哔哩_bilibili
github代码库:github.com
标签:__,Transformer,enc,--,TransformerTrain,de,batch,pad,dec From: https://blog.csdn.net/m0_62030579/article/details/145112549