首页 > 编程语言 >[Python急救站]基于Transformer Models模型完成GPT2的学生AIGC学习训练模型

[Python急救站]基于Transformer Models模型完成GPT2的学生AIGC学习训练模型

时间:2024-04-29 18:56:02浏览次数:16  
标签:急救站 __ tokenizer GPT2 模型 token input model self

为了AIGC的学习,我做了一个基于Transformer Models模型完成GPT2的学生AIGC学习训练模型,指在训练模型中学习编程AI。

在编程之前需要准备一些文件:

首先,先win+R打开运行框,输入:PowerShell后

输入:

pip install -U huggingface_hub

下载完成后,指定我们的环境变量:

$env:HF_ENDPOINT = "https://hf-mirror.com"

然后下载模型:

huggingface-cli download --resume-download gpt2 --local-dir "D:\Pythonxiangmu\PythonandAI\Transformer Models\gpt-2"

这边我的目录是我要下载的工程目录地址

然后下载数据量:

huggingface-cli download --repo-type dataset --resume-download wikitext --local-dir "D:\Pythonxiangmu\PythonandAI\Transformer Models\gpt-2"

这边我的目录是我要下载的工程目录地址

所以两个地址记得更改成自己的工程目录下(建议放在创建一个名为gpt-2的文件夹)

在PowerShell中下载完这些后,可以开始我们的代码啦

import torch
from torch.utils.data import Dataset, DataLoader
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    AdamW,
    get_linear_schedule_with_warmup,
    set_seed,
)
from torch.optim import AdamW

# 设置随机种子以确保结果可复现
set_seed(42)


class TextDataset(Dataset):
    def __init__(self, tokenizer, texts, block_size=128):
        self.tokenizer = tokenizer
        self.examples = [
            self.tokenizer(text, return_tensors="pt", padding='max_length', truncation=True, max_length=block_size) for
            text
            in texts]
        # 在tokenizer初始化后,确保unk_token已设置
        print(f"Tokenizer's unk_token: {self.tokenizer.unk_token}, unk_token_id: {self.tokenizer.unk_token_id}")

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, i):
        item = self.examples[i]
        # 替换所有不在vocab中的token为unk_token_id
        for key in item.keys():
            item[key] = torch.where(item[key] >= self.tokenizer.vocab_size, self.tokenizer.unk_token_id, item[key])
        return item


def train(model, dataloader, optimizer, scheduler, de, tokenizer):
    model.train()
    for batch in dataloader:
        input_ids = batch['input_ids'].to(de)
        # 添加日志输出检查input_ids
        if torch.any(input_ids >= model.config.vocab_size):
            print("Warning: Some input IDs are outside the model's vocabulary.")
            print(f"Max input ID: {input_ids.max()}, Vocabulary Size: {model.config.vocab_size}")

        attention_mask = batch['attention_mask'].to(de)
        labels = input_ids.clone()
        labels[labels[:, :] == tokenizer.pad_token_id] = -100

        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        loss.backward()

        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()


def main():
    local_model_path = "D:/Pythonxiangmu/PythonandAI/Transformer Models/gpt-2"
    tokenizer = AutoTokenizer.from_pretrained(local_model_path)

    # 确保pad_token已经存在于tokenizer中,对于GPT-2,它通常自带pad_token
    if tokenizer.pad_token is None:
        special_tokens_dict = {'pad_token': '[PAD]'}
        tokenizer.add_special_tokens(special_tokens_dict)
        model = AutoModelForCausalLM.from_pretrained(local_model_path, pad_token_id=tokenizer.pad_token_id)
    else:
        model = AutoModelForCausalLM.from_pretrained(local_model_path)

    model.to(device)

    train_texts = [
        "The quick brown fox jumps over the lazy dog.",
        "In the midst of chaos, there is also opportunity.",
        "To be or not to be, that is the question.",
        "Artificial intelligence will reshape our future.",
        "Every day is a new opportunity to learn something.",
        "Python programming enhances problem-solving skills.",
        "The night sky sparkles with countless stars.",
        "Music is the universal language of mankind.",
        "Exploring the depths of the ocean reveals hidden wonders.",
        "A healthy mind resides in a healthy body.",
        "Sustainability is key for our planet's survival.",
        "Laughter is the shortest distance between two people.",
        "Virtual reality opens doors to immersive experiences.",
        "The early morning sun brings hope and vitality.",
        "Books are portals to different worlds and minds.",
        "Innovation distinguishes between a leader and a follower.",
        "Nature's beauty can be found in the simplest things.",
        "Continuous learning fuels personal growth.",
        "The internet connects the world like never before."
        # 更多训练文本...
    ]

    dataset = TextDataset(tokenizer, train_texts, block_size=128)
    dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

    optimizer = AdamW(model.parameters(), lr=5e-5)
    total_steps = len(dataloader) * 5  # 假设训练5个epoch
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)

    for epoch in range(5):  # 训练5个epoch
        train(model, dataloader, optimizer, scheduler, device, tokenizer)  # 使用正确的变量名dataloader并传递tokenizer

    # 保存微调后的模型
    model.save_pretrained("path/to/save/fine-tuned_model")
    tokenizer.save_pretrained("path/to/save/fine-tuned_tokenizer")


