首页 > 编程语言 >NumPyML 源码解析(五)

NumPyML 源码解析(五)

时间:2024-02-16 19:55:59浏览次数:24  
标签:episode self 源码 np NumPyML 解析 grad obs def

numpy-ml\numpy_ml\preprocessing\nlp.py

# 导入必要的库和模块
import re
import heapq
import os.path as op
from collections import Counter, OrderedDict, defaultdict
import numpy as np

# 定义英文停用词列表,来源于"Glasgow Information Retrieval Group"
_STOP_WORDS = set(
    ).split(" "),
)

# 定义用于匹配单词的正则表达式,用于分词
_WORD_REGEX = re.compile(r"(?u)\b\w\w+\b")  # sklearn默认
_WORD_REGEX_W_PUNC = re.compile(r"(?u)\w+|[^a-zA-Z0-9\s]")
_WORD_REGEX_W_PUNC_AND_WHITESPACE = re.compile(r"(?u)s?\w+\s?|\s?[^a-zA-Z0-9\s]\s?")

# 定义用于匹配标点符号的正则表达式
_PUNC_BYTE_REGEX = re.compile(
    r"(33|34|35|36|37|38|39|40|41|42|43|44|45|"
    r"46|47|58|59|60|61|62|63|64|91|92|93|94|"
    r"95|96|123|124|125|126)",
)
# 定义标点符号
_PUNCTUATION = "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~"
# 创建用于去除标点符号的转换表
_PUNC_TABLE = str.maketrans("", "", _PUNCTUATION)

# 定义函数,返回指定长度的n-gram序列
def ngrams(sequence, N):
    """Return all `N`-grams of the elements in `sequence`"""
    assert N >= 1
    return list(zip(*[sequence[i:] for i in range(N)]))

# 定义函数,将字符串按空格分词,可选择是否转为小写、过滤停用词和标点符号
def tokenize_whitespace(
    line, lowercase=True, filter_stopwords=True, filter_punctuation=True, **kwargs,
):
    """
    Split a string at any whitespace characters, optionally removing
    punctuation and stop-words in the process.
    """
    line = line.lower() if lowercase else line
    words = line.split()
    line = [strip_punctuation(w) for w in words] if filter_punctuation else line
    return remove_stop_words(words) if filter_stopwords else words

# 定义函数,将字符串按单词分词,可选择是否转为小写、过滤停用词和标点符号
def tokenize_words(
    line, lowercase=True, filter_stopwords=True, filter_punctuation=True, **kwargs,
):
    """
    Split a string into individual words, optionally removing punctuation and
    stop-words in the process.
    """
    REGEX = _WORD_REGEX if filter_punctuation else _WORD_REGEX_W_PUNC
    words = REGEX.findall(line.lower() if lowercase else line)
    return remove_stop_words(words) if filter_stopwords else words

