title: GRU原理及其实现
date: 2022-10-03 09:31:44
mathjax: true
tags:
- GRU
GRU原理及其实现
同等情况下GRU的参数是LSTM的0.75倍
公式
1-zt保留当前候选者,zt保留上一时刻的部分,公式中的*表示按位置相乘
查看网络模型的参数数目
GRU约为LSTM的0.75倍
输入参数
都是3维的
输出参数
API实现
batch_size,T,i_size,h_size = 2,3,4,5
input = torch.randn(batch_size,T,i_size) # 输入序列
h0 = torch.randn(batch_size,h_size)
# 用pytorch的api实现
gru_layer = nn.GRU(i_size,h_size,batch_first=True)
output,h_final = gru_layer(input,h0.unsqueeze(0))
print(output)
自定义
def gru_forward(input,initial_states,w_ih,w_hh,b_ih,b_hh):
prev_h = initial_states
bs,T,i_size = input.shape
h_size = w_ih.shape[0] // 3 # 只有只有r,z,n门有w,而且这些w是堆叠在一起的
# w是二维张量,而input和initial_states都是带有batch的三维张量
# 所以需要两个w进行扩维
batch_w_ih = w_ih.unsqueeze(0).tile(bs,1,1)
batch_w_hh = w_hh.unsqueeze(0).tile(bs,1,1)
output = torch.zeros(bs,T,h_size) # GRU网络的输出
for t in range(T):
x = input[:,t,:] # t时刻的GRU cell的输入特征向量 [bs,i_size]
w_times_x = torch.bmm(batch_w_ih,x.unsqueeze(-1)) # [bs,3*h_size,1]
w_times_x = w_times_x.squeeze(-1) # [bs,3*h_size]
w_times_h_prev = torch.bmm(batch_w_hh,prev_h.unsqueeze(-1)) # [bs,3*h_size,1]
w_times_h_prev = w_times_h_prev.squeeze(-1) # [bs,3*h_size]
# 重置门
r_t = torch.sigmoid(w_times_x[:,:h_size]+w_times_h_prev[:,:h_size]+b_ih[:h_size]+b_hh[:h_size])
# 更新门
z_t = torch.sigmoid(w_times_x[:,h_size:2*h_size]+w_times_h_prev[:,h_size:2*h_size]+b_ih[h_size:2*h_size]+b_hh[h_size:2*h_size])
# 候选门
n_t = torch.tanh(w_times_x[:,2*h_size:3*h_size] +b_ih[2*h_size:3*h_size]+r_t*(w_times_h_prev[:,2*h_size:3*h_size] + b_hh[2*h_size:3*h_size]))
# 增量更新,含有隐藏状态的
prev_h = (1-z_t)*n_t + z_t*prev_h
output[:,t,:] = prev_h
return output,prev_h
# 调用自定义
output_custom,h_final_custom = gru_forward(input,h0,gru_layer.weight_ih_l0,gru_layer.weight_hh_l0,gru_layer.bias_ih_l0,gru_layer.bias_hh_l0)
print(output_custom)
查看两个是否一致
torch.allclose(output,output_custom)
标签:GRU,及其,times,ih,hh,原理,prev,size
From: https://www.cnblogs.com/bzwww/p/16805766.html