首页 > 编程语言 >LLM大模型: FlagEmbedding-BiEncoderModel原理和源码解析

LLM大模型: FlagEmbedding-BiEncoderModel原理和源码解析

时间:2024-06-19 18:21:21浏览次数:10  
标签:LLM self reps BiEncoderModel 源码 embedding scores query target

  NLP常见的任务之一是高效检索:在大规模语料库中快速检索与查询相关的段落或文档;用户输入query,要在语料库中找到语义最接近、最匹配的回答!此外,还有文本分类、情感分析等下游任务需要先把文本的embedding求出来,这些功能都能通过"双塔结构"(Bi-Encoder)实现!核心思路很简单:用两个不同的encoder分别求出query的embedding和answer的embedding,然后求两种embedding之间的距离(cosin或dot product都行),找到距离topK的embedding作为最合适的answer即可!存储和查找topK的向量可以借助专业的向量数据库,比如FAISS等!

  1、setence转embedding的方法:这里提供了两种方式,求平均和取第一个cls token的embedding代表整个句子的embedding

    def sentence_embedding(self, hidden_state, mask):
        if self.sentence_pooling_method == 'mean':
            s = torch.sum(hidden_state * mask.unsqueeze(-1).float(), dim=1)
            d = mask.sum(axis=1, keepdim=True).float()
            return s / d
        elif self.sentence_pooling_method == 'cls':
            return hidden_state[:, 0]

  除了上述两种方式,其实还有另外两种可取:

# 最大池化
max_embedding = outputs.last_hidden_state.max(dim=1).values

# 拼接多种表示
concatenated_embedding = torch.cat([cls_embedding, mean_embedding, max_embedding], dim=1)

  2、把input转成embedding向量

    def encode(self, features):
        if features is None:
            return None
        psg_out = self.model(**features, return_dict=True)#先把input通过model的forward求embedding
        p_reps = self.sentence_embedding(psg_out.last_hidden_state, features['attention_mask'])#再求整个句子的embedding
        if self.normlized:#归一化,利于下一步求cosin或dot product
            p_reps = torch.nn.functional.normalize(p_reps, dim=-1)
        return p_reps.contiguous()

  3、求相似度:就是query和passage两个矩阵相乘,本质还是dot product

    def compute_similarity(self, q_reps, p_reps):
        if len(p_reps.size()) == 2:
            return torch.matmul(q_reps, p_reps.transpose(0, 1))
        return torch.matmul(q_reps, p_reps.transpose(-2, -1))

  4、这个loss更简单了:直接就是cross entropy!

    def compute_loss(self, scores, target):
        return self.cross_entropy(scores, target)

  5、最核心的就是forward方法了:

    def forward(self, query: Dict[str, Tensor] = None, passage: Dict[str, Tensor] = None, teacher_score: Tensor = None):
        q_reps = self.encode(query)#两个encoder分别求embedding,这是模型叫Bi双塔的原因
        p_reps = self.encode(passage)

        if self.training:
            if self.negatives_cross_device and self.use_inbatch_neg:
                q_reps = self._dist_gather_tensor(q_reps)
                p_reps = self._dist_gather_tensor(p_reps)

            group_size = p_reps.size(0) // q_reps.size(0)
            if self.use_inbatch_neg:#计算两个embedding之间的相似度
                scores = self.compute_similarity(q_reps, p_reps) / self.temperature # B B*G
                scores = scores.view(q_reps.size(0), -1)

                target = torch.arange(scores.size(0), device=scores.device, dtype=torch.long)
                target = target * group_size
                loss = self.compute_loss(scores, target)#计算loss
            else:
                scores = self.compute_similarity(q_reps[:, None, :,], p_reps.view(q_reps.size(0), group_size, -1)).squeeze(1) / self.temperature # B G

                scores = scores.view(q_reps.size(0), -1)
                target = torch.zeros(scores.size(0), device=scores.device, dtype=torch.long)
                loss = self.compute_loss(scores, target)

        else:
            scores = self.compute_similarity(q_reps, p_reps)
            loss = None
        return EncoderOutput(
            loss=loss,
            scores=scores,
            q_reps=q_reps,
            p_reps=p_reps,
        )

  只看代码感觉很抽象,这里详细介绍一下整个流程:

  (1)假设下面是训练样本:

