首页 > 其他分享 >构造中心损失----pytorch详解

构造中心损失----pytorch详解

时间:2024-07-27 20:27:47浏览次数:19  
标签:num seq self batch ---- pytorch 详解 classes size

当输入数据X维度为[num_classes,feat_dim]时,参考链接: Center loss-pytorch代码详解.

对于输入数据X类型为[batch_size,seq_len,feat_dim],对参考链接代码进行调整,整个代码如下:

class CenterLoss_seq(nn.Module):
    """Center loss.

    Reference:
    Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.

    Args:
        num_classes (int): number of classes.
        feat_dim (int): feature dimension.
    """

    def __init__(self, batch_size, num_classes, feat_dim, use_gpu=True):
        super(CenterLoss_seq, self).__init__()
        self.num_classes = num_classes
        self.feat_dim = feat_dim
        self.use_gpu = use_gpu
        self.batch_size = batch_size

        if self.use_gpu:
            self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda())
        else:
            self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim))

        # 计算中心矩阵的转置,以便在forward方法中使用矩阵乘法进行高效计算
        self.center_t = self.centers.t().expand(self.batch_size, self.feat_dim, self.num_classes).contiguous()

    def forward(self, x, labels):
        """
        Args:
            x: feature matrix with shape (batch_size, seq_len, feat_dim).
            labels: ground truth labels with shape (batch_size, seq_len).
        """
        batch_size = x.size(0)
        seq_len = x.size(1)

        x_pow = torch.pow(x, 2).sum(dim=2, keepdim=True).expand(batch_size, seq_len, self.num_classes)

        center_pow = torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()
        center_pow = center_pow.unsqueeze(1).expand(batch_size, seq_len, self.num_classes)

        # distmat = torch.pow(x, 2).sum(dim=2, keepdim=True).expand(batch_size, seq_len, self.num_classes) + \
        #           torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()
        distmat = x_pow + center_pow

        # center_t = self.centers.t().expand(batch_size, self.feat_dim, self.num_classes)
        # x_cen_mul = torch.bmm(x.float(), center_t.float())

        x_cen_mul = torch.matmul(x.float(), self.center_t.float())  # 使用矩阵乘法进行高效计算

        distmat = distmat - 2 * x_cen_mul

        # distmat.addmm_(1, -2, x, self.centers.t())

        classes = torch.arange(self.num_classes).long()
        if self.use_gpu: classes = classes.cuda()
        labels = labels.unsqueeze(2).expand(batch_size, seq_len, self.num_classes)
        mask = labels.eq(classes.expand(batch_size, seq_len, self.num_classes))

        dist = distmat * mask.float()
        loss = dist.clamp(min=1e-12, max=1e+12).sum() / (batch_size * seq_len)

        return loss

参数
num_classes:为数据集类别数
feat_dim:为特征向量维度
batch_size:为小批量样本数目
seq_len:为序列长度

下面对forward代码进行解析:
举例说明,这里假设num_classes=5。feat_dim=4,batch_size=1,seq_len=2。
输入X:[batch_size,seq_len,feat_dim]=[1,2,4]。
labels:[batch_size,seq_len]=[1,2]。
经过初始化,centers:[num_classes,feat_dim]。centers_t:[batch_size,feat_dim,num_classes]是centers经过转置并扩展得到。
经过下面代码运行,得到x_pow。

		x_pow = torch.pow(x, 2).sum(dim=2, keepdim=True).expand(batch_size, seq_len, self.num_classes)

对上行代码进行解析:
x:[batch_size,seq_len,feat_dim]=[1,2,4]

其中每个向量的维度为feat_dim=4。
得到x_pow :[batch_size,seq_len,num_classes]

接着同理对centers进行操作:

		center_pow = torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()
        center_pow = center_pow.unsqueeze(1).expand(batch_size, seq_len, self.num_classes)

        # distmat = torch.pow(x, 2).sum(dim=2, keepdim=True).expand(batch_size, seq_len, self.num_classes) + \
        #           torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()
        distmat = x_pow + center_pow

得到center_pow:[batch_size,seq_len, num_classes]=[1,2,5]。

然后得到distmat:[batch_size,seq_len, num_classes]=[1,2,5]。
在这里插入图片描述
再经过如下两行代码进行矩阵运算,得到:

 		x_cen_mul = torch.matmul(x.float(), self.center_t.float())  # 使用矩阵乘法进行高效计算

        distmat = distmat - 2 * x_cen_mul

x_cen_mul:[batch_size,seq_len,num_classes]=[1,2,5]。

得到最终的distmat为:
在这里插入图片描述
经过如下代码来得到mask。

		classes = torch.arange(self.num_classes).long()
        if self.use_gpu: classes = classes.cuda()
        labels = labels.unsqueeze(2).expand(batch_size, seq_len, self.num_classes)
        mask = labels.eq(classes.expand(batch_size, seq_len, self.num_classes))

