首页 > 其他分享 >深度学习面试常用代码:MHA/MQA/GQA/LN/BN/位置编码代码

深度学习面试常用代码:MHA/MQA/GQA/LN/BN/位置编码代码

时间:2023-12-11 17:02:04浏览次数:36  
标签:MQA head heads nn LN 代码 model self size

深度学习常用代码

1. MHA(MultiHeadAttention)代码实现


# 1. MHA实现
import torch
import torch.nn as nn
import torch.nn.functional as F

class ScaleDotProductAttention(nn.Module):
    def __init__(self, ):
        super(ScaleDotProductAttention, self).__init__()
        self.softmax = nn.Softmax(dim = -1)

    def forward(self, Q, K, V, mask=None):
        K_T = K.transpose(-1, -2) # 计算矩阵 K 的转置  
        d_k = Q.size(-1)
        # 1, 计算 Q, K^T 矩阵的点积,再除以 sqrt(d_k) 得到注意力分数矩阵
        scores = torch.matmul(Q, K_T) / math.sqrt(d_k)
        # 2, 如果有掩码,则将注意力分数矩阵中对应掩码位置的值设为负无穷大
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        # 3, 对注意力分数矩阵按照最后一个维度进行 softmax 操作,得到注意力权重矩阵,值范围为 [0, 1]
        attn_weights = self.softmax(scores)
        # 4, 将注意力权重矩阵乘以 V,得到最终的输出矩阵
        output = torch.matmul(attn_weights, V)

        return output, attn_weights
      
    
class MultiHeadAttention(nn.Module):
    """Multi-Head Attention Layer
    Args:
        d_model: Dimensions of the input embedding vector, equal to input and output dimensions of each head
        n_head: number of heads, which is also the number of parallel attention layers
    """
    def __init__(self, d_model, n_head):
        super(MultiHeadAttention, self).__init__()
        self.n_head = n_head
        self.attention = ScaleDotProductAttention()
        self.w_q = nn.Linear(d_model, d_model)  # Q 线性变换层
        self.w_k = nn.Linear(d_model, d_model)  # K 线性变换层
        self.w_v = nn.Linear(d_model, d_model)  # V 线性变换层
        self.fc = nn.Linear(d_model, d_model)   # 输出线性变换层

    def forward(self, q, k, v, mask=None):
        # 1. dot product with weight matrices
        q, k, v = self.w_q(q), self.w_k(k), self.w_v(v) # size is [batch_size, seq_len, d_model]
        # 2, split by number of heads(n_head) # size is [batch_size, n_head, seq_len, d_model//n_head]
        q, k, v = self.split(q), self.split(k), self.split(v)
        # 3, compute attention
        sa_output, attn_weights = self.attention(q, k, v, mask)
        # 4, concat attention and linear transformation
        concat_tensor = self.concat(sa_output)
        mha_output = self.fc(concat_tensor)

        return mha_output

    def split(self, tensor):
        """
        split tensor by number of head(n_head)

        :param tensor: [batch_size, seq_len, d_model]
        :return: [batch_size, n_head, seq_len, d_model//n_head], 输出矩阵是四维的,第二个维度是 head 维度

        # 将 Q、K、V 通过 reshape 函数拆分为 n_head 个头
        batch_size, seq_len, _ = q.shape
        q = q.reshape(batch_size, seq_len, n_head, d_model // n_head)
        k = k.reshape(batch_size, seq_len, n_head, d_model // n_head)
        v = v.reshape(batch_size, seq_len, n_head, d_model // n_head)
        """

        batch_size, seq_len, d_model = tensor.size()
        d_tensor = d_model // self.n_head
        split_tensor = tensor.view(batch_size, seq_len, self.n_head, d_tensor).transpose(1, 2)
        # it is similar with group convolution (split by number of heads)

        return split_tensor
      
    
    def concat(self, sa_output):
        """ merge multiple heads back together

        Args:
            sa_output: [batch_size, n_head, seq_len, d_tensor]
            return: [batch_size, seq_len, d_model]
        """
        batch_size, n_head, seq_len, d_tensor = sa_output.size()
        d_model = n_head * d_tensor
        concat_tensor = sa_output.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)

        return concat_tensor
      
MHA = MultiHeadAttention(8, 2)
q = torch.ones([2, 3, 8])     # bs, seq_len, dimision
k = torch.ones([2, 3, 8])
v = torch.ones([2, 3, 8])
MHA(q,k,v)

