##---------- Prompt Gen Module ----------------------- class PromptGenBlock(nn.Module): def __init__(self,prompt_dim=128,prompt_len=5,prompt_size = 96,lin_dim = 192): super(PromptGenBlock,self).__init__() self.prompt_param = nn.Parameter(torch.rand(1,prompt_len,prompt_dim,prompt_size,prompt_size)) self.linear_layer = nn.Linear(lin_dim,prompt_len) self.conv3x3 = nn.Conv2d(prompt_dim,prompt_dim,kernel_size=3,stride=1,padding=1,bias=False) def forward(self,x): B,C,H,W = x.shape emb = x.mean(dim=(-2,-1)) prompt_weights = F.softmax(self.linear_layer(emb),dim=1) prompt_param = self.prompt_param.data.unsqueeze(0).repeat(B,1,1,1,1,1) prompt_param = prompt_param.data.squeeze(1) prompt_param = prompt_weights.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * prompt_param prompt_param = torch.sum(prompt_param,dim=1) prompt_param = F.interpolate(prompt_param,(H,W),mode="bilinear") prompt_param = self.conv3x3(prompt_param) return prompt_param
标签:dim,unsqueeze,prompt,self,param,111111111111,size From: https://www.cnblogs.com/yyhappy/p/17502364.html