首页 > 其他分享 >67自注意力和位置编码

67自注意力和位置编码

时间:2022-08-17 22:57:36浏览次数:64  
标签:hiddens 编码 encoding self torch num d2l 67 注意力

点击查看代码
import math
import torch
from torch import nn
from d2l import torch as d2l


# 自注意力
num_hiddens, num_heads = 100, 5
attention = d2l.MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
                                   num_hiddens, num_heads, 0.5)
attention.eval()


#@save
class PositionalEncoding(nn.Module):
    """位置编码"""
    # num_hiddens 向量长度
    def __init__(self, num_hiddens, dropout, max_len=1000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(dropout)
        # 创建一个足够长的P
        self.P = torch.zeros((1, max_len, num_hiddens))
        X = torch.arange(max_len, dtype=torch.float32).reshape(
            -1, 1) / torch.pow(10000, torch.arange(
            0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)
        print('self.P.shape', self.P.shape)
        # 所有batch, 所有numstep, 隔两列
        self.P[:, :, 0::2] = torch.sin(X)
        self.P[:, :, 1::2] = torch.cos(X)

    def forward(self, X):
        X = X + self.P[:, :X.shape[1], :].to(X.device)
        # dropout是为了防止对位置编码太敏感
        return self.dropout(X)

encoding_dim, num_steps = 32, 60
pos_encoding = PositionalEncoding(encoding_dim, 0)
pos_encoding.eval()
X = pos_encoding(torch.zeros((1, num_steps, encoding_dim)))
P = pos_encoding.P[:, :X.shape[1], :]
d2l.plot(torch.arange(num_steps), P[0, :, 6:10].T, xlabel='Row (position)',
         figsize=(6, 2.5), legend=["Col %d" % d for d in torch.arange(6, 10)])

# d2l.plt.show()


标签:hiddens,编码,encoding,self,torch,num,d2l,67,注意力
From: https://www.cnblogs.com/g932150283/p/16597081.html

相关文章

  • 68多头注意力
    点击查看代码importmathimporttorchfromtorchimportnnfromd2limporttorchasd2l#选择缩放点积注意力作为每一个注意力头#......
  • 64注意力汇聚:Nadaraya-Watson 核回归
    点击查看代码importtorchfromtorchimportnnfromd2limporttorchasd2l#生成数据集n_train=50#训练样本数x_train,_=torch.sort(torch.rand(n_trai......
  • python 中根据RNA序列输出密码子编码的氨基酸序列
     001、(base)root@PC1:/home/test4#lstest.py(base)root@PC1:/home/test4#cattest.py##测试程序#!/usr/bin/pythonrna="AUGGCCAUG......
  • 无法在 DLL“SQLite.Interop.dll”中找到名为“SI7fca2652f71267db”的入口点。
    首先,这个是在操作SQLite数据库,使用System.Data.SQLite包,需要这个文件SQLite.Interop.dll不然会报错在生成项目的时候需要确保有这两个文件夹(可以生成完手动复制,也可以放......
  • 【组成原理-数据】浮点数的编码与运算
    目录1浮点数的格式1.1符号(S)1.2阶码(E)1.3尾数(M)2IEEE754标准2.1短浮点数(float型)短浮点数的解释2.2长浮点数(double型)长浮点数的解释2.3相关例题3尾数的......
  • 【JAVA】URL编码对照表
    转载:https://blog.csdn.net/Danalee_Py/article/details/108083038?spm=1001.2101.3001.6661.1&utm_medium=distribute.pc_relevant_t0.none-task-blog-2%7Edefault%7EBlog......
  • P6733 「Wdsr-2」间歇泉——二分答案
    二分可以将优化问题转为判定问题,也可以将\(k\)大问题转为计数问题分析由于已知条件,\(\displaystyleT=\frac{a_ic_i+a_jc_j}{a_i+a_j}\),转为计数问题则是固定T,统计有多少......
  • C语言等长编码压缩和哈夫曼编码压缩
    C语言等长编码压缩和哈夫曼编码压缩利用哈夫曼算法对文件进行压缩及解压缩题目:选择一个英文纯文本文档(不少于3千字,也可以更多),分别利用等长编码和哈夫曼编码对其进行......
  • MySQL字段类型、字符编码与配置文件
    目录字符编码与配置文件存储引擎创建表的完整语法字段类型之整形字段类型之浮点型字段类型之字符类型字段后面的含义字段类型之枚举和集合字段类型之日期类型字段约束条件......
  • 【MySQL】第2回 字符编码和字段类型
    目录1.字符编码与配置文件1.1\S1.2my.ini2.数据库存储引擎2.1定义2.2需要掌握的存储引擎2.3不同存储引擎之间底层文件的区别3.创建表的完整语法4.MySQL字段类型4.......