首页 > 其他分享 >BI-LSTM+Attention 的 tensorflow-1.14 实现

BI-LSTM+Attention 的 tensorflow-1.14 实现

时间:2024-07-28 17:00:28浏览次数:16  
标签:1.14 state Attention BI batch step tf hidden size

这里只是用简单例子演示关于 self-attention 的逻辑,判断一句话的情感是正面或者是负面,具体原理自己百度即可。

import tensorflow as tf
import numpy as np
tf.reset_default_graph()

# 词向量维度
dim = 2
# 隐层大小
hidden = 5
# 时间步大小
step = 3
# 情感类别 正面或者负面
N = 2

sentences = ["i love mengjun","i like peipei","she likes damao","she hates wangda","wangda is good","mengjun is bad"]
labels = [1,1,1,0,1,0]

words = list(set(" ".join(sentences).split()))
# 词典大小
V = len(words)
# 单词和索引互相映射
word2idx = {v:k for k,v in enumerate(words)}
idx2word = {k:v for k,v in enumerate(words)}

# 处理输入数据
input_batch = []
for sentence in sentences:
    input_batch.append([word2idx[word] for word in sentence.split()])

# 处理输出目标数据
target_batch = []
for label in labels:
    target_batch.append(np.eye(N)[label]) # 这里要进行独热编码,后面计算损失会用到
    
# 初始化词向量
embedding = tf.Variable(tf.random_normal([V, dim]))
# 输出分类时使用到的向量矩阵
out = tf.Variable(tf.random_normal([hidden * 2, N]))

X = tf.placeholder(tf.int32, [None, step])
# 对输入进行词嵌入
X_embedding = tf.nn.embedding_lookup(embedding, X)
Y = tf.placeholder(tf.int32, [None, N])

# 定义正向和反向的 lstm 
lstm_fw_cell = tf.nn.rnn_cell.LSTMCell(hidden)
lstm_bw_cell = tf.nn.rnn_cell.LSTMCell(hidden)

# 经过双向 lstm 的计算得到结果 
# output : ([batch_size, step, hidden],[batch_size, step, hidden])  
# final_state : (fw:(c:[batch_size, hidden], h:[batch_size, hidden]), bw:(c:[batch_size, hidden], h:[batch_size, hidden]))
output, final_state = tf.nn.bidirectional_dynamic_rnn(lstm_fw_cell, lstm_bw_cell, X_embedding, dtype=tf.float32)
# 将 output 根据 hidden 维度拼接起来,[batch_size, step, hidden*2]
output = tf.concat([output[0], output[1]], 2)

# 将 final_state 的反方向的 c 和 h 根据 hidden 维度拼接起来, [batch_size, hidden*2]
final_hidden_state = tf.concat([final_state[1][0], final_state[1][1]], 1)
# 增加第三个维度,方便计算 [batch_size, hidden*2, 1]
final_hidden_state = tf.expand_dims(final_hidden_state, 2)

# 计算每个时间步的输出与最后输出状态的相似度 
# [batch_size, step, hidden*2] * [batch_size, hidden*2, 1] = squeeze([batch_size, step, 1]) = [batch_size, step]
attn_weights = tf.squeeze(tf.matmul(output, final_hidden_state), 2)
# 在时间步维度上进行 softmax 得到权重向量
soft_attn_weights = tf.nn.softmax(attn_weights, 1)

# 各时间步输出和对应的权重想成得到上下文矩阵 [batch_size, hidden*2, step] * [batch_size, step, 1] = [batch_size, hidden*2, 1]
context = tf.matmul(tf.transpose(output, [0, 2, 1]), tf.expand_dims(soft_attn_weights, 2))
# squeeze([batch_size, hidden*2, 1]) = [batch_size, hidden*2]
context = tf.squeeze(context, 2)

# 输出概率矩阵 [batch_size, hidden*2] * [hidden*2, N] = [batch_size, N]
model = tf.matmul(context, out)
# 计算损失并优化
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=model ,labels=Y))
optimizer = tf.train.AdamOptimizer(0.001).minimize(cost)
# 预测
hypothesis = tf.nn.softmax(model)
prediction = tf.argmax(hypothesis, 1)

sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
for epoch in range(5000):
    _, loss = sess.run([optimizer, cost], feed_dict={X:input_batch, Y:target_batch})
    if (epoch+1) % 1000 == 0:
        print('epoch ','%06d'%(epoch+1), ' loss ', '%08f'%loss)
        
test_text = [[word2idx[word] for word in 'she hates wangda'.split()]]
predict = sess.run([prediction], feed_dict={X: test_text})
print('she hates wangda', '-->', predict[0][0])

