首页 > 其他分享 >深入探索智能未来:文本生成与问答模型的创新融合

深入探索智能未来:文本生成与问答模型的创新融合

时间:2023-08-17 14:56:03浏览次数:54  
标签:训练 -- 模型 MASK 融合 train 文本 问答

深入探索智能未来:文本生成与问答模型的创新融合

1.Filling Model with T5

1.1背景介绍

该项目用于将句子中 [MASK] 位置通过生成模型还原,以实现 UIE 信息抽取中 Mask Then Filling 数据增强策略

Mask Then Fill 是一种基于生成模型的信息抽取数据增强策略。对于一段文本,我们其分为「关键信息段」和「非关键信息段」,包含关键词片段称为「关键信息段」。下面例子中标粗的为 关键信息片段,其余的为 非关键片段

大年三十 我从 北京 的大兴机场 飞回成都

我们随机 [MASK] 住一部分「非关键片段」,使其变为:

大年三十 我从 北京 [MASK] 飞回成都

随后,将改句子喂给 filling 模型(T5-Fine Tuned)还原句子,得到新生成的句子:

大年三十 我从 北京 首都机场作为起点,飞回成都

  • 环境安装

本项目基于 pytorch + transformers 实现,运行前请安装相关依赖包:

pip install -r ../requirements.txt
  • 数据集准备

项目中提供了一部分示例数据,数据来自DuIE数据集中的文本数据,数据在 data/

若想使用 自定义数据 训练,只需要仿照示例数据构建带 [MASK] 的文本即可,你也可以使用 parse_data.py 快速生成基于 词粒度 的训练数据:

"Bortolaso Guillaume,法国籍[MASK]"中[MASK]位置的文本是:	运动员
"歌曲[MASK]是由歌手海生演唱的一首歌曲"中[MASK]位置的文本是:	《情一动心就痛》
...

每一行用 \t 分隔符分开,第一部分部分为 带[MASK]的文本,后一部分为 [MASK]位置的原始文本(label)

1.2. 模型训练

修改训练脚本 train.sh 里的对应参数, 开启模型训练:

python train.py \
    --pretrained_model "uer/t5-base-chinese-cluecorpussmall" \
    --save_dir "checkpoints/t5" \
    --train_path "data/train.tsv" \
    --dev_path "data/dev.tsv" \
    --img_log_dir "logs" \
    --img_log_name "T5-Base-Chinese" \
    --batch_size 128 \
    --max_source_seq_len 128 \
    --max_target_seq_len 32 \
    --learning_rate 1e-4 \
    --num_train_epochs 20 \
    --logging_steps 50 \
    --valid_steps 500 \
    --device cuda:0

正确开启训练后,终端会打印以下信息:

...
 0%|          | 0/2 [00:00<?, ?it/s]
100%|██████████| 2/2 [00:00<00:00, 21.28it/s]
DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 350134
    })
    dev: Dataset({
        features: ['text'],
        num_rows: 38904
    })
})
...
global step 2400, epoch: 1, loss: 7.44746, speed: 0.82 step/s
global step 2450, epoch: 1, loss: 7.42028, speed: 0.82 step/s
global step 2500, epoch: 1, loss: 7.39333, speed: 0.82 step/s
Evaluation bleu4: 0.00578
best BLEU-4 performence has been updated: 0.00026 --> 0.00578
global step 2550, epoch: 1, loss: 7.36620, speed: 0.81 step/s
...

logs/T5-Base-Chinese.png 文件中将会保存训练曲线图:

1.3 模型预测

完成模型训练后,运行 inference.py 以加载训练好的模型并应用:

 if __name__ == "__main__":
    masked_texts = [
        '"《μVision2单片机应用程序开发指南》是2005年2月[MASK]图书,作者是李宇"中[MASK]位置的文本是:'
    ]
    inference(masked_texts)
python inference.py

得到以下推理结果:

maksed text: 
[
    '"《μVision2单片机应用程序开发指南》是2005年2月[MASK]图书,作者是李宇"中[MASK]位置的文本是:'
]
output: 
[
    ',中国工业出版社出版的'
]

2.问答模型(Text-Generation, T5 Based)

2.1 背景介绍

问答模型是指通过输入一个「问题」和一段「文章」,输出「问题的答案」。

