前言:
在一些序列生成任务中,比如 seq2seq 的机器翻译模型,或者是验证码识别的 CTC 算法中,输出的每一个时间步都会有一个分布。最终的序列会使用 BeamSearch或者 Viterbi等算法搜索 Top-K 概率的序列。这类方法介于逐时间步 argmax 的完全贪心策略和全局动态规划的优化策略之间。
常见的一些搜索算法:
局部最优-贪心策略,Greedy Decoder
全局搜索
Beam Search,Viterbi
Random Categorical + Softmax with Temperature
背景:
nlp模型预测下一个单词时输出的是一个概率分布logits,即模型可以预测的每个单词(token)的概率值,因此logits长度一般是字典(词汇表)的长度,在BERT中一般为30000到40000不等,GPT-2中般为50257。
全局搜索需要预测全部可能的结果,不太现实。
2 Random Categorical + Softmax with Temperature
tf.random.set_seed(1) logits = [[1.0, 1.0, 1.0]] print('Probabilities:', tf.math.softmax(logits).numpy()[0]) samples = tf.random.categorical( logits=logits, num_samples=10) tf.print(samples.numpy()) # Probabilities: [0.33333334 0.33333334 0.33333334] # array([[0, 0, 1, 2, 0, 0, 0, 0, 1, 0]])
tf.random.set_seed(1) logits = [[1.0, 1.0, 3.0]] print('Probabilities:', tf.math.softmax(logits).numpy()[0]) samples = tf.random.categorical( logits=logits, num_samples=10) tf.print(samples.numpy()) # Probabilities: [0.10650698 0.10650698 0.78698605] # array([[2, 0, 2, 2, 2, 0, 1, 2, 2, 0]])
def sample(model, starting_str, len_generated_text=500, max_input_length=40, scale_factor=1.0): encoded_input = [char2int[s] for s in starting_str] encoded_input = tf.reshape(encoded_input, (1, -1)) generated_str = starting_str model.reset_states() for i in range(len_generated_text): logits = model(encoded_input) logits = tf.squeeze(logits, 0) scaled_logits = logits * scale_factor new_char_indx = tf.random.categorical( scaled_logits, num_samples=1) new_char_indx = tf.squeeze(new_char_indx)[-1].numpy() generated_str += str(char_array[new_char_indx]) new_char_indx = tf.expand_dims([new_char_indx], 0) encoded_input = tf.concat( [encoded_input, new_char_indx], axis=1) encoded_input = encoded_input[:, -max_input_length:] return generated_str tf.random.set_seed(1) print(sample(model, starting_str='The island'))
The island is open he heard the victory of the Mercy, and brought it into them, and they no longer continue, some on the little man of the felting circle of slopes. The engineer troused, he could not find our companions. Chapter 11 At this position, he might just as if his first true to be finished, and he though not more I can this teles.” “Why shall fear line,” answered the reporter, “what a disposal silence was advanced with them, and in masterspon. Before three heights of the Frenchant Heights
1 ngram去重
example 1
输入:
ngram_size = 3
states = [4, 8, 8, 4, 3, 4, 8 ] # 下一个token不能是8
probs = [0.2, 0.6, 0.7, 0.4, 0.2, 0.1, 0.5, 0.3, 0.9, 0.4] # shape:(vocab_size,),对应的token index设置成-inf
结果:
probs = [0.2, 0.6, 0.7, 0.4, 0.2, 0.1, 0.5, 0.3, -inf, 0.4]
example 2
输入:
ngram_size = 3
states = [1, 2, 1, 2, 5, 1, 2 ] # 下一个token不能是1和5, [1, 2, 1, 2, 5, 1, 2 ]或 [1, 2, 1, 2, 5, 1, 2 ]
probs = [0.2, 0.6, 0.7, 0.4, 0.2, 0.1, 0.5, 0.3, 0.9, 0.4] # shape:(vocab_size,), 对应的token index设置成-inf
结果:
probs = [0.2, -inf, 0.7, 0.4, 0.2, -inf, 0.5, 0.3, 0.9, 0.4]
example 3:
在beam search过程中,输入probs的shape是(batch_size, beam_width, vocab_size),states(batch_size, beam_width, i) ,
对于一个batch_size的数据来说,是同一个句子的下一个单词(token)概率分布预测,beam search保留了beam_width个最优的结果,
i是指该句子前i个已输出的token,小于句子长度即可。
输入:
batch_size = 2
beam_width = 5
ngram_size = 3
states =
[[[4, 8, 8, 4, 3, 4, 8],
[2, 8, 9, 5, 7, 5, 5],
[2, 9, 2, 2, 9, 4, 5],
[5, 7, 5, 5, 8, 3, 4],
[4, 7, 1, 2, 5, 9, 5]],
[[4, 1, 7, 8, 1, 9, 8],
[1, 2, 1, 2, 5, 1, 2],
[7, 9, 2, 1, 3, 5, 6],
[2, 2, 7, 8, 6, 1, 9],
[3, 1, 7, 1, 5, 7, 4]]]
probs =
[[[0.06988439 0.88953143 0.3388861 0.8003634 0.41536968 0.23019396 0.96064572 0.5049323 0.17197636 0.01114495]
[0.26413354 0.0027116 0.48181481 0.67516154 0.98633325 0.41732901 0.3308262 0.32395439 0.93045618 0.73636922]
[0.13610143 0.2878246 0.06102439 0.90753978 0.36238731 0.85826691 0.06957903 0.95829313 0.59740139 0.06498013]
[0.47543078 0.82744895 0.36380256 0.90964788 0.673106 0.0384001 0.84760034 0.22287195 0.10001 0.76660941]
[0.39213985 0.8131306 0.2783355 0.67879836 0.07235815 0.87530527 0.27077796 0.6315962 0.02997279 0.06629175]]
[[0.43232198 0.86879328 0.49806399 0.00536409 0.55031861 0.89780367 0.50707047 0.68293632 0.62203018 0.57622393]
[0.24464949 0.46303155 0.37299519 0.03162689 0.77467507 0.14801423 0.07873908 0.18735025 0.99573359 0.5650082 ]
[0.97870774 0.44381718 0.77745946 0.60062697 0.3214976 0.98093789 0.43354787 0.98535251 0.11714246 0.09250022]
[0.14141 0.07578034 0.9494816 0.82918715 0.4199323 0.42469995 0.90658367 0.96065581 0.83701996 0.93619639]
[0.83922269 0.6241601 0.3264292 0.15436493 0.80988114 0.99374556 0.95503546 0.7865531 0.16768138 0.56712871]]]
结果:
probs =
[[[0.06988439 0.88953143 0.3388861 0.8003634 0.41536968 0.23019396 0.96064572 0.5049323 -inf 0.01114495]
[0.26413354 0.0027116 0.48181481 0.67516154 0.98633325 0.41732901 0.3308262 0.32395439 0.93045618 0.73636922]
[0.13610143 0.2878246 0.06102439 0.90753978 0.36238731 0.85826691 0.06957903 0.95829313 0.59740139 0.06498013]
[0.47543078 0.82744895 0.36380256 0.90964788 0.673106 0.0384001 0.84760034 0.22287195 0.10001 0.76660941]
[0.39213985 0.8131306 0.2783355 0.67879836 0.07235815 0.87530527 0.27077796 0.6315962 0.02997279 0.06629175]]
[[0.43232198 0.86879328 0.49806399 0.00536409 0.55031861 0.89780367 0.50707047 0.68293632 0.62203018 0.57622393]
[0.24464949 -inf 0.37299519 0.03162689 0.77467507 -inf 0.07873908 0.18735025 0.99573359 0.5650082 ]
[0.97870774 0.44381718 0.77745946 0.60062697 0.3214976 0.98093789 0.43354787 0.98535251 0.11714246 0.09250022]
[0.14141 0.07578034 0.9494816 0.82918715 0.4199323 0.42469995 0.90658367 0.96065581 0.83701996 0.93619639]
[0.83922269 0.6241601 0.3264292 0.15436493 0.80988114 0.99374556 0.95503546 0.7865531 0.16768138 0.56712871]]]
标签:nlp,技术细节,后处理,str,encoded,tf,input,logits,size From: https://www.cnblogs.com/3511rjzn/p/17306601.html