当输入数据X维度为[num_classes,feat_dim]时,参考链接: Center loss-pytorch代码详解.
对于输入数据X类型为[batch_size,seq_len,feat_dim],对参考链接代码进行调整,整个代码如下:
class CenterLoss_seq(nn.Module):
"""Center loss.
Reference:
Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.
Args:
num_classes (int): number of classes.
feat_dim (int): feature dimension.
"""
def __init__(self, batch_size, num_classes, feat_dim, use_gpu=True):
super(CenterLoss_seq, self).__init__()
self.num_classes = num_classes
self.feat_dim = feat_dim
self.use_gpu = use_gpu
self.batch_size = batch_size
if self.use_gpu:
self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda())
else:
self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim))
# 计算中心矩阵的转置,以便在forward方法中使用矩阵乘法进行高效计算
self.center_t = self.centers.t().expand(self.batch_size, self.feat_dim, self.num_classes).contiguous()
def forward(self, x, labels):
"""
Args:
x: feature matrix with shape (batch_size, seq_len, feat_dim).
labels: ground truth labels with shape (batch_size, seq_len).
"""
batch_size = x.size(0)
seq_len = x.size(1)
x_pow = torch.pow(x, 2).sum(dim=2, keepdim=True).expand(batch_size, seq_len, self.num_classes)
center_pow = torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()
center_pow = center_pow.unsqueeze(1).expand(batch_size, seq_len, self.num_classes)
# distmat = torch.pow(x, 2).sum(dim=2, keepdim=True).expand(batch_size, seq_len, self.num_classes) + \
# torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()
distmat = x_pow + center_pow
# center_t = self.centers.t().expand(batch_size, self.feat_dim, self.num_classes)
# x_cen_mul = torch.bmm(x.float(), center_t.float())
x_cen_mul = torch.matmul(x.float(), self.center_t.float()) # 使用矩阵乘法进行高效计算
distmat = distmat - 2 * x_cen_mul
# distmat.addmm_(1, -2, x, self.centers.t())
classes = torch.arange(self.num_classes).long()
if self.use_gpu: classes = classes.cuda()
labels = labels.unsqueeze(2).expand(batch_size, seq_len, self.num_classes)
mask = labels.eq(classes.expand(batch_size, seq_len, self.num_classes))
dist = distmat * mask.float()
loss = dist.clamp(min=1e-12, max=1e+12).sum() / (batch_size * seq_len)
return loss
参数:
num_classes:为数据集类别数
feat_dim:为特征向量维度
batch_size:为小批量样本数目
seq_len:为序列长度
下面对forward代码进行解析:
举例说明,这里假设num_classes=5。feat_dim=4,batch_size=1,seq_len=2。
输入X:[batch_size,seq_len,feat_dim]=[1,2,4]。
labels:[batch_size,seq_len]=[1,2]。
经过初始化,centers:[num_classes,feat_dim]。centers_t:[batch_size,feat_dim,num_classes]是centers经过转置并扩展得到。
经过下面代码运行,得到x_pow。
x_pow = torch.pow(x, 2).sum(dim=2, keepdim=True).expand(batch_size, seq_len, self.num_classes)
对上行代码进行解析:
x:[batch_size,seq_len,feat_dim]=[1,2,4]
其中每个向量的维度为feat_dim=4。
得到x_pow :[batch_size,seq_len,num_classes]
接着同理对centers进行操作:
center_pow = torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()
center_pow = center_pow.unsqueeze(1).expand(batch_size, seq_len, self.num_classes)
# distmat = torch.pow(x, 2).sum(dim=2, keepdim=True).expand(batch_size, seq_len, self.num_classes) + \
# torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()
distmat = x_pow + center_pow
得到center_pow:[batch_size,seq_len, num_classes]=[1,2,5]。
然后得到distmat:[batch_size,seq_len, num_classes]=[1,2,5]。
再经过如下两行代码进行矩阵运算,得到:
x_cen_mul = torch.matmul(x.float(), self.center_t.float()) # 使用矩阵乘法进行高效计算
distmat = distmat - 2 * x_cen_mul
x_cen_mul:[batch_size,seq_len,num_classes]=[1,2,5]。
得到最终的distmat为:
经过如下代码来得到mask。
classes = torch.arange(self.num_classes).long()
if self.use_gpu: classes = classes.cuda()
labels = labels.unsqueeze(2).expand(batch_size, seq_len, self.num_classes)
mask = labels.eq(classes.expand(batch_size, seq_len, self.num_classes))
标签为labels:[batch_size,seq_len]=[1,2]。假设便签值为[0,1],经历过扩展后,labels:[batch_size,seq_len,num_classes],标签值为:
classes=[0,1,2,3,4]经过扩展后,其维度为:[batch_size,seq_len,num_classes]。
得到mask:[batch_size,seq_len,num_classes]。
dist = distmat * mask.float()
loss = dist.clamp(min=1e-12, max=1e+12).sum() / (batch_size * seq_len)
将distmat与mask对应位置相乘,正好得到输入数据X与其对应中心的距离平方,最后得到平均距离之和,clamp使其在[1e-12,1e+12]的范围。内。
标签:num,seq,self,batch,----,pytorch,详解,classes,size From: https://blog.csdn.net/weixin_52680123/article/details/140666794