首页 > 其他分享 >深度学习之Transformer网络

深度学习之Transformer网络

时间:2022-12-27 21:34:25浏览次数:39  
标签:dim Transformer self 网络 np shape embedding 深度 tf

【博主使用的python版本:3.6.8】


本次没有额外的资料下载

Packages

ort tensorflow as tf
import pandas as pd
import time
import numpy as np
import matplotlib.pyplot as plt

from tensorflow.keras.layers import Embedding, MultiHeadAttention, Dense, Input, Dropout, LayerNormalization
from transformers import DistilBertTokenizerFast #, TFDistilBertModel
from transformers import TFDistilBertForTokenClassification
from tqdm import tqdm_notebook as tqdm

1 - 位置编码

在顺序到序列任务中,数据的相对顺序对其含义非常重要。当你训练顺序神经网络(如RNN)时,你按顺序将输入输入到网络中。有关数据顺序的信息会自动输入到模型中。但是,在训练转换器网络时,会一次性将数据全部输入到模型中。虽然这大大减少了训练时间,但没有关于数据顺序的信息。这就是位置编码有用的地方 - 您可以专门编码输入的位置,并使用以下正弦和余弦公式将它们传递到网络中:

  • d是词嵌入和位置编码的维度
  • pos是单词的位置。
  • i指位置编码的每个不同维度。

正弦和余弦方程的值足够小(介于 -1 和 1 之间),因此当您将位置编码添加到单词嵌入时,单词嵌入不会明显失真。位置编码和单词嵌入的总和最终是输入到模型中的内容。结合使用这两个方程有助于变压器网络关注输入数据的相对位置。请注意,虽然在讲座中,Andrew 使用垂直向量,但在此作业中,所有向量都是水平的。所有矩阵乘法都应相应调整。

1.1 - 正弦角和余弦角

通过计算正弦和余弦方程的内项,获取用于计算位置编码的可能角度:

 

练习 1 - get_angles

实现函数 get_angles() 来计算正弦和余弦位置编码的可能角度

def get_angles(pos, i, d):
    """
    获取位置编码的角度
    
    Arguments:
        pos -- 包含位置的列向量[[0], [1], ...,[N-1]]
        i --   包含维度跨度的行向量 [[0, 1, 2, ..., M-1]]
        d(integer) -- 编码大小
    
    Returns:
        angles -- (pos, d) 数组
    """
    
    angles = pos/ (np.power(10000, (2 * (i//2)) / np.float32(d)))
    
    
    return angles

我们测试一下:

def get_angles_test(target):
    position = 4
    d_model = 16
    pos_m = np.arange(position)[:, np.newaxis]
    dims = np.arange(d_model)[np.newaxis, :]

    result = target(pos_m, dims, d_model)

    assert type(result) == np.ndarray, "你必须返回一系列数组集合"
    assert result.shape == (position, d_model), f"防止错误我们希望: ({position}, {d_model})"
    assert np.sum(result[0, :]) == 0
    assert np.isclose(np.sum(result[:, 0]), position * (position - 1) / 2)
    even_cols =  result[:, 0::2]
    odd_cols = result[:,  1::2]
    assert np.all(even_cols == odd_cols), "奇数列和偶数列的子矩阵必须相等"
    limit = (position - 1) / np.power(10000,14.0/16.0)
    assert np.isclose(result[position - 1, d_model -1], limit ), f"组后的值必须是 {limit}"

    print("\033[92mAll tests passed")

get_angles_test(get_angles)

# 例如
position = 4
d_model = 8
pos_m = np.arange(position)[:, np.newaxis]
dims = np.arange(d_model)[np.newaxis, :]
get_angles(pos_m, dims, d_model)
All tests passed
Out[9]:
array([[0.e+00, 0.e+00, 0.e+00, 0.e+00, 0.e+00, 0.e+00, 0.e+00, 0.e+00],
       [1.e+00, 1.e+00, 1.e-01, 1.e-01, 1.e-02, 1.e-02, 1.e-03, 1.e-03],
       [2.e+00, 2.e+00, 2.e-01, 2.e-01, 2.e-02, 2.e-02, 2.e-03, 2.e-03],
       [3.e+00, 3.e+00, 3.e-01, 3.e-01, 3.e-02, 3.e-02, 3.e-03, 3.e-03]])

1.2 - 正弦和余弦位置编码

现在,您可以使用计算的角度来计算正弦和余弦位置编码。

 

练习 2 - 位置编码

实现函数 positional_encoding() 来计算正弦和余弦位置编码

  • np.newaxis 有用,具体取决于您选择的实现。就是将矩阵升维
def positional_encoding(positions, d):
    """
    预先计算包含所有位置编码的矩阵
    
    Arguments:
        positions (int) -- 要编码的最大位置数
        d (int) --编码大小 
    
    Returns:
        pos_encoding -- (1, position, d_model)具有位置编码的矩阵
    """
    # 初始化所有角度angle_rads矩阵
    angle_rads = get_angles(np.arange(positions)[:, np.newaxis],
                            np.arange(d)[ np.newaxis,:],
                            d)
  
    # -> angle_rads has dim (positions,d)
    # 将 sin 应用于数组中的偶数索引;2i
    angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])
  
    # a将 cos 应用于数组中的偶数索引;2i; 2i+1
    angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])
    # END CODE HERE
    
    pos_encoding = angle_rads[np.newaxis, ...]
    
    return tf.cast(pos_encoding, dtype=tf.float32)

