首页 > 编程语言 >HanLP — HMM隐马尔可夫模型 -- 维特比(Viterbi)算法 --完整示例代码

HanLP — HMM隐马尔可夫模型 -- 维特比(Viterbi)算法 --完整示例代码

时间:2024-01-17 15:44:44浏览次数:21  
标签:matrix 示例 -- text self Viterbi states state emit

完成代码

import pickle
from tqdm import tqdm
import numpy as np
import os


def make_label(text_str):
    """从单词到label的转换, 如: 今天 ----> BE  麻辣肥牛: ---> BMME  的 ---> S"""
    text_len = len(text_str)
    if text_len == 1:
        return "S"
    return "B" + "M" * (text_len - 2) + "E"  # 除了开头是 B, 结尾是 E,中间都是M


def text_to_state(train_file, state_file):
    """ 将原始的语料库转换为 对应的状态文件 """
    if os.path.exists(state_file):  # 如果存在该文件, 就直接退出
        os.remove(state_file)
    # 打开文件并按行切分到  all_data 中 , all_data  是一个list
    all_data = open(train_file, "r", encoding="utf-8").read().split("\n")
    with open(state_file, "w", encoding="utf-8") as f:  # 代开写入的文件
        for d_index, data in tqdm(enumerate(all_data)):  # 逐行 遍历 , tqdm 是进度条提示 , data 是一篇文章, 有可能为空
            if data:  # 如果 data 不为空
                state_ = ""
                for w in data.split(" "):  # 当前 文章按照空格切分, w是文章中的一个词语
                    if w:  # 如果 w 不为空
                        state_ = state_ + make_label(w) + " "  # 制作单个词语的label
                if d_index != len(all_data) - 1:  # 最后一行不要加 "\n" 其他行都加 "\n"
                    state_ = state_.strip() + "\n"  # 每一行都去掉 最后的空格
                f.write(state_)


# 定义 HMM类, 其实最关键的就是三大矩阵
class HMM:
    def __init__(self, file_text, file_state):
        self.all_states = open(file_state, "r", encoding="utf-8").read().split("\n")[:200]  # 按行获取所有的状态
        self.all_texts = open(file_text, "r", encoding="utf-8").read().split("\n")[:200]  # 按行获取所有的文本
        self.states_to_index = {"B": 0, "M": 1, "S": 2, "E": 3}  # 给每个状态定义一个索引, 以后可以根据状态获取索引
        self.index_to_states = ["B", "M", "S", "E"]  # 根据索引获取对应状态
        self.len_states = len(self.states_to_index)  # 状态长度 : 这里是4

        # 初始矩阵 : 1 * 4 , 对应的是 BMSE,
        self.init_matrix = np.zeros((self.len_states))
        # 转移状态矩阵:  4 * 4 ,
        self.transfer_matrix = np.zeros((self.len_states, self.len_states))

        # 发射矩阵, 使用的 2级 字典嵌套,
        # # 注意这里初始化了一个  total 键 , 存储当前状态出现的总次数, 为了后面的归一化使用
        self.emit_matrix = {"B": {"total": 0}, "M": {"total": 0}, "S": {"total": 0}, "E": {"total": 0}}

    # 计算 初始矩阵,统计每行第一个字出现的频次
    def cal_init_matrix(self, state):
        self.init_matrix[self.states_to_index[state[0]]] += 1  # BMSE 四种状态, 对应状态出现 1次 就 +1

    # 计算 转移矩阵,当前状态到下一状态的概率
    def cal_transfer_matrix(self, states):
        sta_join = "".join(states)  # 状态转移 从当前状态转移到后一状态, 即 从 sta1 每一元素转移到 sta2 中
        sta1 = sta_join[:-1]
        sta2 = sta_join[1:]
        for s1, s2 in zip(sta1, sta2):  # 同时遍历 s1 , s2  -- (('B', 'E'), ('E', 'B'), ('B', 'E'), ('E', 'S'), ('S', 'B'), ('B', 'E'), ('E', 'S'))
            self.transfer_matrix[self.states_to_index[s1], self.states_to_index[s2]] += 1

    # 计算 发射矩阵,在特定状态下,出现某个字的概率
    def cal_emit_matrix(self, words, states):
        """计算 发射矩阵,在特定状态下,出现某个字的概率
        [
          '今天 天气 真 不错 。',
          '麻辣肥牛 好吃 !',
          '我 喜欢 吃 好吃 的 !'
        ]
        [
          'BE BE S BE S',
          'BMME BE S',
          'S BE S BE S S '
        ]
        {
          'B': {'total': 7, '今': 1, '天': 1, '不': 1, '麻': 1, '好': 2, '喜': 1},
          'M': {'total': 2, '辣': 1, '肥': 1},
          'S': {'total': 7, '真': 1, '。': 1, '!': 2, '我': 1, '吃': 1, '的': 1},
          'E': {'total': 7, '天': 1, '气': 1, '错': 1, '牛': 1, '吃': 2, '欢': 1}
        }
        """
        # print(tuple(zip("".join(words), "".join(states))))
        for word, state in zip("".join(words), "".join(states)):  # 先把words 和 states 拼接起来再遍历, 因为中间有空格
            self.emit_matrix[state][word] = self.emit_matrix[state].get(word, 0) + 1
            self.emit_matrix[state]["total"] += 1  # 注意这里多添加了一个  total 键 , 存储当前状态出现的总次数, 为了后面的归一化使用

    # 将矩阵归一化
    def normalize(self):
        self.init_matrix = self.init_matrix / np.sum(self.init_matrix)
        self.transfer_matrix = self.transfer_matrix / np.sum(self.transfer_matrix, axis=1, keepdims=True)
        self.emit_matrix = {state: {word: t / word_times["total"] * 1000 for word, t in word_times.items() if word != "total"} for state, word_times in
                            self.emit_matrix.items()}

    # 训练开始, 其实就是3个矩阵的求解过程
    def train(self):
        for words, states in tqdm(zip(self.all_texts, self.all_states)):  # 按行读取文件, 调用3个矩阵的求解函数
            words = words.split(" ")  # 在文件中 都是按照空格切分的
            states = states.split(" ")
            self.cal_init_matrix(states[0])  # 初始矩阵,统计每行第一个字出现的频次 [2. 0. 1. 0.]
            self.cal_transfer_matrix(states)  # 转移矩阵,当前状态到下一状态的概率
            self.cal_emit_matrix(words, states)  # 发射矩阵,在特定状态下,出现某个字的概率
        self.normalize()  # 矩阵求完之后进行归一化
        pickle.dump([self.init_matrix, self.transfer_matrix, self.emit_matrix], open("data/three_matrix.pkl", "wb"))  # 保存参数