data = [
    {
        "query": "How does one become an actor in the Telugu Film Industry?", 
        "pos": ["How do I become an actor in film industry?"], 
        "neg": ["What is the story of Moses and Ramesses?", "Does caste system affect economic growth of India?"]
    },
    {
        "query": "Why do some computer programmers develop amazing software or new concepts, while some are stuck with basic programming work?", 
        "pos": ["Why do some computer programmers develops amazing softwares or new concepts, while some are stuck with basics programming works?"], 
        "neg": ["When visiting a friend, do you ever think about what would happen if you did something wildly inappropriate like punch them or destroy their furniture?", "What is the difference between a compliment and flirting?"]
    }
]

  (2)query和回答会被分开单独转成token_ids,回答叫passage,如下:(注意,这里的token编号只是示意,不一定对,只是为了说明流程和原理)!

query = {
    'input_ids': tensor([[101, 2129, 2515, 2028, 2468, 2019, 4449, 1999, 1996, 10165, 2143, 3068, 102], 
                         [101, 2339, 2079, 2070, 3274, 13193, 3285, 12460, 13191, 3021, 1997, 3749, 2135, 102]]),  # 两个查询的示例
    'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 
                              [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
}

passage = {
    'input_ids': tensor([
        [101, 2129, 2079, 1045, 2468, 2019, 4449, 1999, 10165, 2143, 3068, 102],  # 第一个查询的正样本
        [101, 2339, 2079, 2070, 3274, 13193, 3285, 12460, 13191, 3021, 1997, 3749, 2135, 102],  # 第二个查询的正样本
        [101, 2054, 2003, 1996, 2466, 1997, 7929, 1998, 10500, 1029, 102],  # 第一个查询的负样本1
        [101, 2515, 9397, 3600, 7462, 3599, 2964, 4100, 3600, 2630, 1997, 2290, 1029, 102],  # 第一个查询的负样本2
        [101, 2043, 6188, 1037, 2767, 1010, 2079, 2017, 2412, 2228, 2055, 2054, 2052, 2490, 2017, 2106, 1037, 10723, 21446, 2066, 7059, 2068, 2030, 5620, 2037, 4192, 1029, 102],  # 第二个查询的负样本1
        [101, 2054, 2003, 1996, 4487, 2090, 1037, 9994, 1998, 18095, 1029, 102]  # 第二个查询的负样本2
    ]),
    'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 
                              [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 
                              [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 
                              [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 
                              [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 
                              [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
}

  (3)Bi双塔就体现在这里了:query和passage分别调用encoder求整个句子的embedding

q_reps = self.encode(query)  # 形状 [2, embedding_dim],表示两个query的embedding
p_reps = self.encode(passage)  # 形状 [6, embedding_dim],表示六个passage的embedding

  (4)求query和passage之间的相似度:

scores = self.compute_similarity(q_reps, p_reps) / self.temperature

  这步很关键,假设结果如下(注意:数值不一定对,这里只是说明流程和原理):

scores = tensor([
    [0.8, 0.1, 0.3, 0.2, 0.4, 0.7],  # 第一个query与所有passage的相似度分数
    [0.5, 0.6, 0.4, 0.9, 0.2, 0.3]   # 第二个query与所有passage的相似度分数
])

  上面是query和所有passage的相似度,哪些passage才是pos,哪些是neg了?这个需要区分开来吧,是如下方式做的:

target = torch.arange(2, device=scores.device, dtype=torch.long)  # [0, 1]
target = target * 3  # [0 * 3, 1 * 3] => [0, 3]

  这样一来,target就包含了正确pos回答的位置了,如下;

  • target[0] = 0 表示第一个查询的正确段落是 passage[0]
  • target[1] = 3 表示第二个查询的正确段落是 passage[3]

    (5)上面的所有的铺垫和准备工作都已完成,最后一步就是coss entropy求loss了:scores和target之间要尽量对齐一致,由于target包含了pos正确的回答,所以scores对应的正确回答pos的维度数值要尽量大,其他neg维度数值要尽量小,这就从loss端区分开了pos和neg答案啦

loss = cross_entropy(scores, target)

  最后有个疑问:target包含了正样本pos位置,和scores求cross entropy,本质是通过target选择scores中最合理的维度求极值,这么来看,貌似负样本neg好像没用上?

 

参考:

1、https://github.com/FlagOpen/FlagEmbedding

2、https://www.bilibili.com/video/BV1sQ4y137Ft/?spm_id_from=pageDriver&vd_source=241a5bcb1c13e6828e519dd1f78f35b2

3、https://huggingface.co/BAAI/bge-m3

4、https://www.53ai.com/news/qianyanjishu/816.html

 

标签:LLM,self,reps,BiEncoderModel,源码,embedding,scores,query,target
From: https://www.cnblogs.com/theseventhson/p/18256405

相关文章

  • springboot小型超市商品展销系统-计算机毕业设计源码01635
    摘 要科技进步的飞速发展引起人们日常生活的巨大变化,电子信息技术的飞速发展使得电子信息技术的各个领域的应用水平得到普及和应用。信息时代的到来已成为不可阻挡的时尚潮流,人类发展的历史正进入一个新时代。在现实运用中,应用软件的工作规则和开发步骤,采用Springboot框架建......
  • SSM医院线上线下全诊疗系统-计算机毕业设计源码02210
    目 录摘要1绪论1.1背景及意义1.2研究现状1.3ssm框架介绍1.4论文结构与章节安排2 医院线上线下全诊疗系统系统分析2.1可行性分析2.1.1技术可行性分析2.1.2经济可行性分析2.1.3法律可行性分析2.2系统功能分析2.2.1功能性分析2.2.2非功能......
  • springboot防疫知识科普系统-计算机毕业设计源码03531
    摘 要如今计算机行业的发展极为快速,搭载于计算机软件运行的数据库管理系统在各行各业得到了广泛的运用,其在数据管理方面具有的准确性和高效性为大中小企业的日常运营提供了巨大的帮助。自从2020年新冠疫情爆发以来,防疫成了社会关注的重中之重,在防疫管理中,一开始对防疫的管......
  • SSM图书借阅管理系统-计算机毕业设计源码06780
    摘 要大数据时代下,数据呈爆炸式地增长。为了迎合信息化时代的潮流和信息化安全的要求,利用互联网服务于其他行业,促进生产,已经是成为一种势不可挡的趋势。在图书馆的要求下,开发一款整体式结构的图书借阅管理系统,将复杂的系统进行拆分,能够实现对需求的变化快速响应、系统稳定性......
  • 一文搞定 大语言模型(LLM)微调方法
    引言众所周知,大语言模型(LLM)正在飞速发展,各行业都有了自己的大模型。其中,大模型微调技术在此过程中起到了非常关键的作用,它提升了模型的生成效率和适应性,使其能够在多样化的应用场景中发挥更大的价值。那么,今天这篇文章就带大家深入了解大模型微调。其中主要包括什么是大......
  • 人大这波操作666! 国内首本中文版的LLM大语言模型入门指南!(附PDF)
    我就知道人大还留有后手。自从这篇中文大模型综述发布以后,在全网收到了一致好评。人大这边也一直没闲着,在后续一年之内修改了十多遍,收录了近千篇的参考文献,快马加鞭赶出了这本大语言模型中文版。一经发布就震惊国内高校和研究人员,是更适合中国体制的大模型指南。本书内容......
  • 基于SpringBoot+Vue的高校爱心捐赠系统设计与实现(源码+lw+部署+讲解)
    文章目录前言详细视频演示具体实现截图技术可行性分析技术简介后端框架SpringBoot前端框架Vue系统开发平台系统架构设计业务流程分析为什么选择我们自己的公众号(一点毕设)海量实战案例代码参考数据库参考源码及文档获取前言......
  • useEffect 的原理是什么,怎么使用,源码的逻辑是怎么样的
    useEffect的原理useEffect的原理是基于React组件的生命周期函数。当组件的props或state发生变更时,会触发一个更新循环。在这个更新循环中,会调用useEffect中的函数,即根据组件中获取的变更信息来执行useEffect中定义的操作。useEffect允许开发人员在组件生命周期中执行副作用......
  • 小型企业人事管理系统java ssm mysql|全套源码+文章lw+毕业设计+课程设计+数据库+ppt
    小型企业人事管理系统javassmmysql|全套源码+文章lw+毕业设计+课程设计+数据库+ppt小型企业人事管理系统的设计与实现【摘要】:人才是企业发展的核心力量,所以人事管理是企业管理中一项重要的任务。传统的人事管理系统不仅效率慢而且极易出错,使管理者不能清楚的了解每一位......
  • 【CS.SE】从源码到实践:探索日常对话的生成性语音模型ChatTTS
    文章目录1项目介绍1.1功能与特色2技术分析2.1模型架构3项目实践3.1快速上手4项目总结ReferencesGitcode上有许多优秀的开源项目,今天我们要介绍的是一个令人耳目一新的项目——ChatTTS。ChatTTS是一个基于深度学习的文本转语音(TTS)系统,它的目标是通过先进......