问答模型分为「抽取式」和「生成式」,抽取式问答可以使用 [UIE] 训练,这个实验中我们将使用「生成式」模型来训练一个问答模型。

我们选用「T5」作为 backbone,使用百度开源的「QA数据集」来训练得到一个生成式的问答模型。

  • 环境安装

本项目基于 pytorch + transformers 实现,运行前请安装相关依赖包:

pip install -r ../requirements.txt

2.2 数据集准备

项目中提供了一部分示例数据,数据是百度开源的问答数据集,数据在 data/DuReaderQG

若想使用自定义数据训练,只需要仿照示例数据构建数据集即可:

{"context": "违规分为:一般违规扣分、严重违规扣分、出售假冒商品违规扣分,淘宝网每年12月31日24:00点会对符合条件的扣分做清零处理,详情如下:|温馨提醒:由于出售假冒商品24≤N<48分,当年的24分不清零,所以会存在第一年和第二年的不同计分情况。", "answer": "12月31日24:00", "question": "淘宝扣分什么时候清零", "id": 203}
{"context": "生长速度 头发是毛发中生长最快的毛发,一般每天长0.27—0.4mm,每月平均生长约1.0cm,一年大概长10—14cm。但是,头发不可能无限制的生长,一般情况下,头发长至50—60cm,就会脱落再生新发。", "answer": "0.27—0.4mm", "question": "头发一天能长多少", "id": 328}
...

每一行为一个数据样本,json 格式。

其中,"context" 代表参考文章,question 代表问题,"answer" 代表问题答案。

2.3 模型训练

修改训练脚本 train.sh 里的对应参数, 开启模型训练:

python train.py \
    --pretrained_model "uer/t5-base-chinese-cluecorpussmall" \
    --save_dir "checkpoints/DuReaderQG" \
    --train_path "data/DuReaderQG/train.json" \
    --dev_path "data/DuReaderQG/dev.json" \
    --img_log_dir "logs/DuReaderQG" \
    --img_log_name "T5-Base-Chinese" \
    --batch_size 32 \
    --learning_rate 1e-4 \
    --max_source_seq_len 256 \
    --max_target_seq_len 32 \
    --learning_rate 5e-5 \
    --num_train_epochs 50 \
    --logging_steps 10 \
    --valid_steps 500 \
    --device "cuda:0"

正确开启训练后,终端会打印以下信息:

...
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 650.73it/s]
DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 14520
    })
    dev: Dataset({
        features: ['text'],
        num_rows: 984
    })

global step 10, epoch: 1, loss: 9.39613, speed: 1.60 step/s
global step 20, epoch: 1, loss: 9.39434, speed: 1.71 step/s
global step 30, epoch: 1, loss: 9.39222, speed: 1.72 step/s
global step 40, epoch: 1, loss: 9.38739, speed: 1.63 step/s
global step 50, epoch: 1, loss: 9.38296, speed: 1.63 step/s
global step 60, epoch: 1, loss: 9.37982, speed: 1.71 step/s
global step 70, epoch: 1, loss: 9.37385, speed: 1.71 step/s
global step 80, epoch: 1, loss: 9.36876, speed: 1.69 step/s
global step 90, epoch: 1, loss: 9.36209, speed: 1.72 step/s
global step 100, epoch: 1, loss: 9.35349, speed: 1.70 step/s
...

logs/DuReaderQG 文件下将会保存训练曲线图:

2.4 模型推理

完成模型训练后,运行 inference.py 以加载训练好的模型并应用:

...

if __name__ == '__main__':
    question = '治疗宫颈糜烂的最佳时间'
    context = '专家指出,宫颈糜烂治疗时间应选在月经干净后3-7日,因为治疗之后宫颈有一定的创面,如赶上月经期易发生感染。因此患者应在月经干净后3天尽快来医院治疗。同时应该注意,术前3天禁同房,有生殖道急性炎症者应治好后才可进行。'
    inference(qustion=question, context=context)

运行推理程序:

python inference.py

得到以下推理结果:

Q: "治疗宫颈糜烂的最佳时间"
C: "专家指出,宫颈糜烂治疗时间应选在月经干净后3-7日,因为治疗之后宫颈有一定的创面,如赶上月经期易发生感染。因此患者应在月经干净后3天尽快来医院治疗。同时应该注意,术前3天禁同房,有生殖道急性炎症者应治好后才可进行。"
A: "答案:月经干净后3-7日"

