1语言模型生成文本的顺序
-
前面我们已经能够实现使用下图的LSTM网络进行语言建模;
-
对于一个已经在语料库上学习好的LSTM模型;如果语料库就只是you say goobye and i say hello;那么当把单词i输入到模型中,Time xxx层的第一个LSTM层将会输出一个概率分布,这个概率分布中概率最大的那个对应的单词应该是say;如下图所示;
-
上图的情况是按照选择概率最大的那个作为当前输出的文本的;因此可以说是“确定的”,只要模型训练的好,他都会倾向于输出那些准确的单词;
-
那么我们就可以接着以这个层的输出的单词作为下一个时刻的输入,输出下一个时刻的概率分布,然后选择概率最大的,如下图所示;以此类推,就可以一直输出下去;我们可以人为控制什么时候停止,也可以设定当模型输出的下一个单词是特殊结束符号标记的时候就终止;如果模型训练的好,模型觉得要说完了,就会输出一个特殊结束符号标记,从而自动停止;
-
另一种方法是以这个概率分布作为准则,随机的从词库中选择单词;就是之前说的
np.choice
方法;这样一来:- 因为具有随机性,因此被选到的单词(被采样到的单词)每次都不一样
- 但是选到的单词又具有倾向性,即概率高的单词容易被选到,概率低的单词难以被选到。
-
这种引入随机性的方法,可以让模型生成训练数据中没有的文本,即新的文本;不过构成这些新文本的单词依然是那么些个单词,毕竟训练集里面只有这些单词;
2使用改进前的LSTM语言模型尝试生成文本
- 这里使用的是未改进的LSTMLM模型;且不使用训练好的权重;只是看一下生成的过程;
- 代码位于:RNN_generate/RNNLM_gen.py · GY/basicNLP - 码云 - 开源中国 (gitee.com);
-
生成的过程可以描述为:
- 给初始的单词,reshape之后输入到模型中;
- 由于Time xxx层是权重共享的,因此可以根据输入的数据的维度调整T值;
- 那么这里每次都输入一个单词,用上图中不带Time的层来输出一个概率分布;
- 以一定概率选择要预测的单词,然后将这个单词再次作为输入,输入到上图中不带Time的层,再输出一个概率分布;以此类推;直到全部输出完;
-
这里直接继承了未改进的
Rnnlm
类,然后实现了生成函数;生成函数代码如下:- model 的
predict()
方法进行的是 mini-batch 处理,所以输入x
必须是二维数组。因此,即使在只输入 1 个单词 ID 的情况下,也要将它的批大小视为 1,并将其整理成形状为 1 × 1 的 NumPy 数组
class RnnlmGen(Rnnlm): def generate(self, start_id, skip_ids=None, sample_size=100): ''' @param start_id: 第一个单词的ID @param skip_ids: 不生成的ID;用于排除一些填充符之类的 @param sample_size: 生成的长度 @return:生成的文本''' word_ids = [start_id] x = start_id while len(word_ids) < sample_size: # x = np.array(x).reshape(1, 1) if GPU: x = cupy.array(x).reshape(1, 1) else: x = np.array(x).reshape(1, 1) score = self.predict(x) # (N,T,V);这里是(1,1,10000) p = softmax(score.flatten()) # score.flatten()展平成一维的;softmax函数中设置了一维的计算方式 if GPU: sampled = cupy.random.choice(len(p), size=1, p=p) else: sampled = np.random.choice(len(p), size=1, p=p) if (skip_ids is None) or (sampled not in skip_ids): x = sampled word_ids.append(int(x[0])) return word_ids
- model 的
-
以下是生成文本;因为没有训练,所以杂乱无序;
you fired designing indianapolis counsel calgary readers reviewed wright shouting underlying existing agip frankfurt depress interstate steelmakers natural weeks begins gatt stiff delivering telesis grounds boards stream louisiana breed sample indexing acquiring commentary hired al philip blast helping dictaphone attention confusion auditors beaten arbitrage ii scholars forecasting monopolies burke fit spacecraft takeover-stock engineering aftershocks arise shipbuilding minivans along recalls bone recreational year may disappears sixth motivated monitors understanding swing previously coupon expects difference plo remain attendants sullivan kansas peninsula patent skeptical fields galileo blackstone battered steps anger fusion mandatory mca trains postal forest-products scrapped faa censorship tea building tests milton
3使用改进后的LSTM语言模型尝试生成文本
这里使用之前训练好的权重;权重位于:BetterRnnlm.pkl · GY/basicNLP - 码云 - 开源中国 (gitee.com);
代码位于:RNN_generate/betterRNNLM_gen.py · GY/basicNLP - 码云 - 开源中国 (gitee.com);
还演示了给模型一句话或者一句话开头几个单词,然后让其续写的方法;
-
这里的改进即前面说的LSTM多层化、embedding层和Affine层的权重共享,以及在纵向上加入dropout层;
-
继承了
BetterRnnlm
类,然后实现generate
函数;这个generate
函数与2使用改进前的LSTM语言模型尝试生成文本
小节的generate
函数一样; -
以下是生成的结果:
- 可以看到,训练好的模型生成的文本流畅多了;
you said. in the event of falling environmental prices and the rapid growth of revenues rate for the third quarter abc has up from almost every time says frozen president of fox 's pro conn. i obtained a courthouse for other new york series on its job virginia to do with mr. spielvogel and the task and the merged bank business. according to the usx spokesman bernard and other major investors have declined to support this profit from personal management. delmed. the matter was held by berry 's family in five years. mr. nadeau
-
目前为止我们只是给模型第一个单词,然后让模型预测之后的单词;那么如果我们希望给模型一句话或者一句话开头几个单词,然后让其续写呢?
-
方法是:先将前几个单词依次输入到模型中;这样在LSTM层进行计算时会将隐藏状态保存在类的成员变量self.h中;但是要记得设置stateful=True,这样才能继承前面计算的隐藏状态;
-
之后,再将最后一个单词输入到模型中,依次获取模型输出即可;
-
核心代码如下:
if __name__ == '__main__': start_words = 'the meaning of life is' start_ids = [word_to_id[w] for w in start_words.split(' ')] for x in start_ids[:-1]: if GPU: x = cupy.array(x).reshape(1, 1) else: x = np.array(x).reshape(1, 1) model.predict(x) # 文本生成 word_ids = model.generate(start_ids[-1], skip_ids) txt = ' '.join([id_to_word[i] for i in word_ids]) txt = txt.replace(' <eos>', '.\n') print(txt)
-
以下是一次输出的文本:
the meaning of life is not a nightmare in many of the newspapers. the solution will be shipped although washington 's future is n't likely to adopt the first changes in the world. our state bailout is very large payments he says. on the in the corporate market the analyst said the arrest by the new york borough president had arranged less than one million shares in the event. and by selling their stake in rico the public home short-term market system is a one-time candidate for an investment bank. it 's time to worry that customers will need