首页 > 其他分享 >DashText-进阶使用

DashText-进阶使用

时间:2024-10-30 10:46:51浏览次数:4  
标签:SparseVectorEncoder 进阶 doc over fox encoder DashText score 使用

前置知识

BM25简介

BM25算法(Best Matching 25)是一种广泛用于信息检索领域的排名函数,用于在给定查询(Query)时对一组文档(Document)进行评分和排序。BM25在计算Query和Document之间的相似度时,本质上是依次计算Query中每个单词和Document的相关性,然后对每个单词的相关性进行加权求和。BM25算法一般可以表示为如下形式:

上式中, qd 分别表示用来计算相似度的Query和Document, q *i*表示 q 的第 i 个单词, R(q *i* , d) 表示单词 q *i*和文档 d 的相关性, W *i*表示单词 q *i*的权重,计算得到的 score(q, d) 表示 qd 的相关性得分,得分越高表示 qd 越相似。 W *i*R(q *i* , d) 一般可以表示为如下形式:

其中, N 表示总文档数, N(q *i* ) 表示包含单词 q *i*的文档数, tf(q *i* , d) 表示 q *i*在文档 d 中的词频, L *d*表示文档 d 的长度, L *avg*表示平均文档长度, k *1*b 是分别用来控制 tf(q *i* , d)L *d*对得分影响的超参数。

稀疏向量生成

在检索场景中,为了让BM25算法的Score方便进行计算,通常分别对Document和Query进行编码,然后通过 点积 的方式计算出两者的相似度。得益于BM25原理的特性,其原生支持将Score拆分为两部分Sparse Vector,DashText提供了encode_document以及encode_query两个接口来分别实现这两部分向量的生成,其生成链路如下图所示:

最终生成的稀疏向量可表示为:

Score/距离计算

生成 dq 的稀疏向量后,就可以通过简单的点积进行距离计算,即将相同单词上的值对应相乘再求和,通过稀疏向量计算距离的方式如下所示:

上述计算方式本质上是通过点积来计算的, score 越大表示越相似,如果需要结合Dense Vector一起进行距离度量时,需要对齐距离度量方式。也就是说,在结合Dense Vector+Sparse Vector的场景中,距离计算只支持点积度量方式。

如何自训练模型

考虑到内置的BM25 Model是基于通用语料(中文Wiki语料)训练得到,在特定领域下通常不能表现出最佳的效果。因此,在一些特定场景下,通常建议训练自定义BM25模型。使用DashText来训练自定义模型时一般需要遵循以下步骤:

Step1:确认使用场景

当准备使用SparseVector来进行信息检索时,应提前考虑当前场景下的Query以及Document来源,通常需要提前准备好一定数量Document来入库,这些Document通常需要和特定的业务场景直接相关。

Step2:准备语料

根据BM25原理,语料直接决定了BM25模型的参数。通常应按照以下几个原则来准备语料:

  • 语料来源应尽可能反映对应场景的特性,尽可能让 N(q *i* ) 能够反映对应真实场景的词频信息。

  • 调节合理的语料切片长度和切片数量,避免出现语料当中只有少量长文本的情况。

一般情况下,如无特殊要求或限制,可以直接将Step1准备的一系列Document组织为语料即可。

Step3:准备Tokenizer

Tokenizer决定了分词的结果,分词的结果则直接影响Sparse Vector的生成,在特定领域下使用自定义Tokenizer会达到更好的效果。DashText提供了两种扩展Tokenizer的方式:

  • 使用自定义词表:DashText内置的Jieba Tokenizer支持传入自定义词表。(Java SDK暂不支持该功能)

Python示例:

from dashtext import TextTokenizer, SparseVectorEncoder

my_tokenizer = TextTokenizer.from_pretrained(model_name='Jieba', dict='dict.txt')
my_encoder = SparseVectorEncoder(tokenize_function=my_tokenizer.tokenize)
  • 使用自定义Tokenizer:DashText支持任务自定义的Tokenizer,只需提供一个符合Callable[[str], List[str]]签名的Tokenize函数即可。

Python示例:

from dashtext import SparseVectorEncoder
from transformers import BertTokenizer