if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    main()

这个代码只训练了5个epoch,有一些实例文本,记得调成直接的路径后,运行即可啦。

如果有什么问题可以随时在评论区或者是发个人邮箱:[email protected]

标签:急救站,__,tokenizer,GPT2,模型,token,input,model,self
From: https://www.cnblogs.com/python120/p/18166491

相关文章

  • OSI七层模型
    OSI(OpenSystemsInterconnection)模型是计算机网络体系结构的一种标准化框架,由国际标准化组织(ISO)制定,用于定义计算机网络通信的不同层次和功能。OSI模型将网络通信分解为七个抽象的层次,每个层次都有其特定的功能和责任,通过层次间的交互和协作,实现了网络通信的可靠性、安全性和高效......
  • 一分钟部署 Llama3 中文大模型,没别的,就是快
    前段时间百度创始人李彦宏信誓旦旦地说开源大模型会越来越落后,闭源模型会持续领先。随后小扎同学就给了他当头一棒,向他展示了什么叫做顶级开源大模型。美国当地时间4月18日,Meta在官网上发布了两款开源大模型,参数分别达到80亿(8B)和700亿(70B),是目前同体量下性能最好的开......
  • python使用langchain调用本地大模型
    参考https://www.cnblogs.com/scarecrow-blog/p/17875127.html模型下载之前说过一次https://www.cnblogs.com/qcy-blog/p/18165717也可直接去官网,把所有文件都点一遍fromlangchainimportPromptTemplate,LLMChainimporttorchfromtransformersimportAutoTokenizer,A......
  • PYTHON 用几何布朗运动模型和蒙特卡罗MONTE CARLO随机过程模拟股票价格可视化分析耐克
    原文链接:http://tecdat.cn/?p=27099最近我们被客户要求撰写关于蒙特卡罗的研究报告,包括一些图形和统计输出。金融资产/证券已使用多种技术进行建模。该项目的主要目标是使用几何布朗运动模型和蒙特卡罗模拟来模拟股票价格。该模型基于受乘性噪声影响的随机(与确定性相反)变量该项......
  • PYTHON用时变马尔可夫区制转换(MARKOV REGIME SWITCHING)自回归模型分析经济时间序列|附
    全文下载链接:http://tecdat.cn/?p=22617最近我们被客户要求撰写关于MRS的研究报告,包括一些图形和统计输出。本文提供了一个在统计模型中使用马可夫转换模型模型的例子,来复现Kim和Nelson(1999)中提出的一些结果。它应用了Hamilton(1989)的滤波器和Kim(1994)的平滑器  %matplot......
  • python大模型下载HuggingFace的镜像hf-mirror
    hf-mirror.com的包如何下载pipinstall-Uhuggingface_hub设置环境变量以使用镜像站:exportHF_ENDPOINT=https://hf-mirror.com对于WindowsPowershell,使用:$env:HF_ENDPOINT="https://hf-mirror.com"使用huggingface-cli下载模型:huggingface-clidownload--resum......
  • 带你开发一个视频动态手势识别模型
    本文分享自华为云社区《CNN-VIT视频动态手势识别【玩转华为云】》,作者:HouYanSong。CNN-VIT视频动态手势识别人工智能的发展日新月异,也深刻的影响到人机交互领域的发展。手势动作作为一种自然、快捷的交互方式,在智能驾驶、虚拟现实等领域有着广泛的应用。手势识别的任务是,当......
  • 倾斜摄影三维模型数据在模型调色应用分析
    倾斜摄影三维模型数据在模型调色应用分析 倾斜摄影三维模型数据是一种通过倾斜摄影技术获取的具有高精度的地表三维模型数据。它可以提供丰富的地理和地形信息,广泛应用于城市规划、土地管理、环境保护等领域。在模型调色应用分析中,倾斜摄影三维模型数据可以发挥重要的作用......
  • EPAI手绘建模APP资源管理和模型编辑器2
    g)矩形  图26模型编辑器-矩形i.修改矩形的中心位置。ii.修改矩形的长度和宽度。h)正多边形图27模型编辑器-内接正多边形图28模型编辑器-外切正多边形i.修改正多边形的中心位置。ii.修改正多边形中心距离端点的长度。iii.修改正多边形的阶数。阶数为3,表示......
  • ThinkPHP6 多模型关联查询操作记录
    新入职后组长安排了一个小的管理项目来检验能力,后发现自身对于ThinkPHP框架中的模型关联属于一窍不通,故被终止项目叫楼主去恶补ThinkPHP6框架知识。对于多联表查询之前本人一直使用join方法,但是此方法对于代码效率和维护都有较大影响,故在此尝试使用ThinkPHP框架内置的模型......