import torch
def promptGating(gating, adding, x):
'''
gating: (num_prefix, dim)
adding: (num_prefix, dim)
x: (seq_length, batch_size, dim)
'''
if gating is not None:
gating = gating.unsqueeze(0).expand(x.size(1), -1, -1).transpose(0, 1) # (num_prefix,batch_size,dim)
gating = torch.cat([gating, torch.ones([x.size(0)-gating.size(0), x.size(1), x.size(2)])], axis=0)
# (seq_length, batch_size, dim)
x = x * gating # prefix之外*1
if adding is not None: #相当于加上bias
adding = adding.unsqueeze(0).expand(x.size(1), -1, -1).transpose(0, 1)
adding = torch.cat([adding, torch.zeros([x.size(0)-adding.size(0), x.size(1), x.size(2)])], axis=0)
x = adding + x # prefix之外+0
return x
if __name__ == "__main__":
num_prompt, batch_size, seq_length, dim = 2, 8, 22, 1024
gating = torch.randn(num_prompt, dim)
adding = torch.randn(num_prompt, dim)
x = torch.randn(seq_length, batch_size, dim)
new_x = promptGating(gating, adding, x)
print(new_x.shape)
# 输出:torch.Size([22, 8, 1024])
标签:adding,prompt,dim,代码,torch,num,gating,size
From: https://www.cnblogs.com/tuyuge/p/17617500.html