标签为labels:[batch_size,seq_len]=[1,2]。假设便签值为[0,1],经历过扩展后,labels:[batch_size,seq_len,num_classes],标签值为:

classes=[0,1,2,3,4]经过扩展后,其维度为:[batch_size,seq_len,num_classes]。

得到mask:[batch_size,seq_len,num_classes]。

 		dist = distmat * mask.float()
        loss = dist.clamp(min=1e-12, max=1e+12).sum() / (batch_size * seq_len)

将distmat与mask对应位置相乘,正好得到输入数据X与其对应中心的距离平方,最后得到平均距离之和,clamp使其在[1e-12,1e+12]的范围。内。

标签:num,seq,self,batch,----,pytorch,详解,classes,size
From: https://blog.csdn.net/weixin_52680123/article/details/140666794

相关文章

  • 从输入 URL 到页面展示到底发生了什么?
    在浏览器输入网址后,浏览器会先解析URL,解析出域名、资源路径、端口等信息,准备发送HTTP请求,检查浏览器缓存是否有缓存该资源,如果有就直接返回;没有的话就进入下一步网路请求;接着进行DNS域名解析,来获取请求域名的IP地址,如果请求协议是HTTPS,那么还会需要建立TLS连接,DNS......
  • CentOS配置NTP服务
     更改配置文件[root@Controller~]#vim/etc/chrony.conf重启服务并设置为开机自启动[root@Controller~]#systemctlrestartchronyd.service[root@Controller~]#systemctlenablechronyd.service在另一台CentOS测试更改配置文件[root@Compute~]#vim/etc......
  • Android中Service学习记录
    目录一概述二生命周期2.1启动服务startService()2.2绑定服务bindService()2.3先启动后绑定2.4先绑定后启动三使用3.1本地服务(启动式)3.2可通信的服务(绑定式)3.3前台服务3.4IntentService总结参考一概述Service组件一般用来执行长期在后台的任务,如播放音......
  • 基于SSM技术的珠宝销售系统的设计与实现/在线销售系统/在线购物平台/Java
    摘 要随着互联网技术和国内珠宝行业持续快速地发展,管理员为了可以更为便捷地管理珠宝销售线上服务,珠宝销售系统被开发出的目地是为了可以更为便捷管理珠宝销售线上服务。该系统介绍了珠宝销售系统的功能和特点。可以帮助珠宝在线商城管理销售、订单和客户信息等。它还可以......
  • 计算机Java项目|基于SpringBoot的智能无人仓库管理的设计与实现
    作者主页:编程指南针作者简介:Java领域优质创作者、CSDN博客专家、CSDN内容合伙人、掘金特邀作者、阿里云博客专家、51CTO特邀作者、多年架构师设计经验、多年校企合作经验,被多个学校常年聘为校外企业导师,指导学生毕业设计并参与学生毕业答辩指导,有较为丰富的相关经验。期待与......
  • 邦布带你从零开始实现图书管理系统(java版)
    今天我们来从零开始实现图书管理系统。图书管理系统来看我们的具体的实现,上述视频。我们首先来实现框架,我们要实现图书管理系统,首先要搭框架。我们首先定义一个书包,在书包中定义一个书类和一个书架类,再定义一个用户包,其中包含用户类,管理者类,普通用户类,在定义一个工具包......
  • 基于springboot的球鞋销售及鞋迷交流系统的开发与实现 /WEB
    摘要计算机网络与信息化管理相配合,可以有效地提高管理人员的工作效能和改进工作的质量。良好的球鞋销售及鞋迷交流系统可以使管理员工作得到更好的管理和应用,并有助于管理员更好地管理球鞋销售及鞋迷交流,并有助于解决人力管理中出现的差错等问题。因此一套好的球鞋销售及鞋......
  • 基于微信小程序的生态农场系统设计与实现 /农场管理平台
    摘  要近年来,随着网络产业的飞速发展,人们的日常生活和工作方式也随之发生变化。许多生态农场正在把常规的工作方式与因特网相融合,借助因特网的力量来提升管理者的工作能力。当前很多生态农场系统工作都有很多问题,所以针对生态农场系统的实际情况,提出可以针对生态农场系统的......
  • SpringbBoot的运动鞋交易系统/交易网站/Java/web
    摘要近年来,随着网络产业的飞速发展,人们的日常生活和工作方式也随之发生变化。各行各业正在把常规的工作方式与因特网相融合,于是,网上交易系统亦应运而生。与传统的店铺销售相比,网上运动鞋店具有方便、快捷、信息畅通的特点,交易环节的缩减,使交易成本大为降低,消费者选择购物的......
  • 算法板子:滑动窗口——应用单调队列,找到窗口中的最小值与最大值
    #include<iostream>usingnamespacestd;constintN=1e6+10;inta[N];//q数组模拟单调队列;q数组存储原数组元素的下标;//递增单调队列的队头始终维护窗口中的最小值;队头存的是窗口中最小值的下标//递减单调队列的队头始终维护窗口中的最大值;队头存的......