首页 > 其他分享 >CTC的训练与推理之Greedy Decoder, Beam Search,CTC Loss, RNNT Loss

CTC的训练与推理之Greedy Decoder, Beam Search,CTC Loss, RNNT Loss

时间:2023-01-09 22:12:30浏览次数:49  
标签:Loss Search inputs pred decoder token CTC tf inf


模型流程:t时刻输入x(t),t-1时刻输出y`(t-1), t时刻的输出y`(t)为:由x(t)和y`(t-1)作为输入得到的预测值
训练:采用Teacher Forcing的策略,在t时刻,并不是使用上一时刻的预测值y`(t-1)作为输出,而是把实际的label即y(t-1)作为输入。
           输出y`(t)为是所有token的概率分布(embedding),label即y(t)相当于一个One-hot向量, 两者通过交叉熵计算损失。
推理:采用Greedy Decoder, Beam Search搜索策略,在t时刻,使用上一时刻的预测值y`(t-1)和x(t)作为输出,预测输出y`(t)。
          但是由于上一时刻的输出y`(t-1)是token的概率分布(embedding),Greedy Decoder每次预测时使用概率值最大值对应的id,Beam Search每一步都保留beam_size个最大值。

注意

这里的Greedy Decoder, Beam Search好像与tf的tf.nn.ctc_greedy_decoder, tf.nn.ctc_beam_search_decoder不同,tf的输入是所有时刻的输出概率,在该矩阵上进行搜索。

tensorflow:

# 1 tf.nn.ctc_greedy_decoder
inf = float("inf")
logits = tf.constant([[[   0., -inf, -inf],
                       [ -2.3, -inf, -0.1]],
                      [[ -inf, -0.5, -inf],
                       [ -inf, -inf, -0.1]],
                      [[ -inf, -inf, -inf],
                       [ -0.1, -inf, -2.3]]])
seq_lens = tf.constant([2, 3])
outputs = tf.nn.ctc_greedy_decoder(
    logits,
    seq_lens,
    blank_index=1)


#2 tf.nn.ctc_beam_search_decoder
tf.nn.ctc_beam_search_decoder(
    inputs, sequence_length, beam_width=100, top_paths=1
)

实际的GreedyDecoder:

def GreedyDecode(self, inputs, inputs_length):
    
    batch_size = inputs.size(0)

    enc_states, _ = self.encoder(inputs, inputs_length)

    zero_token = torch.LongTensor([[0]])
    if inputs.is_cuda:
        zero_token = zero_token.cuda()

    def decode(enc_state, lengths):
        token_list = []
        dec_state, hidden = self.decoder(zero_token)

        for t in range(lengths):
            logits = self.joint(enc_state[t].view(-1), dec_state.view(-1))
            # out = F.softmax(logits, dim=0)
            # pred = torch.argmax(out, dim=0)
            out = F.log_softmax(logits, dim=0)
            prob, pred = torch.max(out, dim=0)
            pred = int(pred.item())

            if pred != 0:
                token_list.append(pred)
                token = torch.LongTensor([[pred]])

                if enc_state.is_cuda:
                    token = token.cuda()

                dec_state, hidden = self.decoder(token, hidden=hidden)

        return token_list

    results = []
    for i in range(batch_size):
        decoded_seq = decode(enc_states[i], inputs_length[i])
        results.append(decoded_seq)

    return results

 

   

标签:Loss,Search,inputs,pred,decoder,token,CTC,tf,inf
From: https://www.cnblogs.com/3511rjzn/p/17038668.html

相关文章