my_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
my_encoder = SparseVectorEncoder(tokenize_function=my_tokenizer.tokenize)

Java示例:

import com.aliyun.dashtext.common.DashTextException;
import com.aliyun.dashtext.common.ErrorCode;
import com.aliyun.dashtext.encoder.SparseVectorEncoder;
import com.aliyun.dashtext.tokenizer.BaseTokenizer;

import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;

public class Main {
    public static class MyTokenizer implements BaseTokenizer {
        @Override
        public List<String> tokenize(String s) throws DashTextException {
            if (s == null) {
                throw new DashTextException(ErrorCode.INVALID_ARGUMENT);
            }

            // 使用正则表达式将文本按空白符和标点符号分割,并转换为小写
            return Arrays.stream(s.split("\\s+|(?<!\\d)[.,](?!\\d)"))
                    .map(String::toLowerCase)
                    .filter(token -> !token.isEmpty())   // 过滤掉空字符串
                    .collect(Collectors.toList());
        }
    }

    public static void main(String[] args) {
        SparseVectorEncoder encoder = new SparseVectorEncoder(new MyTokenizer());
    }
}

Step4:训练模型

实际上,这里的"训练"本质上是一个"统计"参数的过程。由于训练自定义模型的过程中包含着大量Tokenizing/Hashing过程,所以可能会耗费一定的时间。DashText提供了SparseVectorEncoder.train接口可以用来训练模型。

Step5:调参优化(可选)

模型训练完成后,可以准备部分验证数据集以及通过微调 k *1*b 来达到最佳的召回效果。调节k1和b一般需要遵循以下原则:

  • 调节 k *1*(1.2 < k *1* < 2)可控制Document词频对Score的影响, k *1*越大Document的词频对Score的贡献越小。

  • 调节 b (0 < b < 1)可控制文档长度对Score的影响, b 越大表示文档长度对权重的影响越大

一般情况下,如无特殊要求或限制,不需要调整 k *1*b

Step6:Finetune模型(可选)

实际场景下,可能会存在需要补充训练语料来增量式地更新BM25模型参数的情况。DashText的SparseVectorEncoder.train接口原生支持模型的增量更新。需要注意的是,模型更改之后,使用旧模型进行编码并已入库的向量就失去了时效性,一般需要重新入库。

示例代码

以下是一个简单完整的自训练模型示例。
Python示例:

from dashtext import SparseVectorEncoder
from pydantic import BaseModel
from typing import Dict, List


class Result(BaseModel):
    doc: str
    score: float


def calculate_score(query_vector: Dict[int, float], document_vector: Dict[int, float]) -> float:
    score = 0.0
    for key, value in query_vector.items():
        if key in document_vector:
            score += value * document_vector[key]
    return score


# 创建空SparseVectorEncoder(可以设置自定义Tokenizer)
encoder = SparseVectorEncoder()


# step1: 准备语料以及Documents
corpus_document: List[str] = [
    "The quick brown fox rapidly and agilely leaps over the lazy dog that lies idly by the roadside.",
    "Never jump over the lazy dog quickly",
    "A fox is quick and jumps over dogs",
    "The quick brown fox",
    "Dogs are domestic animals",
    "Some dog breeds are quick and jump high",
    "Foxes are wild animals and often have a brown coat",
]


# step2: 训练BM25 Model
encoder.train(corpus_document)


# step3: 调参优化BM25 Model
query: str = "quick brown fox"
print(f"query: {query}")
k1s = [1.0, 1.5]
bs = [0.5, 0.75]
for k1, b in zip(k1s, bs):
    print(f"current k1: {k1}, b: {b}")
    encoder.b = b
    encoder.k1 = k1
    query_vector = encoder.encode_queries(query)
    results: List[Result] = []
    for idx, doc in enumerate(corpus_document):
        doc_vector = encoder.encode_documents(doc)
        score = calculate_score(query_vector, doc_vector)
        results.append(Result(doc=doc, score=score))
    results.sort(key=lambda r: r.score, reverse=True)

    for result in results:
        print(result)


# step4: 选择最优参数并保存模型
encoder.b = 0.75
encoder.k1 = 1.5
encoder.dump("./model.json")