MQA(MultiQueryAttention)代码实现

# MQA实现
class MultiQueryAttention(nn.Module):
    """Multi-Query self attention.
    Using torch or triton attention implemetation enables user to also use
    additive bias.
    """
    def __init__(self,d_model: int,n_heads: int,device: Optional[str] = None):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        self.Wqkv = nn.Linear(                           # 【关键】Multi-Query Attention 的创建方法
            d_model,
            d_model + 2 * self.head_dim,                 # 只创建 query 的 head 向量,所以只有 1 个 d_model
            device=device,                               # 而 key 和 value 则只共享各自的一个 head_dim 的向量
        )
        self.attn_fn = scaled_multihead_dot_product_attention
        self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
        self.out_proj._is_residual = True  # type: ignore
    def forward(self,x):
        qkv = self.Wqkv(x)                                           # (1, 512, 960)
        query, key, value = qkv.split(                               # query -> (1, 512, 768)
            [self.d_model, self.head_dim, self.head_dim],            # key   -> (1, 512, 96)
            dim=2                                                    # value -> (1, 512, 96)
        )
        context, attn_weights, past_key_value = self.attn_fn(query,key,value,self.n_heads,multiquery=True)
        return self.out_proj(context), attn_weights, past_key_value

GQA(GroupQueryAttention)代码实现

# 3.GQA实现
# 参考:llama2源代码
# https://zhuanlan.zhihu.com/p/649756898?utm_id=0

import torch
import torch.nn as nn


def repeat_kv(x, n_rep):
  bs, slen, n_kv_heads, head_dim = x.size()
  # 根据n_rep扩展kv
  if n_rep == 1:
    return x
  return (x[:,:,:,None,:].expand(bs, slen, n_kv_heads, n_rep, head_dim).reshape(bs, slen, n_kv_heads*n_rep, head_dim))

class Attention(nn.Module):
  def __init__(self, n_heads, n_kv_heads, dim,head_dim, max_batch_size, max_seq_len, model_parallel_size):
    super().__init__()
    
    self.n_local_heads = n_heads // model_parallel_size         # Q的头数       [涉及模型并行]
    self.n_local_kv_heads = n_kv_heads // model_parallel_size   # KV的头数
    self.n_rep = self.n_local_heads // self.n_local_kv_heads    # KV 需要重复的次数
    
    self.wq = nn.Linear(dim, n_heads * head_dim)       # [768, 96=768/8 * 8] Q头数*每个头的dim
    self.wk = nn.Linear(dim, n_kv_heads * head_dim)    
    self.wv = nn.Linear(dim, n_kv_heads * head_dim)
    self.wo = nn.Linear(n_heads * head_dim, dim)
    
  def forward(self, x, mask=None):
    bsz, seqlen, _ = x.size()
    xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
    
    xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
    xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
    x = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
    
    xq, xk = apply_rotary_emb(xq, xk)       # RoPE位置编码
    
    # KV Cache
    
    # repeat K/V heads if n_kv_heads < n_heads
    keys = repeat_kv(keys, self.n_rep)          # [bs,slen,n_kv_heads*n_rep, dim]
    values = repeat_kv(values, self.n_rep)      # [bs,slen,n_kv_heads*n_rep, dim]
    
    scores = torch.matmul(xq, keys.transpose(2,3)) / math.sqrt(self.head_dim)
    
    if mask is not None:
      scores = scores + mask
    
    scores = F.softmax(scores, dim=-1)
    output = torch.matmul(scores, values)
    output = output.transpose(1,2).contiguous().view(bsz, seqlen, -1)
    return self.wo(output)

KV_Cache

import torch.nn as nn

class IncrementalAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(IncrementalAttention, self).__init__()
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.depth = d_model // self.num_heads

        # Q, K, V 的线性层
        self.WQ = nn.Linear(d_model, d_model)
        self.WK = nn.Linear(d_model, d_model)
        self.WV = nn.Linear(d_model, d_model)

        self.softmax = nn.Softmax(dim=-1)

        # 初始化K和V的空缓存
        self.k_cache = None
        self.v_cache = None

    def forward(self, q, k, v, mask=None):
        # 计算Q
        q = self.WQ(q).view(-1, self.num_heads, self.depth)

        # 计算新令牌的K和V
        k_new = self.WK(k).view(-1, self.num_heads, self.depth)
        v_new = self.WV(v).view(-1, self.num_heads, self.depth)

        # 添加到缓存
        if self.k_cache is not None:
            k = torch.cat([self.k_cache, k_new], dim=1)
            v = torch.cat([self.v_cache, v_new], dim=1)
        else:
            k = k_new
            v = v_new
        
        # 更新缓存以供下一次迭代使用
        self.k_cache = k
        self.v_cache = v

        # 注意力机制(简化了,以便简短)
        scores = torch.matmul(q, k.transpose(1, 2)) / self.depth**0.5
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attn_weights = self.softmax(scores)
        output = torch.matmul(attn_weights, v)

        return output

