首页 > 其他分享 >69预训练BERT

69预训练BERT

时间:2022-08-17 22:55:28浏览次数:54  
标签:BERT 训练 devices mlm tokens encoded nsp 69 net

点击查看代码
import torch
from torch import nn
from d2l import torch as d2l

batch_size, max_len = 512, 64
train_iter, vocab = d2l.load_data_wiki(batch_size, max_len)

net = d2l.BERTModel(len(vocab), num_hiddens=128, norm_shape=[128],
                    ffn_num_input=128, ffn_num_hiddens=256, num_heads=2,
                    num_layers=2, dropout=0.2, key_size=128, query_size=128,
                    value_size=128, hid_in_features=128, mlm_in_features=128,
                    nsp_in_features=128)
devices = d2l.try_all_gpus()
print(devices)
loss = nn.CrossEntropyLoss()

# 计算遮蔽语言模型和下一句子预测任务的损失
#@save
def _get_batch_loss_bert(net, loss, vocab_size, tokens_X, segments_X, valid_lens_x,
                         pred_positions_X, mlm_weights_X, mlm_Y, nsp_y):
    # 前向传播
    _, mlm_Y_hat, nsp_Y_hat = net(tokens_X, segments_X, valid_lens_x.reshape(-1), pred_positions_X)
    # 计算遮蔽语言模型损失
    # * mlm_weights_X.reshape(-1, 1) 不去计算pad的loss
    mlm_l = loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1)) * mlm_weights_X.reshape(-1, 1)
    # ?
    mlm_l = mlm_l.sum() / (mlm_weights_X.sum() + 1e-8)
    # 计算下一句子预测任务的损失
    nsp_l = loss(nsp_Y_hat, nsp_y)
    # BERT预训练的最终损失是遮蔽语言模型损失和下一句预测损失的和
    l = mlm_l + nsp_l
    return mlm_l, nsp_l, l

def train_bert(train_iter, net, loss, vocab_size, devices, num_steps):
    # 当迭代次数或者epoch足够大的时候,我们通常会使用nn.DataParallel函数来用多个GPU来加速训练
    net = nn.DataParallel(net, device_ids=devices).to(devices[0])
    trainer = torch.optim.Adam(net.parameters(), lr=0.01)
    step, timer = 0, d2l.Timer()
    # animator = d2l.Animator(xlabel='step', ylabel='loss', xlim=[1, num_steps], legend=['mlm', 'nsp'])
    # 遮蔽语言模型损失的和,下一句预测任务损失的和,句子对的数量,计数
    metric = d2l.Accumulator(4)
    num_steps_reached = False
    while step < num_steps and not num_steps_reached:
        for tokens_X, segments_X, valid_lens_x, pred_positions_X,mlm_weights_X, mlm_Y, nsp_y in train_iter:
            # 放到GPU
            tokens_X = tokens_X.to(devices[0])
            segments_X = segments_X.to(devices[0])
            valid_lens_x = valid_lens_x.to(devices[0])
            pred_positions_X = pred_positions_X.to(devices[0])
            mlm_weights_X = mlm_weights_X.to(devices[0])
            mlm_Y, nsp_y = mlm_Y.to(devices[0]), nsp_y.to(devices[0])

            trainer.zero_grad()
            timer.start()
            mlm_l, nsp_l, l = _get_batch_loss_bert(
                net, loss, vocab_size, tokens_X, segments_X, valid_lens_x,
                pred_positions_X, mlm_weights_X, mlm_Y, nsp_y)
            l.backward()
            trainer.step()
            metric.add(mlm_l, nsp_l, tokens_X.shape[0], 1)
            timer.stop()
            # animator.add(step + 1, (metric[0] / metric[3], metric[1] / metric[3]))
            step += 1
            if step == num_steps:
                num_steps_reached = True
                break

    print(f'MLM loss {metric[0] / metric[3]:.3f}, '
          f'NSP loss {metric[1] / metric[3]:.3f}')
    print(f'{metric[2] / timer.sum():.1f} sentence pairs/sec on '
          f'{str(devices)}')

# train_bert(train_iter, net, loss, len(vocab), devices, 20)


