首页 > 其他分享 >昇思25天学习打卡营第11天|基于MindSpore的GPT2文本摘要

昇思25天学习打卡营第11天|基于MindSpore的GPT2文本摘要

时间:2024-07-09 19:54:56浏览次数:23  
标签:11 25 max dataset num steps import article 打卡

数据集

准备nlpcc2017摘要数据,内容为新闻正文及其摘要,总计50000个样本。

数据需要预处理,如下

原始数据格式:
article: [CLS] article_context [SEP]
summary: [CLS] summary_context [SEP]

预处理后的数据格式:
[CLS] article_context [SEP] summary_context [SEP]

代码示例:

# 下载依赖
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
!pip install mindnlp

from mindnlp.utils import http_get

# 下载数据集
url = 'https://download.mindspore.cn/toolkits/mindnlp/dataset/text_generation/nlpcc2017/train_with_summ.txt'
path = http_get(url, './')

from mindspore.dataset import TextFileDataset

# 加载数据集
dataset = TextFileDataset(str(path), shuffle=False)
dataset.get_dataset_size()

# 按9:1比例拆分训练集、测试集
train_dataset, test_dataset = dataset.split([0.9, 0.1], randomize=False)

import json
import numpy as np

# 数据集数据预处理
def process_dataset(dataset, tokenizer, batch_size=6, max_seq_len=1024, shuffle=False):
    def read_map(text):
        data = json.loads(text.tobytes())
        return np.array(data['article']), np.array(data['summarization'])

    def merge_and_pad(article, summary):
        # tokenization
        # pad to max_seq_length, only truncate the article
        tokenized = tokenizer(text=article, text_pair=summary,
                              padding='max_length', truncation='only_first', max_length=max_seq_len)
        return tokenized['input_ids'], tokenized['input_ids']
    
    dataset = dataset.map(read_map, 'text', ['article', 'summary'])
    # change column names to input_ids and labels for the following training
    dataset = dataset.map(merge_and_pad, ['article', 'summary'], ['input_ids', 'labels'])

    dataset = dataset.batch(batch_size)
    if shuffle:
        dataset = dataset.shuffle(batch_size)

    return dataset

from mindnlp.transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
len(tokenizer)

train_dataset = process_dataset(train_dataset, tokenizer, batch_size=4)
next(train_dataset.create_tuple_iterator())

模型构建

代码示例:

# 构建GPT2ForSummarization模型
from mindspore import ops
from mindnlp.transformers import GPT2LMHeadModel

class GPT2ForSummarization(GPT2LMHeadModel):
    def construct(
        self,
        input_ids = None,
        attention_mask = None,
        labels = None,
    ):
        outputs = super().construct(input_ids=input_ids, attention_mask=attention_mask)
        shift_logits = outputs.logits[..., :-1, :]
        shift_labels = labels[..., 1:]
        # Flatten the tokens
        loss = ops.cross_entropy(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1), ignore_index=tokenizer.pad_token_id)
        return loss

# 动态学习率
from mindspore import ops
from mindspore.nn.learning_rate_schedule import LearningRateSchedule

class LinearWithWarmUp(LearningRateSchedule):
    """
    Warmup-decay learning rate.
    """
    def __init__(self, learning_rate, num_warmup_steps, num_training_steps):
        super().__init__()
        self.learning_rate = learning_rate
        self.num_warmup_steps = num_warmup_steps
        self.num_training_steps = num_training_steps

    def construct(self, global_step):
        if global_step < self.num_warmup_steps:
            return global_step / float(max(1, self.num_warmup_steps)) * self.learning_rate
        return ops.maximum(
            0.0, (self.num_training_steps - global_step) / (max(1, self.num_training_steps - self.num_warmup_steps))
        ) * self.learning_rate

模型训练

代码示例:

num_epochs = 1
warmup_steps = 2000
learning_rate = 1.5e-4

num_training_steps = num_epochs * train_dataset.get_dataset_size()

from mindspore import nn
from mindnlp.transformers import GPT2Config, GPT2LMHeadModel

config = GPT2Config(vocab_size=len(tokenizer))
model = GPT2ForSummarization(config)

lr_scheduler = LinearWithWarmUp(learning_rate=learning_rate, num_warmup_steps=warmup_steps, num_training_steps=num_training_steps)
optimizer = nn.AdamWeightDecay(model.trainable_params(), learning_rate=lr_scheduler)

# 记录模型参数数量
print('number of model parameters: {}'.format(model.num_parameters()))

from mindnlp._legacy.engine import Trainer
from mindnlp._legacy.engine.callbacks import CheckpointCallback

ckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='gpt2_summarization',
                                epochs=1, keep_checkpoint_max=2)

trainer = Trainer(network=model, train_dataset=train_dataset,
                  epochs=1, optimizer=optimizer, callbacks=ckpoint_cb)
trainer.set_amp(level='O1')  # 开启混合精度

trainer.run(tgt_columns="labels")

运行结果:
模型训练结果

模型推理

将向量数据变为中文数据。
代码示例:

def process_test_dataset(dataset, tokenizer, batch_size=1, max_seq_len=1024, max_summary_len=100):
    def read_map(text):
        data = json.loads(text.tobytes())
        return np.array(data['article']), np.array(data['summarization'])

    def pad(article):
        tokenized = tokenizer(text=article, truncation=True, max_length=max_seq_len-max_summary_len)
        return tokenized['input_ids']

    dataset = dataset.map(read_map, 'text', ['article', 'summary'])
    dataset = dataset.map(pad, 'article', ['input_ids'])
    
    dataset = dataset.batch(batch_size)

    return dataset

