首页 > 其他分享 >手撕Transformer -- Day9 -- TransformerTrain

手撕Transformer -- Day9 -- TransformerTrain

时间:2025-01-13 11:59:56浏览次数:3  
标签:__ Transformer enc -- TransformerTrain de batch pad dec

手撕Transformer – Day9 – TransformerTrain

Transformer 网络结构图

目录

在这里插入图片描述

Transformer 网络结构

TransformerTrain 代码

Part1 库函数

# 该模块是训练transformer的,所以主要的部分在于训练的数据集dataloader怎么设置以及如何进行epoch训练
'''
# Part1主要是引入一些库的函数
'''
import torch
from torch import nn
from transformer import Transformer
from dataset import de_vocab, en_vocab, de_preprocess, en_preprocess, train_dataset, PAD_IDX
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from config import SEQ_MAX_LEN

Part2 实现一个 DeEnDataset数据集,作为一个类

'''
# Part2 设计一个dataset数据集,继承于Dataset,dataset需要实现的功能
# 1. 在初始化就要初始化好训练集和测试集合。目前比较简单只有训练数据,所以初始化的时候只需要设计两个list,作为输入和输出就行。
# 2. 需要设计一些函数,返回数据集的一些信息,比如数据集的长度,等等

'''


class DeEnDataset(Dataset):
    def __init__(self):
        super().__init__()

        # 初始化的时候,要设置好所有的初始数据,这里的数据存储,主要是通过,list来存储的,并且输入和输出是分开存储的。
        self.enc_x = []
        self.dec_x = []
        for de, en in train_dataset:
            # 第一步分词,预处理
            de_tokens, de_ids = de_preprocess(de)
            en_tokens, en_ids = en_preprocess(en)
            # 判断序列长度是否超限度,对于超限度的句子直接去除了,这里的目的单纯是因为这种长序列的少见,以及内存,以及训练效果不佳啥的,实际是可以训练的
            if len(de_ids) > SEQ_MAX_LEN or len(en_ids) > SEQ_MAX_LEN:
                continue
            # 一个是decoder_x(输出的x),一个是encoder_x(输入的x)
            self.enc_x.append(de_ids)
            self.dec_x.append(en_ids)


    # 获取长度
    def __len__(self):
        return len(self.enc_x)

    # 获取对应元素
    def __getitem__(self, index):
        return self.enc_x[index], self.dec_x[index]

Part3 batch处理,Tensor+Padding

def collate_fn(batch):
    enc_index_batch = []
    dec_index_batch = []

    # 遍历tensor化,到list里面去
    for enc_x, dec_x in batch:
        enc_index_batch.append(torch.tensor(enc_x, dtype=torch.long))
        dec_index_batch.append(torch.tensor(dec_x, dtype=torch.long))

    # 然后进行padding,因为pad_sequence只能在tensorlist用,应该batch_first表示张量以batch为第一个维度也就是(batch,seq_len)
    pad_enc_x = pad_sequence(enc_index_batch, batch_first=True, padding_value=PAD_IDX)
    pad_dec_x = pad_sequence(dec_index_batch, batch_first=True, padding_value=PAD_IDX)

    # 形状为 (batch, batchmax_seq_len),所以可能存在不同batch,这个句子长度是不一样的。
    # 所以为什么 position_emdding 里面有个 seq_max_len = 5000,然后取其前面部分的,因为每个batch可能句子长度不一样,所以位置编码要随时适应句子长度。
    return pad_enc_x, pad_dec_x

Part4 测试-训练

'''
# Part4 测试,真正开始训练
'''

if __name__=='__main__':
    dataset = DeEnDataset()
    dataloader = DataLoader(dataset, batch_size=200, shuffle=True, collate_fn=collate_fn)

    # 尝试看看有没有现有模型,如果有现有模型就加载进行后训练,反之则创建一个进行重新训练,这里主要是用于前向传播
    try:
        transformer = torch.load('checkpoints/model.pth')
    except:
        transformer = Transformer(
            de_vocab_size=len(de_vocab),
            en_vocab_size=len(en_vocab),
            emd_size=512,
            head=8,
            q_k_size=64,
            v_size=64,
            f_size=2048,
            nums_encoder_block=6,
            nums_decoder_block=6,
            dropout=0.1,
            seq_max_len=SEQ_MAX_LEN
        )

    # 初始化损失
    loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_IDX)

    # 初始化优化器,反向传播更新参数用
    optimizer = torch.optim.SGD(transformer.parameters(), lr=1e-1, momentum=0.99)

    # 开始训练
    transformer.train()
    EPOCHS = 300
    for epoch in range(EPOCHS):
        batch_i = 0
        loss_sum = 0
        for pad_enc_x, pad_dec_x in dataloader:
            # 这里一个去掉第一个词,一个去掉最后一个词,有说法的(所以相当于预测每下个词的概率)
            real_dec_z = pad_dec_x[:, 1:]  # decoder正确输出
            pad_enc_x = pad_enc_x
            pad_dec_x = pad_dec_x[:, :-1]  # decoder实际输入
            dec_z = transformer(pad_enc_x, pad_dec_x)  # decoder实际输出

            batch_i += 1
            print(dec_z.size())
            loss = loss_fn(dec_z.reshape(-1, dec_z.size()[-1]), real_dec_z.reshape(-1))  # 把整个batch中的所有词拉平
            loss_sum += loss.item()
            print('epoch:{} batch:{} loss:{}'.format(epoch, batch_i, loss.item()))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        torch.save(transformer, 'checkpoints/model.pth'.format(epoch))