Transformer Embedding实现


# Embeddiing实现: PositionEmbedding + TokenEmbedding
import torch
import torch.nn as nn
import torch.nn.functional as F

class PositionalEncoding(nn.Module):
    """
    compute sinusoid encoding.
    """
    def __init__(self, d_model, max_len, device):
        """
        constructor of sinusoid encoding class

        :param d_model: dimension of model
        :param max_len: max sequence length
        :param device: hardware device setting
        """
        super(PositionalEncoding, self).__init__()

        # same size with input matrix (for adding with input matrix)
        self.encoding = torch.zeros(max_len, d_model, device=device)
        self.encoding.requires_grad = False  # we don't need to compute gradient

        pos = torch.arange(0, max_len, device=device)
        pos = pos.float().unsqueeze(dim=1)
        # 1D => 2D unsqueeze to represent word's position

        _2i = torch.arange(0, d_model, step=2, device=device).float()
        # 'i' means index of d_model (e.g. embedding size = 50, 'i' = [0,50])
        # "step=2" means 'i' multiplied with two (same with 2 * i)

        self.encoding[:, 0::2] = torch.sin(pos / (10000 ** (_2i / d_model)))
        self.encoding[:, 1::2] = torch.cos(pos / (10000 ** (_2i / d_model)))
        # compute positional encoding to consider positional information of words

    def forward(self, x):
        # self.encoding
        # [max_len = 512, d_model = 512]

        batch_size, seq_len = x.size()
        # [batch_size = 128, seq_len = 30]

        return self.encoding[:seq_len, :]
        # [seq_len = 30, d_model = 512]
        # it will add with tok_emb : [128, 30, 512]
        
        
class TokenEmbedding(nn.Embedding):
    """
    Token Embedding using torch.nn
    they will dense representation of word using weighted matrix
    """

    def __init__(self, vocab_size, d_model):
        """
        class for token embedding that included positional information
        :param vocab_size: size of vocabulary
        :param d_model: dimensions of model
        """
        super(TokenEmbedding, self).__init__(vocab_size, d_model, padding_idx=1)

class TransformerEmbedding(nn.Module):
    """
    token embedding + positional encoding (sinusoid)
    positional encoding can give positional information to network
    """

    def __init__(self, vocab_size, max_len, d_model, drop_prob, device):
        """
        class for word embedding that included positional information
        :param vocab_size: size of vocabulary
        :param d_model: dimensions of model
        """
        super(TransformerEmbedding, self).__init__()
        self.tok_emb = TokenEmbedding(vocab_size, d_model)
        self.pos_emb = PositionalEncoding(d_model, max_len, device)
        self.drop_out = nn.Dropout(p=drop_prob)

    def forward(self, x):
        tok_emb = self.tok_emb(x)
        pos_emb = self.pos_emb(x)
        return self.drop_out(tok_emb + pos_emb)

LN代码实现

# LN实现
import torch
import torch.nn as nn
import torch.nn.functional as F

class LayerNorm(nn.Module):
    def __init__(self, d_model, eps=1e-12):
        super(LayerNorm, self).__init__()
        self.gamma = nn.Parameter(torch.ones(d_model))
        self.beta = nn.Parameter(torch.zeros(d_model))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True) # '-1' means last dimension. 
        var = x.var(-1, keepdim=True)

        out = (x - mean) / torch.sqrt(var + self.eps)
        out = self.gamma * out + self.beta

        return out

# NLP Example
batch, sentence_length, embedding_dim = 20, 5, 10
embedding = torch.randn(batch, sentence_length, embedding_dim)

# 1,Activate nn.LayerNorm module
layer_norm1 = nn.LayerNorm(embedding_dim)
pytorch_ln_out = layer_norm1(embedding)

# 2,Activate my nn.LayerNorm module
layer_norm2 = LayerNorm(embedding_dim)
my_ln_out = layer_norm2(embedding)

