class Get_gradient_nopadding_rgb(nn.Module): def __init__(self): super(Get_gradient_nopadding_rgb, self).__init__() kernel_v = [[0, -1, 0], [0, 0, 0], [0, 1, 0]] kernel_h = [[0, 0, 0], [-1, 0, 1], [0, 0, 0]] kernel_h = torch.FloatTensor(kernel_h).unsqueeze(0).unsqueeze(0) kernel_v = torch.FloatTensor(kernel_v).unsqueeze(0).unsqueeze(0) self.weight_h = nn.Parameter(data=kernel_h, requires_grad=False).cuda() self.weight_v = nn.Parameter(data=kernel_v, requires_grad=False).cuda() def forward(self, x): x0 = x[:, 0] x1 = x[:, 1] x2 = x[:, 2] x0_v = F.conv2d(x0.unsqueeze(1), self.weight_v, padding=1) x0_h = F.conv2d(x0.unsqueeze(1), self.weight_h, padding=1) x1_v = F.conv2d(x1.unsqueeze(1), self.weight_v, padding=1) x1_h = F.conv2d(x1.unsqueeze(1), self.weight_h, padding=1) x2_v = F.conv2d(x2.unsqueeze(1), self.weight_v, padding=1) x2_h = F.conv2d(x2.unsqueeze(1), self.weight_h, padding=1) x0 = torch.sqrt(torch.pow(x0_v, 2) + torch.pow(x0_h, 2) + 1e-6) x1 = torch.sqrt(torch.pow(x1_v, 2) + torch.pow(x1_h, 2) + 1e-6) x2 = torch.sqrt(torch.pow(x2_v, 2) + torch.pow(x2_h, 2) + 1e-6) x = torch.cat([x0, x1, x2], dim=1) return x
标签:unsqueeze,self,torch,x2,x0,x1,grad From: https://www.cnblogs.com/yyhappy/p/17987168