我们来测试一下:

def positional_encoding_test(target):
    position = 8
    d_model = 16

    pos_encoding = target(position, d_model)
    sin_part = pos_encoding[:, :, 0::2]
    cos_part = pos_encoding[:, :, 1::2]

    assert tf.is_tensor(pos_encoding), "输出不是一个张量"
    assert pos_encoding.shape == (1, position, d_model), f"防止错误,我们希望: (1, {position}, {d_model})"

    ones = sin_part ** 2  +  cos_part ** 2
    assert np.allclose(ones, np.ones((1, position, d_model // 2))), "平方和一定等于1 = sin(a)**2 + cos(a)**2"
    
    angs = np.arctan(sin_part / cos_part)
    angs[angs < 0] += np.pi
    angs[sin_part.numpy() < 0] += np.pi
    angs = angs % (2 * np.pi)
    
    pos_m = np.arange(position)[:, np.newaxis]
    dims = np.arange(d_model)[np.newaxis, :]

    trueAngs = get_angles(pos_m, dims, d_model)[:, 0::2] % (2 * np.pi)
    
    assert np.allclose(angs[0], trueAngs), "您是否分别将 sin 和 cos 应用于偶数和奇数部分?"
 
    print("\033[92mAll tests passed")

    
positional_encoding_test(positional_encoding)
All tests passed
计算位置编码的工作很好!现在,您可以可视化它们。
pos_encoding = positional_encoding(50, 512)

print (pos_encoding.shape)

plt.pcolormesh(pos_encoding[0], cmap='RdBu')
plt.xlabel('d')
plt.xlim((0, 512))
plt.ylabel('Position')
plt.colorbar()
plt.show()
(1, 50, 512)

 

 每一行代表一个位置编码 - 请注意,没有一行是相同的!您已为每个单词创建了唯一的位置编码。

2 - 掩码

构建transformer网络时,有两种类型的掩码很有用:填充掩码和前瞻掩码。两者都有助于softmax计算为输入句子中的单词提供适当的权重。

2.1 - 填充掩码

通常,输入序列会超过网络可以处理的序列的最大长度。假设模型的最大长度为 5,则按以下序列馈送:

[["Do", "you", "know", "when", "Jane", "is", "going", "to", "visit", "Africa"], 
 ["Jane", "visits", "Africa", "in", "September" ],
 ["Exciting", "!"]
]
可能会被矢量化为:
[[ 71, 121, 4, 56, 99, 2344, 345, 1284, 15],
 [ 56, 1285, 15, 181, 545],
 [ 87, 600]
]
将序列传递到转换器模型中时,它们必须具有统一的长度。您可以通过用零填充序列并截断超过模型最大长度的句子来实现此目的:
[[ 71, 121, 4, 56, 99],
 [ 2344, 345, 1284, 15, 0],
 [ 56, 1285, 15, 181, 545],
 [ 87, 600, 0, 0, 0],
]
长度超过最大长度 5 的序列将被截断,零将被添加到截断的序列中以实现一致的长度。同样,对于短于最大长度的序列,它们也将添加零以进行填充。
但是,这些零会影响softmax计算 - 这是填充掩码派上用场的时候!通过将填充掩码乘以 -1e9 并将其添加到序列中,
您可以通过将零设置为接近负无穷大来屏蔽零。我们将为您实现这一点,以便您可以获得构建transformer网络的乐趣!

标签:dim,Transformer,self,网络,np,shape,embedding,深度,tf
From: https://www.cnblogs.com/kk-style/p/17009045.html

相关文章

  • 计算机网络复习——概要
    第一章概述什么是协议和体系结构?了解网络应用的两种模型:C/S和P2P模型什么是资源子网和通信子网?各种网络设备(转发器、集线器、网桥、路由器等)所工作的层次和基本特性......
  • 网络程序设计 实验5 图形化Ping工具
    实验5图形化Ping工具实验目的:用图形界面实现Ping操作。开发语言与工具:VC实验要求:1.使用MFC编程。2.界面上有目标地址栏,信息框和ping按钮。3.使用原始套接字实......
  • 如何写一个深度学习编译器
    编译器本质上是一种提高开发效率的工具,将高级语言转换为低级语言(通常是二进制机器码),使得程序员不需要徒手写二进制。转换过程中,首要任务是保证正确性,同时需要进行优化以提......
  • 一个核心交换机如何安全隔离两个网络?
    网络描述:客户内部电脑分为内网和互联网;由于预算有限,只有一台核心交换机,分别连接互联网和集团内网;客户想要内网交换机的终端只能访问集团内网,接互联网的交换机只访问互联网,看......
  • JVM内存溢出深度分析
    今天,发现游戏逻辑服务器内存溢出问题,每隔一定时间就生成java_pidxxxxxx.hprof,基本1G内存分配不够用了,导致FGC频繁发生。工具:MAT ​​EclipseMemoryAnalyzerTool(MAT)分......
  • Linux网络流量实时监控工具-ifstat
    介绍ifstat工具是个网络接口监测工具,比较简单看网络流量ifstat的安装使用:wget ​​http://distfiles.macports.org/ifstat/ifstat-1.1.tar.gz​​​tarxzvfifstat-1.1.......
  • Kali Linux三种网络攻击方法总结(DDoS、CC和ARP欺骗)
    本文章使用的是KaliLinux的2020-4-installer-amd64版本,其他版本是否兼容我会尽快测试,如果你想更快知道你所用的版本是否兼容,可以在下面留言,我会在看到信息后的第一时间回......
  • TensorFlow高阶 API: keras教程-使用tf.keras搭建mnist手写数字识别网络
    TensorFlow高阶API:keras教程-使用tf.keras搭建mnist手写数字识别网络目录​​TensorFlow高阶API:keras教程-使用tf.keras搭建mnist手写数字识别网络​​​​1、Keras​​​......
  • 网络监测工具之Zabbix的搭建与测试方法(三) ---Zabbix Agent
    安装客户端在官方网站下载最新版本zabbixagentv6.2.6,然后默认安装,其中配置服务端的界面如下图: 其他一律默认即可。启用发现功能 如上图所示,开启发现规则,默认搜索......
  • 第一期预告|基于深度学习的物体抓取位置估计
    3D视觉工坊的各位小伙伴们,本周我们将迎来首次线上公开课,此次公开课是一次知识分享,希望更多的小伙伴能够加入我们。本周给大家先带来一场关于机械臂抓取的精彩课程。本周公开......