首页 > 其他分享 >Transformers包使用记录

Transformers包使用记录

时间:2023-08-19 14:00:25浏览次数:49  
标签:Transformers tokenizer 权重 记录 模型 base 使用 output model

  Transformers是著名的深度学习预训练模型集成库,包含NLP模型最多,CV等其他领域也有,支持预训练模型的快速使用和魔改,并且模型可以快速在不同的深度学习框架间(Pytorch/Tensorflow/Jax)无缝转移。以下记录基于HuggingFace官网教程:https://github.com/huggingface/transformers/blob/main/README_zh-hans.md

任务调用

  直接使用两行代码实现各种任务,以下举例一个情感分析任务:

from transformers import pipeline
# 使用情绪分析流水线
classifier = pipeline('sentiment-analysis', 'distilbert-base-uncased-finetuned-sst-2-english')
classifier('We are very happy to introduce pipeline to the transformers repository.')

  pipeline第一个参数传入实现任务类型,第二个参数传入预训练模型权重名。模型预训练权重名中,distilbert-base表示使用模型蒸馏训练的base bert;uncased表示模型权重无法区分大小写,数据在传入前需要小写处理;finetuned-sst-2-english表示模型权重在英文Stanford Sentiment Treebank 2数据集上进行微调。如果权重名能在当前工作目录中找到,就读取当前工作目录的文件,否则就会去HuggingFace官网下载相应的Repository。如果自动下载失败,distilbert-base-uncased-finetuned-sst-2-english的模型权重和配置文件可以通过以下方式下载:

git lfs install
git clone https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english

  下载下来一个文件夹,其中包含模型结构文件 config.json、模型权重文件 model.safetensors、分词器配置文件 tokenizer_config.json、词表文件 vocab.txt等。文件夹中有时会包含文件分词器文件 tokenizer.json,其中保存了分词到id的映射。tokenizer.json的映射与vocab.txt正好相反,因此没有tokenizer.json照样可以运行。但是除了映射之外,tokenizer.json通常还会保存一些额外的关于特殊token或是未登录词的词频信息,是会影响模型结果的。

  如果通过git模型权重下载失败,可以直接进网站下载单个权重文件并放入文件夹。其中后缀为h5、weights、ckpt、pth、safetensors、bin的文件都是模型权重。比如pth是pytorch常用的权重后缀,h5是Tensorflow的常用的权重后缀。具体保存的格式不细究,只要任意下载一个就行。Transformers默认使用Pytorch,因此通常下载pth、bin或safetensors。

  通过以上API和下载的Repository文件,可以看出Transformers把用到的预训练模型、配置文件、分词等都放在一个repository中,从而在使用时实现模型结构的自动构建以及配套预训练权重的读取,从而无需显式使用Pytorch写好与预训练权重配套的结构代码,加快预训练模型使用流程。

预训练模型调用

  如果要研究模型的推理,而不是实现具体任务。可以实现为以下代码:

from transformers import AutoTokenizer, AutoModel

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") #1
model = AutoModel.from_pretrained("bert-base-uncased") #2
inp = tokenizer("Hello world!", return_tensors="pt") #3
outp = model(**inp)

  其中#1表示读取bert-base-uncased的分词器,#2表示读取bert-base-uncased的预训练权重并构建模型。如果模型权重只下载了h5,而使用Pytorch作为后端,则需要给from_pretrained添加from_tf=True参数。#3使用分词器对输入句子进行分词,输出pytorch张量。如果设置return_tensors="tf"则分词器输出兼容tensorflow模型的张量,此时model应该使用TFAutoModel来实例化。

  如果要处理批量数据,可以给分词器传入文本列表,如:

texts = ["Hello world!", "Hello, how are you?"]
inp = tokenizer(texts, return_tensors="pt", padding=True, truncation=True)

  如果给分词器传入两段文本,分词器将它们合并,并额外生成句子类型id,用于句子顺序判别任务。第一句token标识为0,第二句token标识为1:

texts = ["Hello world!", "Hello, how are you?"]
inp = tokenizer(*texts, return_tensors="pt", padding=True, truncation=True)

自定义模型推理

  观察config.json,其中architectures字段定义了所需预训练权重所需使用的模型结构类,可以发现其它的各字段就是传入该模型结构类的参数,从而能实例化出与预训练模型权重一致的模型结构,然后再读取权重得到预训练模型。那么我们可以根据这些文件以及Transformers内置的模型结构类(继承自nn.Module),来自定义模型的数据通路。将前面的情感分类管道分解如下:

from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
from torch import nn

text = "We are very happy to introduce pipeline to the transformers repository."
model_head_name = "distilbert-base-uncased-finetuned-sst-2-english"
model = DistilBertForSequenceClassification.from_pretrained(model_head_name).to('cuda')
tokenizer = DistilBertTokenizer.from_pretrained(model_head_name)
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to('cuda')