# step5: 后续使用时可以加载模型
new_encoder = SparseVectorEncoder()
bm25_model_path = "./model.json"
new_encoder.load(bm25_model_path)


# step6: 对模型进行finetune并保存
extra_corpus: List[str] = [
    "The fast fox jumps over the lazy, chubby dog",
    "A swift fox hops over a napping old dog",
    "The quick fox leaps over the sleepy, plump dog",
    "The agile fox jumps over the dozing, heavy-set dog",
    "A speedy fox vaults over a lazy, old dog lying in the sun"
]

new_encoder.train(extra_corpus)
new_bm25_model_path = "new_model.json"
new_encoder.dump(new_bm25_model_path)

Java示例:

import com.aliyun.dashtext.encoder.SparseVectorEncoder;

import java.io.*;
import java.util.*;

public class Main {

    public static class Result {
        public String doc;
        public float score;

        public Result(String doc, float score) {
            this.doc = doc;
            this.score = score;
        }

        @Override
        public String toString() {
            return String.format("Result(doc=%s, score=%f)", doc, score);
        }
    }

    public static float calculateScore(Map<Long, Float> queryVector, Map<Long, Float> documentVector) {
        float score = 0.0f;
        for (Map.Entry<Long, Float> entry : queryVector.entrySet()) {
            if (documentVector.containsKey(entry.getKey())) {
                score += entry.getValue() * documentVector.get(entry.getKey());
            }
        }
        return score;
    }

    public static void main(String[] args) throws IOException {
        // 创建空SparseVectorEncoder(可以设置自定义Tokenizer)
        SparseVectorEncoder encoder = new SparseVectorEncoder();

        // step1: 准备语料以及Documents
        List<String> corpusDocument = Arrays.asList(
                "The quick brown fox rapidly and agilely leaps over the lazy dog that lies idly by the roadside.",
                "Never jump over the lazy dog quickly",
                "A fox is quick and jumps over dogs",
                "The quick brown fox",
                "Dogs are domestic animals",
                "Some dog breeds are quick and jump high",
                "Foxes are wild animals and often have a brown coat"
        );

        // step2: 训练BM25 Model
        encoder.train(corpusDocument);

        // step3: 调参优化BM25 Model
        String query = "quick brown fox";
        System.out.println("query: " + query);
        float[] k1s = {1.0f, 1.5f};
        float[] bs = {0.5f, 0.75f};
        for (int i = 0; i < k1s.length; i++) {
            float k1 = k1s[i];
            float b = bs[i];
            System.out.println("current k1: " + k1 + ", b: " + b);
            encoder.setB(b);
            encoder.setK1(k1);

            Map<Long, Float> queryVector = encoder.encodeQueries(query);
            List<Result> results = new ArrayList<>();
            for (String doc : corpusDocument) {
                Map<Long, Float> docVector = encoder.encodeDocuments(doc);
                float score = calculateScore(queryVector, docVector);
                results.add(new Result(doc, score));
            }

            results.sort((r1, r2) -> Float.compare(r2.score, r1.score));

            for (Result result : results) {
                System.out.println(result);
            }
        }

        // step4: 选择最优参数并保存模型
        encoder.setB(0.75f);
        encoder.setK1(1.5f);
        encoder.dump("./model.json");

        // step5: 后续使用时可以加载模型
        SparseVectorEncoder newEncoder = new SparseVectorEncoder();
        newEncoder.load("./model.json");

        // step6: 对模型进行finetune并保存
        List<String> extraCorpus = Arrays.asList(
                "The fast fox jumps over the lazy, chubby dog",
                "A swift fox hops over a napping old dog",
                "The quick fox leaps over the sleepy, plump dog",
                "The agile fox jumps over the dozing, heavy-set dog",
                "A speedy fox vaults over a lazy, old dog lying in the sun"
        );
        newEncoder.train(extraCorpus);
        newEncoder.dump("./new_model.json");
    }
}

API参考

DashText API详情可参考:https://pypi.org/project/dashtext/

标签:SparseVectorEncoder,进阶,doc,over,fox,encoder,DashText,score,使用
From: https://www.cnblogs.com/DashVector/p/18515422