结果打印:

epoch  001000  loss  0.001645
epoch  002000  loss  0.000279
epoch  003000  loss  0.000106
epoch  004000  loss  0.000052
epoch  005000  loss  0.000029
she hates wangda --> 0  

标签:1.14,state,Attention,BI,batch,step,tf,hidden,size
From: https://blog.csdn.net/wang7075202/article/details/140745731

相关文章

  • 使用 Docker Compose 部署 RabbitMQ 的一些经验与踩坑记录
    前言RabbitMQ是一个功能强大的开源消息队列系统,它实现了高效的消息通信和异步处理。本文主要介绍其基于Docker-Compose的部署安装和一些使用的经验。特点成熟,稳定消息持久化灵活的消息路由高性能,高可用性,可扩展性高支持插件系统:RabbitMQ具有丰富的插件系统,可以通......
  • 在 Google Colab 上运行 Django:错误 403 Forbidden
    我正在尝试对我的Python程序的Colab进行一些测试并使用Django。我按照此链接中的说明进行操作。我确保在settings.py中设置了此设置ALLOWED_HOSTS=['*']运行此命令以获取链接https://randomstrings.colab.googleusercontent.com/fromgo......
  • 操作系统Bit位数操作类 - C#小函数类推荐
          此文记录的是检测当前操作系统的位数的函数。/***操作系统Bit位数操作类AustinLiu刘恒辉ProjectManagerandSoftwareDesignerE-Mail:[email protected]:http://lzhdim.cnblogs.comDate:2024-01-1515:18:00使......
  • Vcpkg + cmake + pybind 问题“无法找到平台独立库 <前缀>”
    我发现了vcpkgerlier,它看起来很有趣,但是易于使用。据我了解,经过一天的调查,vcpkgpybind11与vcpkgpython搭配使用。但是当我启动一个简单的程序时,它被中止并出现以下输出无法找到平台独立库<前缀>这是一个已知问题,但不适用于vcpkgpython。我不知道为什么?不......
  • flutter中使用rabbitmq
    依赖dart_amqp:^0.3.1#rabbitMq接收发送消息工具封装import'package:dart_amqp/dart_amqp.dart';///封装RabbitMQ的服务类classRabbitMQService{lateConnectionSettings_settings;//RabbitMQ连接设置lateClient_client;//RabbitMQ客户端late......
  • Linux捣鼓记录:debian12日志警告:firmware: failed to load iwl-debug-yoyo.bin (-2)
    问题现象:网卡为intelax200,系统为debian12蓝牙wifi使用功能一切正常,根据wiki检查了驱动也都已经安装,但每次开机后,查看cockpit日志会看到警告:firmware:failedtoloadiwl-debug-yoyo.bin(-2)......问题分析:检索网络得到初步结论:iwl-debug-yoyo.bin是一个intel网卡相关的de......
  • 经典CNN模型(九):MobileNetV3(PyTorch详细注释版)
    一.MobileNetV3神经网络介绍MobileNetV3是MobileNet系列的第三代模型,由Google在2019年提出,旨在进一步优化模型的效率和性能,特别是在移动设备和边缘计算设备上。与前一代相比,MobileNetV3引入了多项改进,包括使用神经架构搜索(NeuralArchitectureSearch,NAS)、自适......
  • Linux捣鼓记录:debian配置语言环境
    1.安装区域设置sudoaptupdatesudoaptinstalllocales2.配置语言环境sudodpkg-reconfigurelocales按空格多选,选中en_US.UTF-8和zh_CN.UTF-8这里多选择了英文,可以避免有些软件比如steamcmd报警告:WARNING:setlocale('en_US.UTF-8')failed,usinglocale:'C'.......
  • Linux捣鼓记录:debian12安装xfce桌面环境
    在Debian12上安装Xfce桌面第1步。在安装任何软件包之前,建议更新软件包列表以确保您安装的是最新版本的软件包。您可以通过在终端中运行以下命令来执行此操作:sudoaptupdate&&sudoaptupgrade此命令将刷新存储库,允许您安装最新版本的软件包。第2步。在Debian12......
  • 【独家首发】Matlab实现凌日优化算法TSOA优化Transformer-BiLSTM实现负荷数据回归预测
    %假设您有负荷数据load_data和相应的回归标签regression_labels%1.数据预处理%在这一步中,您需要对负荷数据进行适当的预处理,例如归一化、序列化等操作%2.划分数据集为训练集和测试集%这里假设您将数据划分为train_data,train_labels,test_data,test_label......