# 获取模型内 bert 主体的输出
distilbert_output = model.distilbert(**inputs)
# 使用 bert 输出的第一个token [CLS] 计算情感分类概率
hidden_state = distilbert_output[0]  # (bs, seq_len, dim)
pooled_output = hidden_state[:, 0]  # (bs, dim)
pooled_output = model.pre_classifier(pooled_output)  # (bs, dim)
pooled_output = nn.ReLU()(pooled_output)  # (bs, dim)
pooled_output = model.dropout(pooled_output)  # (bs, dim)
logits = model.classifier(pooled_output)  # (bs, num_labels)
print("Positive rate: ", nn.Softmax(1)(logits)[0,1].detach().cpu().numpy())

 

标签:Transformers,tokenizer,权重,记录,模型,base,使用,output,model
From: https://www.cnblogs.com/qizhou/p/17640915.html

相关文章

  • 23.8.13米哈游秋招笔试题记录
    第一题签到题easy第二题//给出一颗有根树,树上有n个节点和n-1条边,边的距离为1.根节点编号为1.//根据上述构建出这棵有根树。//然后,进行任意次操作://操作内容:对于树的叶子节点添加一个叶子节点,新添加边长度也是1.//问经过操作以后,使得这棵树中所有节点与根节点的距离不......
  • 【LeetCode1384. 按年度列出销售总额】MySQL使用with recursive根据开始日期和结束日
    题目地址https://leetcode.cn/problems/total-sales-amount-by-year/description/代码WITHRECURSIVEDateSeriesAS(SELECTproduct_id,period_startASsale_date,period_end,average_daily_salesFROMSales--Assumingyourtablenameissales_dataUN......
  • 【LeetCode2199. 找到每篇文章的主题】字符串处理题,使用MySQL里的group_concat和LOCAT
    题目地址https://leetcode.cn/problems/finding-the-topic-of-each-post/description/代码witht1as(selectp.*,k.*fromPostspleftjoinKeywordskonLOCATE(LOWER(CONCAT('',word,'')),LOWER(CONCAT('',conte......
  • 【Freertos基础入门】队列(queue)的使用
    提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档@TOC前言本系列基于stm32系列单片机来使用freerotsFreeRTOS是一个广泛使用的开源实时操作系统(RTOS),它提供了丰富的功能和特性,使嵌入式系统的开发更加简单和高效。队列是FreeRTOS中常用的一种通信机制,它用于在任务之间传......
  • 如何使用Git LFS下载大模型权重
    如何使用GitLFS下载大模型权重大语言模型的权重文件通常比较大,直接从浏览器中下载的话不太方便。我们可以使用GitLFS获得更好的下载体验。GitLFS(大文件存储)是Git的一个扩展,允许我们更高效地处理大文件。安装gitlfsinstall或者sudoapt-getinstallgit-lfs下载以清......
  • 开源.NetCore通用工具库Xmtool使用连载 - 加密解密篇
    【Github源码】《上一篇》详细介绍了Xmtool工具库中的正则表达式类库,今天我们继续为大家介绍其中的加密解密类库。在开发过程中我们经常会遇到需要对数据进行加密和解密的需求,例如密码的加密、接口传输数据的加密等;当前类库中只封装了Base64、AES两种加密解密方法,因为C#提供了几......
  • 【Maven】打包补充依赖的操作记录
    题外话每次搞maven环境,总是觉得很痛苦,痛苦的根源源于,无从下手。要说maven有多难,自然也不能这样说,究竟也是因为没有系统地去学习,和没有把踩过的坑积累成经验,以至于每一次都踩差不多的坑,浪费相当的时间,打击相当的信心,于是觉得这是一道铜墙铁壁。每每遇到这些环境问题,首先就觉得,......
  • 免费HTTP代理IP使用须知
       免费的HTTP代理IP可以用于一些基本的网络爬虫、数据采集、简单的网页浏览等业务。但是需要注意的是,由于免费的HTTP代理IP质量不稳定,可能会被封禁或者速度较慢,不适合一些对稳定性和速度要求较高的业务,例如在线视频播放、在线游戏等。对于一些需要高质量代理IP的业务,建议选......
  • 骚操作:使用RxJava实现ImageView的拖动、旋转和缩放
    本文介绍一种使用Rxjava实现图片交互操作的方法。支持单指拖动,双指旋转缩放,效果如下:自定义View首先自定义TrsImageView继承ImageView,设置ScaleType为Matrix,我们使用矩阵计算最终的translate,rotate和scale。publicclassTrsImageViewextendsImageView{publicTrsImageVi......
  • 【LeetCode1225. 报告系统状态的连续日期】MySQL使用lag,lead得到连续段的:开始标志,结束
    目录题目地址题目描述代码题目地址https://leetcode.cn/problems/report-contiguous-dates/description/题目描述Asystemisrunningonetaskeveryday.Everytaskisindependentoftheprevioustasks.Thetaskscanfailorsucceed.Writeasolution toreportth......