相关文章

  • Jenkins使用maven打包项目
    Jenkins使用maven打包项目作为一名软件测试工程师,在日常工作中,我们经常需要使用Jenkins进行持续集成和持续部署(CI/CD)。而Maven作为Java项目的构建工具,更是不可或缺。今天,我将向大家介绍如何在Jenkins中使用Maven打包项目。一、准备工作登录Jenkins后,点击ManageJenkins->Tool......
  • react.js中何时使用useCallback
    useMemo用于记住值,减少重新渲染组件所需的时间。useCallback用于记住函数,通常是为了防止组件的重新渲染举例子组件接收回调函数作为 props父组件引入子组件:constgetList=useCallback(()=>fetch(`http://example.com/api/${userId}`),[userId],);return(<buttonon......
  • 网站有多个域名,使用哪种类型的SSL证书?
    当网站拥有多个域名时,可以选择以下几种类型的SSL证书来满足安全需求:一、多域名SSL证书(SAN证书)定义:多域名SSL证书,也被称为SAN(SubjectAlternativeName)证书或UCC(UnifiedCommunicationsCertificate)证书,是一种特殊的SSL证书类型,可以保护一个主域名以及多个其他附属域名。这些......
  • idea从新建一个maven项目到打包成可运行jar包全流程供接口测试签名使用
     1创建maven项目点击new-project 选择左侧的mavenArchetype修改Name,JDK,Catalog,Archetype(org.apache.maven.archetypes:maven-archetype-webapp)为下图中配置 修改地址(自选),版本号(自选),之后点击create 2配置maven在settings中找到下图中maven的位置,并自定义maven包,......
  • GitLab代码仓管理安装配置使用
    Gitlab介绍GitLab是一个基于Git的开源项目管理工具,它集成了版本控制、代码审查、持续集成(CI)/持续部署(CD)、自动化测试等多种功能,是一个完整的DevOps平台。以下是对GitLab的详细介绍:一、主要特点和功能版本控制系统:GitLab的核心是基于Git的版本控制系统,支持代码的版本管理、分......
  • 《使用Gin框架构建分布式应用》阅读笔记:p251-p271
    《用Gin框架构建分布式应用》学习第14天,p251-p271总结,总21页。一、技术总结1.Docker&DockerComposeversion:"3.9"services:api:image:apienvironment:-MONGO_URI=mongodb://admin:password@mongodb:27017/test?authSource=admin&readPreference=p......
  • 使用最小二乘法进行线性回归(Python)
    已知测得某块地,当温度处于15至40度之间时,数得某块草地上小花朵的数量和温度值的数据如下表所示。现在要来找出这些数据中蕴含的规律,用来预测其它未测温度时的小花朵的数量。测得数据如下图所示:importmatplotlib.pyplotaspltimportnumpyasnptemperatures=[15,20,......
  • 如何使用MD5校验系统文件完整性?
    1、首先,我们先了解一下什么是MD5?很多朋友并不是很了解MD5是什么,针对这个问题,我们来做一下简单的介绍。MD5为计算机安全领域广泛使用的一种散列函数,用以提供文件的完整性保护。简单来说就是用来校验文件在下载过程中是否损坏。2、为什么要对系统文件进行MD5校验呢?经常碰到......
  • 内网穿透:基本概念和使用技巧
    一、为什么要使用内网穿透:内网穿透也称内网映射,简单来说就是让外网可以访问你的内网:把自己的内网(主机)当做服务器    让外网访问简而言之,就是我们在自己计算机上运行的程序,别人也可以通过公网直接访问,这样可以在项目发布到云服务器前,提供一个公网地址给用户进行体......
  • 在Windows环境下使用AMD显卡运行Stable Diffusion
    现在用的电脑是21年配的,当时并没有AI相关的需求,各种各样的原因吧,抉择后选择了AMD的显卡,但在2024年的今天,使用AI进行一些工作已不再是什么罕见的需求,所以我也想尝试一下,但发现AMD显卡却处处碰壁,研究后发现,经过各方面的努力,AMD显卡在AI方面的支持已经有了很大的进步,......