首页 > 其他分享 >DL学习-ctc解码

DL学习-ctc解码

时间:2023-08-04 18:35:00浏览次数:47  
标签:DL NEG nb ctc beam prefix blank INF 解码

参考基于CTC的序列模型:https://distill.pub/2017/ctc/

ctc解码方式:

  • Greedy decode,每次都选取概率最大。
  • Beam Search,对规整字符串进行束搜索算法。
  • FST Status Encode

对齐方式:

方案1:为每个输入步骤分配一个输出字符,堆叠重复的字符。

方案2:为字符添加blank用于防止hello被误解码为helo,防止output中的重复字符被折叠。

特点:1.对齐具有单调性;2.X->Y的映射是多个\(x_i\)对应一个\(y_i\);3.Y的长度不能比X长。

流程:

为了方便就直接截图了,后面其实自己结合自己理解画了一个,写完附在最后面吧。

简单的beam-search算法:

输入:

  • 模型的输出y,shapes:[sequence_length,num_outsize]
  • 束的长度beam_size,表示搜索的宽度

输出:

  • 最优的beam_size条路径。

算法思路:
1.对y求log。
2.初始化beam为[([],0)]
3.迭代sequence中的每一个下标,根据log y的数值确定出每个前缀的分数,并采用排序的方式最终获取一个最优的beam_size条路径。

结合动态规划以及Beam-Search进行搜索。
需要实现的子函数:

  • norm_prefix:将相邻两个相同的字符去重,并将非结尾部分的blank删除。例如:[1,1,2]->[1,2] [1,1,]->[1,_]


这个图是演示的相同路径下进行归并,即相同的节点进行合并。
官方给的代码为:

"""
Author: Awni Hannun
This is an example CTC decoder written in Python. The code is
intended to be a simple example and is not designed to be
especially efficient.
The algorithm is a prefix beam search for a model trained
with the CTC loss function.
For more details checkout either of these references:
  https://distill.pub/2017/ctc/#inference
  https://arxiv.org/abs/1408.2873
"""

import numpy as np
import math
import collections

NEG_INF = -float("inf")
'''这段代码定义了一个名为make_new_beam的函数,该函数没有参数。这个函数的主要功能是创建并返回一个默认字典(defaultdict)。在Python中,defaultdict是内置dict类的一个子类,它重写了一个方法并添加了一个可写的实例变量。其主要特点是在直接读取dict不存在的属性值时,直接返回默认值。

fn = lambda : (NEG_INF, NEG_INF):这行代码定义了一个匿名函数(lambda函数),这个函数没有参数,每次被调用时都会返回一个元组(NEG_INF, NEG_INF)。这里的NEG_INF可能是一个在其他地方定义的常量,表示负无穷大。
return collections.defaultdict(fn):这行代码创建了一个defaultdict,它的默认值是由上面定义的fn函数生成的。也就是说,当你试图访问这个字典中不存在的键时,它会返回(NEG_INF, NEG_INF)。
总的来说,make_new_beam函数返回的是一个默认值为(NEG_INF, NEG_INF)的字典。
'''
def make_new_beam():
  fn = lambda : (NEG_INF, NEG_INF)
  return collections.defaultdict(fn)

'''
这段代码定义了一个名为logsumexp的函数,该函数接收任意数量的参数。这个函数的主要功能是计算所给参数的log-sum-exp,这是一种在处理概率等涉及指数运算的数值时常用的技巧,可以提高数值稳定性,防止因数值过大或过小导致的溢出或下溢。

if all(a == NEG_INF for a in args): return NEG_INF:这行代码检查所有的输入参数是否都等于NEG_INF(可能是一个在其他地方定义的表示负无穷大的常量)。如果所有参数都是NEG_INF,那么函数直接返回NEG_INF。
a_max = max(args):这行代码找出所有输入参数中的最大值。
lsp = math.log(sum(math.exp(a - a_max) for a in args)):这行代码首先计算每个输入参数与最大值的差的指数,然后将这些指数值求和,最后对求和结果取对数。这是log-sum-exp的关键步骤,通过减去最大值,可以防止指数运算的结果过大导致溢出。
return a_max + lsp:最后,函数返回最大值与上一步计算的对数求和结果的和。这就是log-sum-exp的结果。
总的来说,logsumexp函数实现了一种数值稳定的方式来计算一组数的log-sum-exp。
'''
def logsumexp(*args):
  """
  Stable log sum exp.
  """
  if all(a == NEG_INF for a in args):
      return NEG_INF
  a_max = max(args)
  lsp = math.log(sum(math.exp(a - a_max)
                      for a in args))
  return a_max + lsp