参考

视频讲解:transformer-带位置信息的词嵌入向量_哔哩哔哩_bilibili

github代码库:github.com

标签:__,Transformer,enc,--,TransformerTrain,de,batch,pad,dec
From: https://blog.csdn.net/m0_62030579/article/details/145112549

相关文章

  • ESP32模拟IIC,0.96英寸OLED(四针),改编自江科大/江协。
    #喜欢的宝子可以copy。#本文采用ArduinoIDE开发,用到了多文件形式。点击这里可以创建文件。下面直接分享代码和效果图。效果图0.96OLED.ino #include"OLED.h"voidsetup(){OLED_Init();OLED_ShowString(1,1,"sugkug");}intcnt=0;voidloop()......
  • nginx 简单实践:静态资源部署、URL 重写【nginx 实践系列之一】
    〇、前言本文为nginx简单实践系列文章之一,主要简单实践了两个内容:静态资源部署、重写,仅供参考。关于Nginx基础,以及安装和配置详解,可以参考博主过往文章:https://www.cnblogs.com/hnzhengfy/p/Nginx.html 一、静态资源部署当前项目的结构基本上都是前后端分离,前端的相关资......
  • 微软正式开源超强小模型Phi-4 性能测试超越GPT-4o、Llama-3.1
    微软近期在HuggingFace平台上发布了名为Phi-4的小型语言模型,这款模型的参数量仅为140亿,但在多项性能测试中表现出色,超越了众多知名模型,包括OpenAI的GPT-4o及其他同类开源模型如Qwen2.5和Llama-3.1。在之前的在美国数学竞赛AMC的测试中,Phi-4获得了91.8分,显著优......
  • 软件架构中的CS架构和BS架构
    Client/Server        Client/Server,即客户端/服务器架构,是一种典型的两层架构,在计算机网络和软件开发领域有着广泛的应用。    主要特点        -专用客户端应用程序                -C/S架构的客户端通常是安装在用户设备上的......
  • ​Stability AI 推出 SPAR3D:单图像生成 3D 对象一秒钟搞定
    在刚刚结束的CES展会上,StabilityAI宣布推出一种名为SPAR3D(StablePointAware3D)的创新方法,这种两阶段的3D生成技术能够在不到一秒的时间内,从单个图像中生成精确的3D对象。该技术的推出为游戏开发者、产品设计师和环境构建者提供了全新的3D原型设计方式。SPAR3D的......
  • Python异步编程在股票交易系统中的应用:如何减少延迟提升效率
    炒股自动化:申请官方API接口,散户也可以python炒股自动化(0),申请券商API接口python炒股自动化(1),量化交易接口区别Python炒股自动化(2):获取股票实时数据和历史数据Python炒股自动化(3):分析取回的实时数据和历史数据Python炒股自动化(4):通过接口向交易所发送订单Python炒股自动化(5):......
  • 2025毕设python银行账户管理系统程序+论文
    本系统(程序+源码+数据库+调试部署+开发环境)带论文文档1万字以上,文末可获取,系统界面在最后面。系统程序文件列表开题报告内容一、选题背景关于银行账户管理系统的研究,现有研究主要以大型银行或通用银行系统为主,专门针对python实现的银行账户管理系统的研究较少。因此本选题......
  • linux常用命令(2)[常用快捷键, clear, date, ping, ps, kill, man, help, info ]
    常用快捷键清空命令行界面  clear示例显示时间 datedate和date-R命令的区别如下"CST"表示"ChinaStandardTime",即中国标准时间中国标准时间是协调世界时(UTC)的东八区时间,也就是UTC+8:00网络测试命令  ping用于测试主机之间网络的连通性上面的截图......
  • 爬虫基础之爬取歌曲宝歌曲批量下载
    声明:本案列仅供学习交流使用任何用于非法用途均与本作者无关需求分析:网站:邓紫棋-mp3在线免费下载-歌曲宝-找歌就用歌曲宝-MP3音乐高品质在线免费下载(gequbao.com)     爬取歌曲名 歌曲实现歌手名称下载所有歌曲 本案列所使用的模块requests(发送HTTP......
  • 别再“硬扛”了!稳定性保障主导权切换硬核指南:运维 or QA,何时“换帅”才能止损?
    相信不少朋友都有过这样的经历:线上告警突如其来,团队成员立刻紧张起来,争分夺秒地排查问题、快速止损。在稳定性保障这条道路上,谁来主导,至关重要。我曾身处美团金融团队,深知在应对大流量冲击、快速止损方面的运维主导模式的威力。那种对系统运行状态的精准把握,对预案执行的果断高效......