def viterbi_t(text, hmm):
    states = hmm.index_to_states
    emit_p = hmm.emit_matrix
    trans_p = hmm.transfer_matrix
    start_p = hmm.init_matrix
    V = [{}]
    path = {}
    for y in states:
        V[0][y] = start_p[hmm.states_to_index[y]] * emit_p[y].get(text[0], 0)
        path[y] = [y]
    for t in range(1, len(text)):
        V.append({})
        newpath = {}

        # 检验训练的发射概率矩阵中是否有该字
        neverSeen = text[t] not in emit_p['S'].keys() and \
                    text[t] not in emit_p['M'].keys() and \
                    text[t] not in emit_p['E'].keys() and \
                    text[t] not in emit_p['B'].keys()
        for y in states:
            emitP = emit_p[y].get(text[t], 0) if not neverSeen else 1.0  # 设置未知字单独成词
            temp = []
            for y0 in states:
                if V[t - 1][y0] > 0:
                    temp.append((V[t - 1][y0] * trans_p[hmm.states_to_index[y0], hmm.states_to_index[y]] * emitP, y0))
            (prob, state) = max(temp)
            # (prob, state) = max([(V[t - 1][y0] * trans_p[hmm.states_to_index[y0],hmm.states_to_index[y]] * emitP, y0)  for y0 in states if V[t - 1][y0] > 0])
            V[t][y] = prob
            newpath[y] = path[state] + [y]
        path = newpath

    (prob, state) = max([(V[len(text) - 1][y], y) for y in states])  # 求最大概念的路径

    result = ""  # 拼接结果
    for t, s in zip(text, path[state]):
        result += t
        if s == "S" or s == "E":  # 如果是 S 或者 E 就在后面添加空格
            result += " "
    return result


if __name__ == "__main__":
    train_file = "data/train_data.txt"
    state_file = "data/train_state.txt"
    text_to_state(train_file, state_file)
    hmm = HMM(train_file, state_file)
    hmm.train()
    text = "今天的天气不错"
    result = viterbi_t(text, hmm)

    print(result) 