def decode(probs, beam_size=100, blank=0):
  """
  Performs inference for the given output probabilities.
  Arguments:
      probs: The output probabilities (e.g. post-softmax) for each
        time step. Should be an array of shape (time x output dim).
      beam_size (int): Size of the beam to use during inference.
      blank (int): Index of the CTC blank label.
  Returns the output label sequence and the corresponding negative
  log-likelihood estimated by the decoder.
  """
  T, S = probs.shape
  probs = np.log(probs)
  '''

在CTC(Connectionist Temporal Classification)损失中,`p_blank`和`p_no_blank`是两个关键的概率值,它们分别表示在某个时间步上预测出空白标签(blank label)和非空白标签(non-blank label)的概率。

- `p_blank`:这个概率值通常由神经网络模型直接输出,表示在当前时间步上预测出空白标签的概率。在CTC中,空白标签是一个特殊的标签,用于表示没有输出或者输出与前一个时间步的输出相同。

- `p_no_blank`:这个概率值通常由1减去所有空白标签的概率得到,表示在当前时间步上预测出任何非空白标签的概率。

这两个概率值的确定通常依赖于你的神经网络模型的输出。具体的计算方法可能会根据你的模型和任务有所不同。例如,如果你的模型输出的是每个标签的概率,那么你可以直接使用这些概率作为`p_blank`和`p_no_blank`。如果你的模型输出的是每个标签的logits(即未经softmax或sigmoid函数处理的原始输出),那么你可能需要先将这些logits转换为概率,然后再计算`p_blank`和`p_no_blank`。

  '''
  # Elements in the beam are (prefix, (p_blank, p_no_blank))
  # Initialize the beam with the empty sequence, a probability of
  # 1 for ending in blank and zero for ending in non-blank
  # (in log space).
  beam = [(tuple(), (0.0, NEG_INF))]

  for t in range(T): # Loop over time

    # A default dictionary to store the next step candidates.
    next_beam = make_new_beam()

    for s in range(S): # Loop over vocab
      p = probs[t, s]

      # The variables p_b and p_nb are respectively the
      # probabilities for the prefix given that it ends in a
      # blank and does not end in a blank at this time step.
      for prefix, (p_b, p_nb) in beam: # Loop over beam

        # If we propose a blank the prefix doesn't change.
        # Only the probability of ending in blank gets updated.
        if s == blank:
'''
在CTC(Connectionist Temporal Classification)解码中,n_p_b, p_b + p, p_nb + p 是用于计算新的空白和非空白概率的中间变量。

n_p_b:这个变量表示新的空白概率,它是由当前时间步的空白概率(p_b)和非空白概率(p_nb)相加得到的。这个变量的计算反映了CTC解码的一个关键思想,即在当前时间步预测出空白标签可以由前一个时间步预测出空白标签或非空白标签两种情况转移得到。

p_b + p:这个变量表示当前时间步预测出空白标签的概率(p_b)和当前时间步的模型输出概率(p)的和。这个变量用于更新p_b,即新的空白概率。

p_nb + p:这个变量表示当前时间步预测出非空白标签的概率(p_nb)和当前时间步的模型输出概率(p)的和。这个变量用于更新p_nb,即新的非空白概率。

这三个变量的计算是CTC解码的关键步骤,它们反映了CTC解码的主要思想,即通过动态规划在所有可能的序列中找到最可能的序列。
'''
          n_p_b, n_p_nb = next_beam[prefix]
          n_p_b = logsumexp(n_p_b, p_b + p, p_nb + p)
          next_beam[prefix] = (n_p_b, n_p_nb)
          continue

        # Extend the prefix by the new character s and add it to the beam. Only the probability of not ending in blank
        # gets updated.
        end_t = prefix[-1] if prefix else None
        n_prefix = prefix + (s,)
        n_p_b, n_p_nb = next_beam[n_prefix]
        if s != end_t:
          n_p_nb = logsumexp(n_p_nb, p_b + p, p_nb + p)
        else:
          # We don't include the previous probability of not ending in blank (p_nb) if s is repeated at the end. The CTC algorithm merges characters not separated by a blank.
          # 不能在结尾连着出来俩blank
          n_p_nb = logsumexp(n_p_nb, p_b + p)
          
        # *NB* this would be a good place to include an LM score.
        next_beam[n_prefix] = (n_p_b, n_p_nb)

        # If s is repeated at the end we also update the unchanged prefix. This is the merging case.
        if s == end_t:
          n_p_b, n_p_nb = next_beam[prefix]
          n_p_nb = logsumexp(n_p_nb, p_nb + p)
          next_beam[prefix] = (n_p_b, n_p_nb)

    # Sort and trim the beam before moving on to the
    # next time-step.
    beam = sorted(next_beam.items(),
            key=lambda x : logsumexp(*x[1]),
            reverse=True)
    beam = beam[:beam_size]

  best = beam[0]
  # 返回的是数值,也可以通过best来获取其中的最优字符串
  return best[0], -logsumexp(*best[1])