test_dataset = process_test_dataset(test_dataset, tokenizer, batch_size=1)
print(next(test_dataset.create_tuple_iterator(output_numpy=True)))

model = GPT2LMHeadModel.from_pretrained('./checkpoint/gpt2_summarization_epoch_0.ckpt', config=config)
model.set_train(False)
model.config.eos_token_id = model.config.sep_token_id
i = 0
for (input_ids, raw_summary) in test_dataset.create_tuple_iterator():
    output_ids = model.generate(input_ids, max_new_tokens=50, num_beams=5, no_repeat_ngram_size=2)
    output_text = tokenizer.decode(output_ids[0].tolist())
    print(output_text)
    i += 1
    if i == 1:
        break

截图时间
截图时间

标签:11,25,max,dataset,num,steps,import,article,打卡
From: https://blog.csdn.net/leifchen90/article/details/140271225

相关文章

  • Android 11 recovery恢复出厂设置保留某些文件
    /bootable/recovery/recovery.cpprecovery的注释,流程解释!/**Therecoverytoolcommunicateswiththemainsystemthrough/cachefiles.*/cache/recovery/command-INPUT-commandlinefortool,oneargperline*/cache/recovery/log-OUTPUT-combin......
  • Lunaproxy与711Proxy的对比与优劣分析
    今天我们来深入对比两款在市场上备受关注的代理IP服务:Lunaproxy和711Proxy。接下来,我们将从多个角度对这两款服务进行详细分析,帮助大家做出明智的选择。优势分析711Proxy的优势1.性价比高:711Proxy提供多种灵活的套餐选择,价格经济实惠,性价比在所有代理服务商中数一数二,从小......
  • 打卡信奥刷题(276)用Scratch图形化工具信奥P1007[普及组/提高] 独木桥
    独木桥题目背景战争已经进入到紧要时间。你是运输小队长,正在率领运输部队向前线运送物资。运输任务像做题一样的无聊。你希望找些刺激,于是命令你的士兵们到前方的一座独木桥上欣赏风景,而你留在桥下欣赏士兵们。士兵们十分愤怒,因为这座独木桥十分狭窄,只能容纳......
  • Python酷库之旅-第三方库Pandas(011)
    目录一、用法精讲25、pandas.HDFStore.get函数25-1、语法25-2、参数25-3、功能25-4、返回值25-5、说明25-6、用法25-6-1、数据准备25-6-2、代码示例25-6-3、结果输出26、pandas.HDFStore.select函数26-1、语法26-2、参数26-3、功能26-4、返回值26-5、说明26-......
  • 昇思25天学习打卡营第25天|DCGAN生成漫画头像
    使用场景        DCGAN(深度卷积生成对抗网络)被广泛应用于生成图像数据的任务。在本教程中,我们将使用DCGAN生成漫画头像。通过这一教程,您可以了解如何搭建DCGAN网络,设置优化器,计算损失函数,以及初始化模型权重。原理        DCGAN是GAN(生成对抗网络)的扩展版本......
  • P3022 [USACO11OPEN] Odd degrees G
    P3022[USACO11OPEN]OdddegreesG构造每个连通块独立,考虑其中一个如何构造。因为无向图的度数一定是偶数,而每个点的度数是奇数,所以点数为奇数,否则无解。考虑建dfs树,不关心非树边,只考虑树边的取舍构造。自底向上构造,假如当前\(u\)的儿子\(v\)为偶数,那么就不能取\((u,v)......
  • 1154java jsp SSM古董拍卖网站系统(源码+文档+PPT+运行视频+讲解视频)
     项目技术:SSM+Maven+Vue等等组成,B/S模式+Maven管理等等。环境需要1.运行环境:最好是javajdk1.8,我们在这个平台上运行的。其他版本理论上也可以。2.IDE环境:IDEA,Eclipse,Myeclipse都可以。推荐IDEA;3.tomcat环境:Tomcat7.x,8.x,9.x版本均可4.硬件环境:windows7/8/1......
  • Windows7、10、11官方原版ISO镜像下载
    整理好放到夸克网盘了,需要的自取https://pan.quark.cn/s/3efb93ba1306  版本发布日期文件名文件大小sha1Windows1022h2 商业版(含专业、专业工作站、教育、专业教育、企业版)2024年3月19日zh-cn_windows_10_business_editions_version_22h2_updated_mar......
  • c++ primer plus 第15章友,异常和其他:异常,15.3.5 异常规范和 C++11
    c++primerplus第15章友,异常和其他:异常,15.3.5异常规范和C++1115.3.5异常规范和C++11文章目录c++primerplus第15章友,异常和其他:异常,15.3.5异常规范和C++1115.3.5异常规范和C++1115.3.5异常规范和C++11有时候,一种理念看似有前途,但实际的使用效果并......
  • CSCI-GA.2250-001 Scheduler
    ProgrammingAssignment#2(Lab2):Scheduler/DispatcherClassCSCI-GA.2250-001Summ2024In  this  lab  we  explore  the  implementation  and  effects  of different  scheduling policies  discussed  in  class  on  a  set......