源代码: https://gitee.com/VipSoft/VipPython/tree/master/hmm_viterbi
视频:代码讲解 https://www.bilibili.com/video/BV1aP4y147gA?p=11

标签:matrix,示例,--,text,self,Viterbi,states,state,emit
From: https://www.cnblogs.com/vipsoft/p/17970179

相关文章

  • VT-X的学习历程(一)
    学习的目标就是如何实现一个简单VT框架并拦截指令的调用以及EPTHOOK的实现。大概的流程检测是否允许开启VT。a.我们可以从白皮书的24.6DISCOVERINGSUPPORTFORVMX章节中得到这样的信息b.其次就是设置smxc.检测CPUID是否支持VTcpuid第5位是否为1Define.h#prag......
  • [极客大挑战 2019]Knife 1
    [极客大挑战2019]Knife1审题没啥好审的,给出eval($_POST["Syc"]);一句话木马了知识点蚁剑连接一句话木马。做题蚁剑连接测试成功后打开找到flag。......
  • 智能反截屏控制:数据安全防护新利器
    在现代数字化企业中,数据安全已成为关键业务需求。尽管企业在数据保护方面投入了大量资源,但截图泄露敏感信息仍然是难以防范的风险。智能反截屏控制技术为企业提供了全新的数据保护解决方案,通过实时保护敏感内容,大大降低了截图泄露的风险。智能反截屏控制技术通过以下方式实现:智......
  • 练习题1
    使用面向对象的思想,编写自定义描述狗的信息。设定属性包括:品种,年龄,心情,名字;方法包括:叫,跑。要求:1)设置属性的私有访问权限,通过公有的get,set方法实现对属性的访问2)限定心情只能有“心情好”和“心情不好”两种情况,如果无效输入进行提示,默认设置“心情好”。3)设置构造......
  • 解决openssh无法登录的问题
    背景在安装完openssh之后,还是不能解决登录的问题。报错信息如下:ITISPOSSIBLETHATSOMEONEISDOINGSOMETHINGNASTY!Someonecouldbeeavesdroppingonyourightnow(man-in-the-middleattack)!Itisalsopossiblethatahostkeyhasjustbeenchanged.Thefinge......
  • 洛谷题单指南-模拟和高精度-P1042 [NOIP2003 普及组] 乒乓球
    原题链接:https://www.luogu.com.cn/problem/P1042题意解读:分别针对11分制和21分制,输出每局比分。只需要判断一局的结束条件:得分高者如果达到11或者21,且比分间隔大于等于2分,则表示一局结束,可开始下一局,用模拟法即可解决。100分代码:#include<bits/stdc++.h>usingnamespaces......
  • init
    init进程是所有Linux进程的父进程补充说明init命令是Linux下的进程初始化工具,init进程是所有Linux进程的父进程,它的进程号为1。init命令是Linux操作系统中不可缺少的程序之一,init进程是Linux内核引导运行的,是系统中的第一个进程。语法init(选项)(参数)选项-b:不执行相关脚......
  • 常见错误记录之连接MySQL8.0(Navicate Premium 12,出现BigInteger错误)
    一、NavicatePremium12连接MySQL8.0包如下错误: 出错原因:mysql8之前的版本中加密规则为mysql_native_passwordmysql8以后的加密规则为caching_sha2_password解决方法:(1)更新navicat驱动来解决此问题(2)将mysql用户登录的加密规则常用第二种方法:1.用管理员权限打开cmd,输入mysql......
  • SP839Optimal Marks 题解
    part1:建图二进制异或,每一位互不干扰。所以对每一位分开来考虑。然后变成了一个经典的模型。当前每一个未确定点有两个选择:变成\(1\),变成\(0\);已经确定的点只能选它本身的值。于是构造思路非常套路了:构造虚点\(S\)、\(T\)。对于一个点\(u\),从\(S\)连向\(u\)一条边,值为......
  • 构建智算时代的云原生应用平台,2023 云原生产业大会,阿里云在这里!
    2023 信通院云原生产业大会顺利举办。在云原生技术规模化应用的关键时期,云原生前沿技术趋势、云原生技术的应用现代化建设、大模型的云原生算力供给、云原生安全防护新思路、行业应用实践等等都成为从业者的关注焦点。在云原生产业大会主论坛上,阿里云云原生应用平台资深产品总监......