首页 > 编程语言 >cross attention的源码实现,并代码详细讲解

cross attention的源码实现,并代码详细讲解

时间:2024-06-23 16:53:57浏览次数:3  
标签:Attention seq attention cross len 源码 np 注意力

 

import numpy as np

def softmax(x, axis=-1):
    """Softmax函数,用于计算注意力权重"""
    e_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
    return e_x / e_x.sum(axis=axis, keepdims=True)

def scaled_dot_product_attention(q, k, v, mask=None):
    """缩放点积注意力机制,用于计算输出和注意力权重"""
    print(q.shape)
    print(k.transpose().shape)
    matmul_qk = np.matmul(q, k.transpose(0,2,1))  # 计算查询和键的矩阵乘积
    d_k = k.shape[-1]  # 键的维度
    scaled_attention_logits = matmul_qk / np.sqrt(d_k)  # 缩放注意力分数

    if mask is not None:  # 如果有注意力掩码,将其添加到分数上
        scaled_attention_logits += (mask * -1e9)

    attention_weights = softmax(scaled_attention_logits)  # 计算注意力权重
    output = np.matmul(attention_weights, v)  # 计算输出
    return output, attention_weights

def cross_attention(q, k, v, mask=None):
    """Cross-Attention机制"""
    # q, k, v 必须有匹配的前导维度
    # q, k, v 的形状: (batch_size, seq_len, embed_dim)
    # mask 的形状: (batch_size, seq_len_q, seq_len_k)

    # 使用缩放点积注意力机制计算注意力
    output, attention_weights = scaled_dot_product_attention(q, k, v, mask)

    return output, attention_weights

# 测试用例
np.random.seed(0)  # 确保可重复性

# 创建查询、键和值矩阵
batch_size = 2
seq_len_q = 3
seq_len_k = 4
embed_dim = 5
q = np.random.rand(batch_size, seq_len_q, embed_dim)
k = np.random.rand(batch_size, seq_len_k, embed_dim)
v = np.random.rand(batch_size, seq_len_k, embed_dim)

# 创建注意力掩码(可选)
mask = np.zeros((batch_size, seq_len_q, seq_len_k))
mask[:, :, -1:] = 1  # 假设我们想忽略每个序列的最后一个元素

# 计算Cross-Attention
output, attention_weights = cross_attention(q, k, v, mask)

print("Output shape:", output.shape)  # 应该是 (batch_size, seq_len_q, embed_dim)
print("Attention weights shape:", attention_weights.shape)  # 应该是 (batch_size, seq_len_q, seq_len_k)

 

Cross-Attention,也称为自注意力或查询(Query)-键(Key)-值(Value)注意力机制,是一种在Transformer模型中广泛使用的注意力机制。在Cross-Attention中,查询(Query)通常来自于一个序列(如文本序列),而键(Key)和值(Value)来自于另一个序列(如另一个文本序列或图像特征)。

以下是一个简化的Cross-Attention的源码实现,使用Python和NumPy库。这个实现是为了说明Cross-Attention的基本概念,并不是一个高效或完整的实现。在实际应用中,Cross-Attention通常使用更高效的库,如TensorFlow或PyTorch。

代码讲解:

  1. softmax 函数:用于计算注意力权重。它首先从输入矩阵中减去每行的最大值,以增加数值稳定性,然后计算指数,最后将结果归一化为概率分布。

  2. scaled_dot_product_attention 函数:实现缩放点积注意力机制。它首先计算查询(q)和键(k)的转置的矩阵乘积,然后除以键的维度的平方根进行缩放。如果有注意力掩码(mask),将其应用于注意力分数以忽略某些部分。最后,使用softmax函数计算注意力权重,并将其与值(v)相乘以得到输出。

  3. cross_attention 函数:实现Cross-Attention机制。它接受查询(q)、键(k)和值(v)作为输入,以及一个可选的注意力掩码(mask)。它调用scaled_dot_product_attention函数来计算输出和注意力权重,并将其返回。

在实际应用中,Cross-Attention通常使用深度学习框架(如PyTorch或TensorFlow)的内置函数和类来实现,这些实现更加高效和灵活。上述代码仅用于说明Cross-Attention的基本概念。

 

 

标签:Attention,seq,attention,cross,len,源码,np,注意力
From: https://www.cnblogs.com/xiaochouk/p/18263612

相关文章