if __name__ == "__main__":
  np.random.seed(3)

  time = 50
  output_dim = 20

  probs = np.random.rand(time, output_dim)
  probs = probs / np.sum(probs, axis=1, keepdims=True)

  labels, score = decode(probs)
  print("Score {:.3f}".format(score))

手语翻译采用的束搜索算法:

def search(self,probs,beam_width: int = 10,prune: float = 1e-2,blank: int = 0,lm=None,alpha=0.3):
        if lm is None:
            lm=lambda *_:1
        def mslm(l):
            if len(l)==1:
                return self.is_begining(l[-1])
            a,b=l[-2:]
            if self.is_next(a,b):
                return 1
            elif self.is_exiting(a,b):
                return lm(self.collapse(l))**alpha
            return 0
        p_b = defaultdict(Counter)
        p_nb = defaultdict(Counter)

        p_b[-1][()] = 1
        p_nb[-1][()] = 0

        prefixes = [()]

        for t in range(len(probs)):
            pruned_states, prune_relaxed = [], prune
            while not pruned_states:
                pruned_states = np.where(probs[t] >= prune_relaxed)[0].tolist()
                prune_relaxed /= 2
            pruned_states = set(pruned_states)

            for l in prefixes:
                possible_states = {blank} | pruned_states
                if l:
                    possible_states |= self.successors(l[-1])
                for s in possible_states:
                    p_t_s = probs[t,s]

                    if s == blank:
                        p_b[t][l] += p_t_s * (p_b[t - 1][l] + p_nb[t - 1][l])
                        continue

                    ls = l + (s,)
                    p_lm = mslm(ls)

                    if l and s == l[-1]:
                        # a_ + a = aa
                        p_nb[t][ls] += p_lm * p_t_s * p_b[t - 1][l]

                        # a + a = a
                        p_nb[t][l] += p_t_s * p_nb[t - 1][l]
                    else:
                        # a(_) + b = ab
                        p_nb[t][ls] += p_lm * p_t_s * (p_b[t - 1][l] + p_nb[t - 1][l])

            p = p_b[t] + p_nb[t]

            if len(p) == 0:
                p = p_b[t]  # 0 prob for all prefix

            if len(p) == 0:
                p = p_nb[t]  # 0 prob for all prefix

            prefixes = sorted(p, key=lambda k: p[k], reverse=True)
            prefixes = prefixes[:beam_width]

            # divide by a constant (min_prob) to avoid underflow
            min_prob = np.inf
            for prefix in prefixes:
                if min_prob > p[prefix] and p[prefix] > 0:
                    min_prob = p[prefix]
            for prefix in prefixes:
                # usually, min_prob won't be zero
                p_b[t][prefix] /= min_prob
                p_nb[t][prefix] /= min_prob

            if p[prefixes[0]] == 0:
                raise ValueError("Even the most probable beam has probability 0. ")

        hyp = self.collapse(prefixes[0])

        return hyp        

随手画的笔记









结语

读完还是很有收获的,英语阅读能力有待提高,读的时候累死了,对于ctcloss也不是云里雾里了,估计应付面试没大问题。

标签:DL,NEG,nb,ctc,beam,prefix,blank,INF,解码
From: https://www.cnblogs.com/D876887913/p/17600839.html