def get_bert_encoding(net, tokens_a, tokens_b=None):
    tokens, segments = d2l.get_tokens_and_segments(tokens_a, tokens_b)
    token_ids = torch.tensor(vocab[tokens], device=devices[0]).unsqueeze(0)
    segments = torch.tensor(segments, device=devices[0]).unsqueeze(0)
    valid_len = torch.tensor(len(tokens), device=devices[0]).unsqueeze(0)
    # print('token_ids', token_ids.device)
    # print('segments', segments.device)
    # print('valid_len', valid_len.device)
    net = net.to(device=devices[0])
    encoded_X, _, _ = net(token_ids, segments, valid_len)
    # 返回tokens_a和tokens_b中所有词元的BERT(net)表示
    return encoded_X


tokens_a = ['a', 'crane', 'is', 'flying']
encoded_text = get_bert_encoding(net, tokens_a)
# 词元:'<cls>','a','crane','is','flying','<sep>'
encoded_text_cls = encoded_text[:, 0, :]
encoded_text_crane = encoded_text[:, 2, :]
"""encoded_text.shape :  torch.Size([1, 6, 128])"""
"""encoded_text_cls.shape :  torch.Size([1, 128])"""
"""encoded_text_crane[0][:3] :  tensor([-0.1122,  0.1724, -1.8077], device='cuda:0', grad_fn=<SliceBackward0>)"""

tokens_a, tokens_b = ['a', 'crane', 'driver', 'came'], ['he', 'just', 'left']
encoded_pair = get_bert_encoding(net, tokens_a, tokens_b)
# 词元:'<cls>','a','crane','driver','came','<sep>','he','just','left','<sep>'
encoded_pair_cls = encoded_pair[:, 0, :]
encoded_pair_crane = encoded_pair[:, 2, :]
"""encoded_pair.shape :  torch.Size([1, 10, 128])"""
"""encoded_pair_cls.shape :  torch.Size([1, 128])"""
"""encoded_text_crane[0][:3] :  tensor([ 0.3801,  0.4826, -1.7688], device='cuda:0', grad_fn=<SliceBackward0>)"""

标签:BERT,训练,devices,mlm,tokens,encoded,nsp,69,net
From: https://www.cnblogs.com/g932150283/p/16597099.html

相关文章

  • Codeforces1698F Equal Reversal【构造】
    分析:注意到你无论如何都无法改变a[1]的值,而你要改变a[2]的值时,你就必须要选择一个和a[1]相同的值,然后翻转这一段区间。又可以发现,任意两个数的相邻情况是不会改变的。比......
  • 《GB18469-2012》PDF下载
    《GB18469-2012全血及成分血质量要求》PDF下载《GB18469-2012》简介本标准规定了一般血站提供和临床输注用全血及成分血的质量要求;本标准适用于一般血站提供和临床输......
  • leetcode690-员工的重要性
    员工的重要性dfsclassSolution{Map<Integer,Employee>map=newHashMap<>();publicintgetImportance(List<Employee>employees,intid){......
  • Codeforces1699E Three Days Grace【数学】【DP】
    分析:一开始觉得是二分答案,发现行不通之后改为枚举最小值。现在我将这若干个数分解,假设分解完之后得到的最小值为$i$,那么我就是要在最小值为$i$的基础上尽量最小化分解的......
  • 《GB6944-2012》PDF下载
    《GB6944-2012危险货物分类和品名编号》PDF下载《GB6944-2012》简介本标准规定了危险货物品名表的一般要求、结构和危险货物品名表。本标准适用于危险货物运输、储存......
  • autodl-训练HGNN
    报错情况: 一开始以为是yaml版本不对,后来从代码处入手:参考:(92条消息)[报错]yaml.constructor.ConstructorError:couldnotdetermineaconstructorforthetag‘ta......
  • LeetCode 169 Majority Element
    Givenanarraynumsofsizen,returnthemajorityelement.Themajorityelementistheelementthatappearsmorethan⌊n/2⌋times.Youmayassumethatthe......
  • "蔚来杯"2022牛客暑期多校训练营9 G Magic Spells
    原题链接一开始manacher+单哈希wa,样例通过率97%,应该是卡了一手int_64自然溢出换成manacher+双哈希过了#include<bits/stdc++.h>usingnamespacestd;#definefr......
  • Bert bert-base-uncased 模型加载
    1、下载模型相关文件到本地路径https://huggingface.co/bert-base-uncased/tree/main2、修改模型加载,注释为修改前......
  • UPC2022暑期个人训练赛第36场
    多谢两位大佬的帮助,才能勉强完成几个题,这几个题还是挺有意思的问题A:WJ的逃离DFS超时,所以考虑BFS,记得上次炸僵尸也是这个教训,这次忘记了感谢sgjen大佬提供的帮......