# 定义函数,将字符串按字节分词
def tokenize_words_bytes(
    line,
    # 设置是否将文本转换为小写
    lowercase=True,
    # 设置是否过滤停用词
    filter_stopwords=True,
    # 设置是否过滤标点符号
    filter_punctuation=True,
    # 设置文本编码格式为 UTF-8
    encoding="utf-8",
    # **kwargs 表示接受任意数量的关键字参数,这些参数会被传递给函数的其他部分进行处理
    **kwargs,
# 将字符串拆分为单词,并在此过程中选择性地删除标点符号和停用词。将每个单词转换为字节列表。
def tokenize_words(
    line,
    lowercase=lowercase,
    filter_stopwords=filter_stopwords,
    filter_punctuation=filter_punctuation,
    **kwargs,
):
    # 对单词进行分词处理,根据参数选择是否转换为小写、过滤停用词和标点符号
    words = tokenize_words(
        line,
        lowercase=lowercase,
        filter_stopwords=filter_stopwords,
        filter_punctuation=filter_punctuation,
        **kwargs,
    )
    # 将单词转换为字节列表,每个字节用空格分隔
    words = [" ".join([str(i) for i in w.encode(encoding)]) for w in words]
    # 返回字节列表
    return words


# 将字符串中的字符转换为字节集合。每个字节用0到255之间的整数表示。
def tokenize_bytes_raw(line, encoding="utf-8", splitter=None, **kwargs):
    # 将字符串中的字符编码为字节,每个字节用空格分隔
    byte_str = [" ".join([str(i) for i in line.encode(encoding)])
    # 如果指定了分隔符为标点符号,则在编码为字节之前在标点符号处进行分割
    if splitter == "punctuation":
        byte_str = _PUNC_BYTE_REGEX.sub(r"-\1-", byte_str[0]).split("-")
    return byte_str


# 将字节(表示为0到255之间的整数)解码为指定编码的字符。
def bytes_to_chars(byte_list, encoding="utf-8"):
    # 将字节列表中的整数转换为十六进制字符串
    hex_array = [hex(a).replace("0x", "") for a in byte_list]
    # 将十六进制字符串连接起来,并在需要时在前面补0
    hex_array = " ".join([h if len(h) > 1 else f"0{h}" for h in hex_array])
    # 将十六进制字符串转换为字节数组,再根据指定编码解码为字符
    return bytearray.fromhex(hex_array).decode(encoding)


# 将字符串中的字符转换为小写,并根据参数选择是否过滤标点符号。
def tokenize_chars(line, lowercase=True, filter_punctuation=True, **kwargs):
    # 将字符串拆分为单个字符,可选择在此过程中删除标点符号和停用词
    """
    # 如果需要转换为小写,则将字符串转换为小写
    line = line.lower() if lowercase else line
    # 如果需要过滤标点符号,则调用函数去除标点符号
    line = strip_punctuation(line) if filter_punctuation else line
    # 使用正则表达式将连续多个空格替换为一个空格,并去除首尾空格,然后将结果转换为字符列表
    chars = list(re.sub(" {2,}", " ", line).strip())
    # 返回字符列表
    return chars
# 从单词字符串列表中移除停用词
def remove_stop_words(words):
    """Remove stop words from a list of word strings"""
    # 返回不在停用词列表中的单词
    return [w for w in words if w.lower() not in _STOP_WORDS]


# 从字符串中移除标点符号
def strip_punctuation(line):
    """Remove punctuation from a string"""
    # 使用_PUNC_TABLE来移除字符串中的标点符号,并去除首尾空格
    return line.translate(_PUNC_TABLE).strip()


#######################################################################
#                          Byte-Pair Encoder                          #
#######################################################################


# 定义一个Byte-Pair编码器类
class BytePairEncoder(object):
    def __init__(self, max_merges=3000, encoding="utf-8"):
        """
        A byte-pair encoder for sub-word embeddings.

        Notes
        -----
        Byte-pair encoding [1][2] is a compression algorithm that iteratively
        replaces the most frequently ocurring byte pairs in a set of documents
        with a new, single token. It has gained popularity as a preprocessing
        step for many NLP tasks due to its simplicity and expressiveness: using
        a base coebook of just 256 unique tokens (bytes), any string can be
        encoded.

        References
        ----------
        .. [1] Gage, P. (1994). A new algorithm for data compression. *C
           Users Journal, 12(2)*, 23–38.
        .. [2] Sennrich, R., Haddow, B., & Birch, A. (2015). Neural machine
           translation of rare words with subword units, *Proceedings of the
           54th Annual Meeting of the Association for Computational
           Linguistics,* 1715-1725.

        Parameters
        ----------
        max_merges : int
            The maximum number of byte pair merges to perform during the
            :meth:`fit` operation. Default is 3000.
        encoding : str
            The encoding scheme for the documents used to train the encoder.
            Default is `'utf-8'`.
        """
        # 初始化参数字典
        self.parameters = {
            "max_merges": max_merges,
            "encoding": encoding,
        }

        # 初始化字节到标记和标记到字节的有序字典。字节以十进制表示为0到255之间的整数。
        # 在255之前,标记和字节表示之间存在一对一的对应关系。
        self.byte2token = OrderedDict({i: i for i in range(256)})
        self.token2byte = OrderedDict({v: k for k, v in self.byte2token.items()})
    # 在给定语料库上训练一个字节对编码表
    def fit(self, corpus_fps, encoding="utf-8"):
        """
        Train a byte pair codebook on a set of documents.

        Parameters
        ----------
        corpus_fps : str or list of strs
            The filepath / list of filepaths for the document(s) to be used to
            learn the byte pair codebook.
        encoding : str
            The text encoding for documents. Common entries are either 'utf-8'
            (no header byte), or 'utf-8-sig' (header byte). Default is
            'utf-8'.
        """
        # 创建一个词汇表对象,用于存储字节对编码表
        vocab = (
            Vocabulary(
                lowercase=False,
                min_count=None,
                max_tokens=None,
                filter_stopwords=False,
                filter_punctuation=False,
                tokenizer="bytes",
            )
            # 在给定语料库上拟合词汇表
            .fit(corpus_fps, encoding=encoding)
            # 获取词汇表中的计数信息
            .counts
        )

        # 迭代地合并跨文档中最常见的字节二元组
        for _ in range(self.parameters["max_merges"]):
            # 获取词汇表中的字节二元组计数信息
            pair_counts = self._get_counts(vocab)
            # 找到出现次数最多的字节二元组
            most_common_bigram = max(pair_counts, key=pair_counts.get)
            # 合并最常见的字节二元组到词汇表中
            vocab = self._merge(most_common_bigram, vocab)

        # 初始化一个空集合,用于存储字节标记
        token_bytes = set()
        # 遍历词汇表中的键
        for k in vocab.keys():
            # 将键按空格分割,筛选包含"-"的字节标记
            token_bytes = token_bytes.union([w for w in k.split(" ") if "-" in w])

        # 遍历字节标记集合
        for i, t in enumerate(token_bytes):
            # 将字节标记转换为元组形式
            byte_tuple = tuple(int(j) for j in t.split("-"))
            # 将字节标记映射到对应的标记索引
            self.token2byte[256 + i] = byte_tuple
            # 将字节标记索引映射到对应的字节标记
            self.byte2token[byte_tuple] = 256 + i

        # 返回当前对象
        return self

    # 获取词汇表中的字节二元组计数信息
    def _get_counts(self, vocab):
        """Collect bigram counts for the tokens in vocab"""
        # 初始化一个默认字典,用于存储字节二元组计数
        pair_counts = defaultdict(int)
        # 遍历词汇表中的单词和计数信息
        for word, count in vocab.items():
            # 生成单词的二元组
            pairs = ngrams(word.split(" "), 2)
            # 遍历单词的二元组
            for p in pairs:
                # 更新字节二元组计数信息
                pair_counts[p] += count
        # 返回字节二元组计数信息
        return pair_counts
    # 将给定的二元组替换为单个标记,并相应更新词汇表
    def _merge(self, bigram, vocab):
        v_out = {}
        # 转义二元组中的单词,用于正则表达式匹配
        bg = re.escape(" ".join(bigram))
        # 创建匹配二元组的正则表达式
        bigram_regex = re.compile(r"(?<!\S)" + bg + r"(?!\S)")
        # 遍历词汇表中的单词
        for word in vocab.keys():
            # 将匹配到的二元组替换为连接符"-"
            w_out = bigram_regex.sub("-".join(bigram), word)
            v_out[w_out] = vocab[word]
        return v_out

    # 将文本中的单词转换为其字节对编码的标记ID
    def transform(self, text):
        """
        Transform the words in `text` into their byte pair encoded token IDs.

        Parameters
        ----------
        text: str or list of `N` strings
            The list of strings to encode

        Returns
        -------
        codes : list of `N` lists
            A list of byte pair token IDs for each of the `N` strings in
            `text`.

        Examples
        --------
        >>> B = BytePairEncoder(max_merges=100).fit("./example.txt")
        >>> encoded_tokens = B.transform("Hello! How are you 

标签:episode,self,源码,np,NumPyML,解析,grad,obs,def
From: https://www.cnblogs.com/apachecn/p/18017417

相关文章

  • NumPyML 源码解析(六)
    numpy-ml\numpy_ml\tests\test_glm.py#禁用flake8检查#导入numpy库并重命名为npimportnumpyasnp#导入statsmodels库中的api模块并重命名为smimportstatsmodels.apiassm#从numpy_ml.linear_models模块中导入GeneralizedLinearModel类fromnumpy_ml......
  • StringUtils使用与源码分析
    在apache的lang3包中有个StringUtils工具类,该工具类为开发中常用的字符串处理工具类 非空判断,isBlank和isEmpty这俩方法的形参都是charSequence字符序列。isEmpty判断这个字符序列是否为null,还有长度是否为0,如果是,则返回true,反之返回falseisBlank在isEmpty之上还有一个,如果长度......
  • Ubuntu 22.04 源码安装ST-Link V2过程详解
    一首先安装依赖工具:A安装预编译库:sudoapt-getinstallgitmakecmakelibusb-1.0-0-devB安装gcc库:sudoapt-getinstallgccbuild-essential二源码安装A下载代码gitclonehttps://github.com/stlink-org/stlink.gitB编译:cmake.makeC复制二进......
  • WPF新境界:MVVM设计模式解析与实战,构建清晰可维护的用户界面
     概述:MVVM是一种在WPF开发中广泛应用的设计模式,通过将应用程序分为模型、视图、和视图模型,实现了解耦、提高可维护性的目标。典型应用示例展示了如何通过XAML、ViewModel和数据绑定创建清晰、可测试的用户界面。什么是MVVM?MVVM(Model-View-ViewModel)是一种用于构建用户界面的......
  • scratch源码下载 | 炮轰僵尸
    程序说明:《炮轰僵尸》是一款基于Scratch平台制作的游戏程序,它采用了植物大战僵尸的经典场景。在游戏中,玩家需要控制一枚大炮来对抗不断入侵的僵尸。通过移动鼠标,玩家可以调整炮筒的方向,并在合适的时机按下鼠标左键发射炮弹,以消灭逼近的僵尸。这款游戏不仅提供了紧张刺激的游戏体......
  • scratch源码下载 | 几何冲刺
    程序说明:《几何冲刺》是一款基于Scratch平台开发的跑酷类游戏程序。在这个游戏中,玩家控制一个黄色的小方块,在快速向前冲刺的过程中躲避各种障碍物。通过按下键盘上的上方向键,玩家可以操作小方块进行跳跃,以避开途中的障碍。游戏的目标是尽可能让黄色小方块跑得更远,挑战玩家的反应......
  • scratch源码下载 | 蜘蛛传说
    程序说明:《蜘蛛传说》是一个通过Scratch平台制作的互动游戏项目。在这个故事中,玩家将扮演一只蜘蛛,其原本和平的生活被一只入侵的壁虎所打破。为了保卫自己的家园,蜘蛛必须运用智慧和勇气与壁虎对抗。游戏通过ADSW键进行移动,F键发射蜘蛛弹来攻击壁虎,但发射蜘蛛弹会消耗体力。玩家需......
  • scratch源码下载 | 飞天厨师
    程序说明:《飞天厨师》是一款使用Scratch平台制作的游戏程序。在这个游戏中,玩家将控制一名厨师角色,他在天空中不断掉落。玩家需要利用方向键左右移动厨师,以便他能够准确地踩在空中的食物上。每当厨师成功踩到食物时,他就会飞得更高。如果厨师在掉落的过程中没有踩到任何食物,游戏就......
  • Scratch源码下载 | 3D钻石
    程序说明:《3D钻石》是一个利用Scratch平台创作的独特艺术作品。此程序在屏幕上呈现一个精致的3D钻石模型,允许用户通过鼠标操作来旋转和查看钻石的不同角度。该程序还提供了修改钻石参数的功能,使用户能够自定义钻石的外观和特性。由于其复杂的3D渲染和交互设计,这个作品的制作难度......
  • idea 通过maven下载源码
    【问】如上图,IDEA中点击DownloadSource(下载源码)后,下载的源码存储到哪了? 【答】先找到此源码所属Jar包在哪;点击DownloadSource(下载源码)后,会发现存储Jar包的位置多了一个后缀带有-sources的Jar包,这就是IDEA为我们下载的源码。【问】如何找此源码所属Jar包在哪......