相关文章

  • 架设传奇技术教程同目录下无法找到DLL文件"KERNELBASE"处理办法
    同目录下无法找到DLL文件:"KERNELBASE"】.请与作者联系.的弹窗办法和解决架设传奇版本启动引擎或者启动没多久的时候经常遇到弹窗提示【同目录下无法找到DLL文件:"KERNELBASE"】.请与作者联系.的弹窗,如上图所示,下面我来给大家介绍下如何解决这个问题。一般出现这个问题都是windows200......
  • 基于Aidlux平台实现手机摄像头实时Canny检测
    第一步:通过Github查找作者TommyZihao,在其aidlux_tutorial工程下找到“用手机摄像头玩转OpenCV”这个项目,并以压缩包的形式下载下来。 第二步:从手机端登录Aidlux,根据Cloud_ip,获取IP地址,在电脑端进行输入,远程登录Aidlux桌面。默认密码:aidlux 第三步: 远程传输代码文件......
  • Qt 调用倍福TwinCAT通讯模块(TcAdsDll)
    Qt实现TwinCAT通讯目前这种方式是通过调用TwinCAT提供的AdsApi与倍福PLC通讯的。要求本机安装TwinCAT(无需作为主机,但是可能这个api依赖TwinCAT的一些服务)。关于AdsApi的官方资料请看这里,有函数的详细解释,还有例子。你值得拥有。https://infosys.beckhoff.com/english.php?conte......
  • tlflearn 编码解码器 ——数据降维用
     #-*-coding:utf-8-*-"""AutoEncoderExample.UsinganautoencoderonMNISThandwrittendigits.References:Y.LeCun,L.Bottou,Y.Bengio,andP.Haffner."Gradient-basedlearningappliedtodocumentrecognition."......
  • .Net Core MiddleWare
    目录作用Use第一种第二种UseMiddleWareCustomMiddleWare.csProgram.csMapMapWhen作用中间件是一种装配到应用管道以处理请求和响应的软件。每个组件:选择是否将请求传递到管道中的下一个组件。可在管道中的下一个组件前后执行工作。请求委托用于生成请求管道。请求委托......
  • 如何把.net应用程序防止他人反编译,dll打包并搭建成一个合格的安装包
    背景知识:在理论上,任何.NET程序集(.dll文件或.exe文件)都可以被反编译。C#是一种托管语言,其代码编译成中间语言(IL)或称为CIL(CommonIntermediateLanguage),然后在.NET运行时中执行。反编译工具可以将IL代码还原回C#源代码,使得原本的C#代码可以被查看和修改。 最......
  • 配置 Forwarded Headers Middleware
    来自微软的说明:ConfigureASP.NETCoretoworkwithproxyserversandloadbalancers|MicrosoftLearn。通过该中间件,会更新:HttpContext.Connection.RemoteIpAddress:使用 X-Forwarded-For 请求头的值.其它的配置会影响到中间件如何设置 RemoteIpAddress的值.消费......
  • IDEA超强XSD文件编辑插件-XSD / WSDL Visualizer
    前言XSD/WSDLVisualizer可以简化XML架构定义(XSD)和WSDL文件编辑过程;通过使用与IntelliJ无缝集成的可视化编辑器,转换处理XSD和WSDL文件的方式。告别导航复杂和难以阅读的代码的挫败感,迎接流线型和直观的体验。插件安装在线安装IntelliJIDEA可通过在线安装的方式,安装时......
  • 遇到:nodejs unhandledPromiseRejectionWarning 错误应该如何解决
    遇到"unhandledPromiseRejectionWarning"错误是因为在Node.js中,一个Promise被rejected了,但是没有被处理(handled)。这可能是因为你没有使用适当的错误处理机制,导致Promise的rejected状态没有被捕获。要解决这个问题,你可以考虑以下几个步骤:使用catch方法捕获错误:在你的......
  • c#的dllimport使用方法详解
    关于“C#的DllImport使用方法详解”的攻略如下:简介DllImport是C#中一个用于调用非托管代码的方法。它可以让我们在C#代码中调用一些使用一些C++或Win32API等编写的代码。使用方法DllImport的用法非常简单,我们只需要使用指定DllImport特性来声明一个需要调用的函数,然后在代码......