首页 > 其他分享 >Whisper语音识别 -- 自回归解码分析

Whisper语音识别 -- 自回归解码分析

时间:2024-06-13 22:58:54浏览次数:9  
标签:layer -- Whisper 解码 cache self token model id

前言

Whisper 是由 OpenAI 开发的一种先进语音识别系统。它采用深度学习技术,能够高效、准确地将语音转换为文本。Whisper 支持多种语言和口音,并且在处理背景噪音和语音变异方面表现出色。其广泛应用于语音助手、翻译服务、字幕生成等领域,为用户提供了更流畅的语音交互体验。作为一个开源项目,Whisper 鼓励开发者和研究人员进一步优化和创新。
在这里插入图片描述
作者将解码过程整理成 简单的python代码进行讲解

核心思想

whisper解码核心是 基于自回归解码的token游戏 ,换句话说他的参数读取是通过传入token id的形式,即采用大语言模型的prompt范式(whisper的解码器一定程度上也是个大语言模型,虽然语音训练样本token数远不及纯文本token数)
h
图中除了识别结果的框框大多数都是prompt工程, 常用的token id 如图:
在这里插入图片描述

自回归解码

在这里插入图片描述

详细解释放在代码中啦

def main():
    
    """
        解码器须构建Deocder的prompt,序列为【SOT,语种,任务】, 本文中是 model.sot_sequence
        其中SOT:50258
        语种:50332,50309,50333,50335,50273,...
        任务:transcribe 转写 50359, translate 翻译 50358
    """


    """
                加载whisper模型
    """
    encoder_onnx_file = './small-encoder.int8.onnx'
    decoder_onnx_file = './small-decoder.int8.onnx'
    tokenizer_file = './small-tokens.txt'
    model = OnnxModel(encoder_onnx_file, decoder_onnx_file)
    token_table = load_tokenizer(tokenizer_file) # token id to char 


    """
                提取MEL特征
    """
    wav_file = "output.wav"
    mel = compute_features(wav_file)


    """
                计算encoder的K/V编码 
    """
    # 交叉注意力 encoder:K/V, with decoder:Q
    n_layer_cross_k, n_layer_cross_v = model.run_encoder(mel)
    # 自注意力 decoder:K/V, with decoder:Q
    n_layer_self_k_cache, n_layer_self_v_cache = model.get_self_cache()


    """
                检测语种
    """
    lang = model.detect_language(n_layer_cross_k, n_layer_cross_v)
    model.sot_sequence[1] = lang


    """
                任务选择
    """
    # task = model.translate
    task = model.transcribe
    model.sot_sequence[2] = task
    
    
    """
                根据prompt进行首次解码
    """
    tokens = torch.tensor([model.sot_sequence], dtype=torch.int64)
    offset = torch.zeros(1, dtype=torch.int64)
    logits, n_layer_self_k_cache, n_layer_self_v_cache = model.run_decoder(
        tokens=tokens,
        n_layer_self_k_cache=n_layer_self_k_cache,
        n_layer_self_v_cache=n_layer_self_v_cache,
        n_layer_cross_k=n_layer_cross_k,
        n_layer_cross_v=n_layer_cross_v,
        offset=offset,
    )
    offset += len(model.sot_sequence)
    logits = logits[0, -1] # token 声学后验
    model.suppress_tokens(logits, is_initial=True) # 无效token后验抑制



    """
                自回归解码
    """
    max_token_id = logits.argmax(dim=-1) # 选择后验中最大输出的token【贪心解码】
    results = []
    sentence = {'start':0,'end':0,'text':b""} 
    sentences = []
    for i in range(model.n_text_ctx):

        # 打印token属性
        if max_token_id.item() == model.sot:
            print("iter:%8s docode token id:%8s [sot]"%(i,max_token_id.item()))
        elif max_token_id.item() == model.eot:
            print("iter:%8s docode token id:%8s [eot]"%(i,max_token_id.item()))
        elif max_token_id.item() >= model.timestamp_begin:
            print("iter:%8s docode token id:%8s [boundary]"%(i,max_token_id.item()))
        else:
            print("iter:%8s docode token id:%8s [char]"%(i,max_token_id.item()))
        
        # eot 结束
        if max_token_id.item() == model.eot:
            print("Finish !!")
            break

        # 检测到时间戳
        if max_token_id.item()>=model.timestamp_begin:
            timestamp = ((max_token_id.item()-model.timestamp_begin)*model.time_precision)
            # 遇到结束符
            if sentence['text']:
                sentence['end'] = timestamp
                sentence['text'] = sentence['text'].decode().strip()
                print(sentence)
                sentences.append(sentence)
                sentence = {'start':0,'end':0,'text':b""}
            # 遇到开始符
            else:
                sentence['start'] = timestamp
        else:
            decode_token = base64.b64decode(token_table[max_token_id.item()])
            sentence['text'] += decode_token


        results.append(max_token_id.item())
        tokens = torch.tensor([[results[-1]]])
        # deocder 单步解码
        logits, n_layer_self_k_cache, n_layer_self_v_cache = model.run_decoder(
            tokens=tokens,
            n_layer_self_k_cache=n_layer_self_k_cache,
            n_layer_self_v_cache=n_layer_self_v_cache,
            n_layer_cross_k=n_layer_cross_k,
            n_layer_cross_v=n_layer_cross_v,
            offset=offset,
        )
        offset += 1
        logits = logits[0, -1]
        model.suppress_tokens(logits, is_initial=False)
        max_token_id = logits.argmax(dim=-1) # 贪心搜索