项目链接:https://github.com/HarderThenHarder/transformers_tasks/blob/main/answer_generation/readme.md

更多优质内容请关注公号:汀丶人工智能;会提供一些相关的资源和优质文章,免费获取阅读。

标签:训练,--,模型,MASK,融合,train,文本,问答
From: https://www.cnblogs.com/ting1/p/17637571.html

相关文章

  • 长文本拆分
    长文本拆分TL;DR企业微信消息长度限制为2048个字节,字符长度不等于字节长度使用字节拆分,会导致中文字符被截断使用文本+字节拆分,无法处理emoji表情使用unicode字符拆分,即可解决以上问题先前在做企业微信的应用接入ChatGPT时遇到一个问题,就是企业微信的消息长度限制为2048......
  • 搜文本搜位置搜图片,1小时玩转Elasticsearch
    加入Elasticsearch训练营,从全文检索到向量检索,搭建高频业务场景,构建进阶向量检索应用。带你拓展技术视野,晋升Elasticsearch搜索实战派。以下为训练营的参营指南,请您仔细阅读便于更顺利地进行训练营打卡。活动地址活动地址:<https://developer.aliyun.com/trainingcamp/53a2ca29e......
  • 使用css样式完成文本超出的部分用省略号代替
    <p>使用css样式完成文本超出的部分用省略号代替</p>第一步要设置宽度,然后使用text-overflow:ellipsis等属性组合使用li{width:140px;background-color:red;overflow:hidden;/*溢出隐藏*/white-s......
  • jquery内容文本值
       ......
  • RPA+智能问答实现微信端智能客服
    背景:由于业务发展迅速,服务的商家越来越多,目前我们售后团队都是通过企业微信群和客户进行沟通,平时客户的相关问题也是在企业微信中来讨论解决;但是我们售后团队资源有限,而且有的问题客户会重复问,周末或者晚上售后同学回复不及时影响体验;最重要的一点商家客服习惯于在微信端进行咨......
  • 在Typora中使用AutoHotkey 2.0实现使用快捷键设置文本颜色
    使用Typora时不能设置文本颜色,总是觉得不方便,于是在网上搜索,发现有个小工具:AutoHotkey,编写脚本后,通过快捷键的方式可以设置Typora的文本颜色。下载软件到https://www.autohotkey.com/这个网址下载AutoHotkey并安装脚本实现网上很多实现方式都是基于AutoHotkeyv1.0、v1.1的,Au......
  • Linux文本三剑客sed
    目录脚本格式sed即StreamEDitorsed是编辑器sed格式sed[选项]...{sed自己的脚本}{输入文件}...sed'脚本语言'sed自己的脚本语言脚本'地址'+'命令'脚本'命令'#没有地址就是全文选项:-n不输出模式空间内容到屏幕,即不自动打印-r使用扩展正则表达式......
  • 视频融合共享平台在“情指勤舆”一体化工作中的应用
    一、方案背景新时代背景下的公安情报指挥体系建设,需立足情报指挥中心,以“情报全面精准、指挥集成统一、勤务协同高效、舆情管控有力”为目标,以情指勤舆一体化指挥调度平台建设为抓手,把情报指挥中心打造成社会治安防控体系的网络核心,健全完善运行模式,理顺规范工作流程,做实做强平台支......
  • 文本转换图片
    unitUnit1;interfaceusesWindows,Messages,SysUtils,Variants,Classes,Graphics,Controls,Forms,Dialogs,StdCtrls,ExtCtrls;typeTForm1=class(TForm)Memo1:TMemo;{用于输入要保存的文本}ComboBox1:TComboBox;{字体}LabeledEdit1:TLabeledEdit;{字号}LabeledEdi......
  • 济南两化融合申报补助和条件是什么
    济南两化融合申报补助和条件是什么  恒标知产刘经理一、两化融合是指什么 定义: 两化融合是信息化和工业化的高层次的深度结合,是指以信息化带动工业化、以工业化促进信息化,走新型工业化道路;两化融合的核心就是信息化支撑,追求可持续发展模式。 两化融合是指电子信息技术广泛应......