# 比较结果
print(torch.allclose(pytorch_ln_out, my_ln_out, rtol=0.1,atol=0.01))  # 输出 True

标签:MQA,head,heads,nn,LN,代码,model,self,size
From: https://www.cnblogs.com/NSEW/p/17894813.html

相关文章

  • 代码完形填空学习记录
    1.wxml<swiper circular indicator-dots>#circular使轮播图循环播放 indicator-dots用于显示指示点</swiper>(最后 <block wx:for="{{list}}" wx:key="index">#从list中获得轮播项    <swiper-item>      <image src="{{item.imgUrl}}&qu......
  • 【Android】Android清除本地数据缓存代码
    /**文件名:DataCleanManager.java*描述:主要功能有清除内/外缓存,清除数据库,清除sharedPreference,清除files和清除自定义目录*/importjava.io.File;importandroid.content.Context;importandroid.os.Environment;/***本应用数据清除管理器*/public......
  • gitlab+sonarqube实现自动化代码漏洞扫描(centos内网环境安装配置及使用遇到的问题)
    sonarqube可以与源码管理工具gitlab集成,实现提交代码后自动扫描检测代码的相关漏洞。该CI/CD过程大致为:1、研发人员提交源码至gitlab服务器—>2、gitlabrunner执行指定脚本(由项目的.gitlab-ci.yml配置文件指定具体内容,如编译项目、开启代码检测) —>3、sonar-scanner对项目进......
  • vulnhub-Deathnote
    Deathnote0x01信息收集渗透的第一件事那就是信息收集,虚拟机启动起来,nmap扫一手, 192.168.56.101存在,并且开启了80和22端口,那么就访问页面。但是页面返回的结果是没找到http://deathnote.vuln这个网址 那么根据我做题的经验,这应该是需要绑定域名的,意思就是192.168.56.101......
  • Amazon CodeWhisperer 免费的 AI 代码生成助手!最新体验反馈~
    文章作者:段小草自这一波生成式AI浪潮以来,大家尝鲜之余最关心的问题还是如何提高生产力。我们在测试大语言模型能力的时候,往往会将代码能力作为一项重要的评价指标,也说明程序员还是希望能在AI的加持下更快捷安全地编写代码。亚马逊云科技开发者社区为开发者们提供全球的开发......
  • python 怎么组织代码?
    参考:https://www.liaoxuefeng.com/wiki/1016959663602400/10174541450141761.为什么不能把代码写到一个.py中?实际开发中,我们不可能把所有的代码都写到一个.py文件中,看起来太累了,并且难以修改,修改后难免要考虑会不会影响别的。解决方法:把很多函数分组,分别放到不同的文件里,......
  • 出生率持续下降,而低代码,成了!
    低代码这个概念在IT界应该是火了很久,在十年前就有低代码的概念。 在最初的时候,我们都是用高级语言或者脚本来开发页面或者应用,比如Java、C++,前端会使用Vue、React等等。但是我们发现经常写的功能或者页面都是重复的,那能否通过更简单高效的方式来避免每次都是重头开发呢?当时业......
  • Notepad++ 代码格式化插件工具
    因为notepad++的NppAStyle插件只支持格式化C、C++、C#、Java这四种编程语言的代码,所以推荐使用这个CoolFormat的插件,相比于NPPAStyle,CoolFormat支持C\C++\C#\CSS\HTML\Java\JavaScript\JSON\Objective-C\PHP\SQL\XML代码格式化工具。还可以作为VisualStudio2013、SublimeText、......
  • 图片铺满div元素不变形,超出部分隐藏,保留中心部分css代码
    在我们网站更新文章的时候,经常会插入图片,丰富信息。但是我们插入的图片长宽比例并不一定是固定的。我们在调用缩略图的时候,常常会出现图片变形的情况,高和宽不成比例。那么如何让图片不变形,又能铺满div元素呢?我们可以使用css代码中object-fit属性来实现。object-fit属性指定元素的......
  • 村庄规划标识码和要素代码自动赋值
    下载:资源下载介绍:整个数据库自动编标识码和要素代码。操作:1.拷贝整个CZGH文件夹,利用arcgis打开,选择村庄规划质检数据库(数据库存放文件夹不能为汉语,数据库名称不能为汉语),输入村委会代码点击确定即可。2.由于GHWB和SGTJ两个表中文件名称与要素代码名称有些不一致,需要人......