没错连时间戳也是token形式~,下面是运行结果感受一下。我们在边界处对句子进行保存
在这里插入图片描述

以上就是whisper解码的基本原理,感兴趣的同学关注走一波

标签:layer,--,Whisper,解码,cache,self,token,model,id
From: https://blog.csdn.net/Ephemeroptera/article/details/139663706

相关文章

  • 公司面试题总结(五)
    25.谈一谈箭头函数与普通函数的区别,箭头函数主要解决什么问题?箭头函数与普通函数的区别:⚫语法简洁性:◼箭头函数使用=>符号定义,省略了function关键字,使得语法更为紧凑。◼对于单行函数体,可以进一步简化,省略花括号和return语句。⚫词法作用域内的this:......
  • fastjson(版本<=1.2.24)复现
    文章目录1.啥是JSON介绍:2.啥是fastjson?3.fastjson序列化/反序列化原理4.fastjson反序列化漏洞原理$复现流程:漏洞影响范围:fastjson<=1.2.24一、漏洞环境搭建二、漏洞验证方法一三、漏洞验证方法二1.啥是JSON介绍:JSON,全称:JavaScriptObjectNotation,作为一个常见的......
  • MySQL安全性管理
    用户权限管理创建和管理用户:使用CREATEUSER和GRANT语句创建和管理用户。例如:CREATEUSER'username'@'host'IDENTIFIEDBY'password';GRANTSELECT,INSERT,UPDATE,DELETEONdatabase.*TO'username'@'host';最小权限原则:只赋予用户执行其任务所需的最......
  • 24-06-13
    是否可以继承String?String类是final类,不能被继承.继承String本身就是一个错误行为,对String类型最好的重写方式是关联关系(Has-A)和依赖关系(Use-A)而不是继承关系重载(overload)和重写(override)的区别?重载的方法能否根据返回类型进行区分?方法的重载和重写都是实现多态的方式,区别在于......
  • 弹性云服务器使用公网NAT网关和直接绑定弹性公网IP有区别吗
    公网NAT网关提供SNAT和DNAT功能,可允许多台弹性云服务器共享弹性公网IP。弹性云服务器直接绑定弹性公网IP为独占IP的方式。 当同一个弹性云服务器同时设置了SNAT和弹性公网IP时,会优先使用弹性公网IP进行转发。当同一个弹性云服务器同时设置了DNAT和弹性公网IP时,入云方向的......
  • php反序列化个人笔记
    反序列化什么是反序列化?格式转换序列化:对象转换为字符串或者数组等格式反序列化:将数组或字符串转换成对象为什么会出现安全漏洞?魔术方法如何利用漏洞?通过构造pop链,找到代码的逻辑漏洞,进行getshell,rce等操作反序列化利用分为三类魔术方法的调用逻辑语言原生类的调用逻......
  • PHP正则表达式
    PHP正则表达式函数PHP正则表达式介绍正则表达式允许您搜索和替换字符串中的模式。安装PHP正则表达式函数是PHP核心的一部分。无需安装即可使用这些功能。运行时配置php.ini中的这些设置可用于限制计算正则表达式时使用的时间或资源量。名称默认值描述Changea......
  • Linux脚本语言入门.md
    0、shell介绍1)Shell是什么?Shell是一个命令行解释器,它为用户提供一个详Linux内核发送请求以便运行程序的界面系统级程序,用户可以用Shell来启动、挂起、停止甚至是编写一些程序。Shell还是一个功能相当强大的编程语言,易编写,易调试,灵活性较强。Shell是解释执行的脚本语言,在Shell中......
  • kubernetes-PV与PVC 的关系与绑定的条件
    PV:声明这个资源是一个持久卷(PV)。PVC:声明这个资源是一个持久卷声明(PVC)。创建yaml配置apiVersion:v1kind:PersistentVolume#PV是集群中的一块存储,可以由PVC请求并使用。-虚拟存储-实体机的存储、不是容器中的存储metadata:name:postgresql-pvnamespace:......
  • 第十六周周四
    以“事后诸葛亮”为模板总结会议1、我们的软件要解决什么问题?是否定义的很清楚?是否对典型用户和典型场景有清晰的描述?@主要是要方便老师学生的生活,少跑一趟取快递时间可用做其他事情,而取快递的人可以通过拿一次快递,挣一顿饭钱,方便自己方便他人;@定义得较为清楚;@主要将典型用......