Codingame 散列表为主题的练习题中,马尔科夫链文本生成吸引到了我的注意力。它集合了马尔科夫链,状态机和散列表三个方面的学习内容。其中,n-gram马尔科夫链运用到了文本聊天机器人的设计中,还是蛮有启发性的,应该是chatgpt之前的一项经典技术。下面简单讲讲这个编程练习题。
目标
制作一个游戏,让NPC说话,即使这是荒谬的。因为懒得写下所有荒谬的表述,所以你决定创造一个文本生成器。幸运的是,你有一系列的文本来作为训练数据。你将构建一个典型的n-gram马尔科夫链。请研究什么是马尔科夫链链,什么是n-gram,怎么应用。
一个例子
文本 t= one fish is good and no fish is bad and that is it
对应的n-gram深度为2,可以按照如下步骤生成马尔科夫链:
Step 1 : 'one fish' => ['is']
Step 2 : 'fish is' => ['good']
Step 3 : 'is good' => ['and']
Step 4 : 'good and' => ['no']
Step 5 : 'and no' => ['fish']
Step 6 : 'no fish' => ['is']
注意到步骤2中的'fish is',因此step 7中添加新值到列表的末尾
Step 7 : 'fish is' => ['good','bad']
Step 8 : 'is bad => ['and']
如此处理,直到文本t
的末尾。
现在,我们可以生成文本了。对于长度为5的输出,seed文本是fish is
,我们可以随机生成如下的文本:
- fish is good and no
- fish is bad and that
因为走到'fish is'时,我们可以随机选择'good'或者'bad'。其他的文本都是确定的。
重复性
如果n-gram马尔科夫链特定状态的下一个状态,采用以下的伪代码来“随机”选择下一个状态。
random_seed = 0
function pick_option_index( num_of_options ) {
random_seed += 7
return random_seed % num_of_options
}
在上面的例子里,第一次查询返回['good','bad']。有两个选项。调用pick_option_index(2)
返回7%2 = 1。因此,我们在输出文本的末尾添加'bad'。针对只有一个选项的情况,也调用此函数。
Solution
按照叙述的要求,算法大致分两个部分,一是由输入的文本和深度参数生成n-gram马尔科夫链,二是由seed文本出发查询n-gram马尔科夫链补齐单词。
#include <iostream>
#include <vector>
#include <map>
#include <unordered_map>
#include <cstdlib>
#include <algorithm>
using namespace std;
int random_seed = 0;
// pick an option index randomly using a predetermined algorithm
int pick_option_index(int num_of_options) {
random_seed += 7;
return random_seed % num_of_options;
}
auto split(string text)
{
// Split text into words
vector<string> words;
string word = "";
for (char c : text) {
if (c == ' ') {
words.push_back(word);
word = "";
} else {
word += c;
}
}
words.push_back(word);
return words;
}
// Generate Markov chain from input text
auto generateMarkovChain(vector<string>& words, int depth) {
// map<string, vector<string>> chain;
unordered_map<string, vector<string>> chain;
// Generate n-grams and add to chain
for (int i = 0; i < words.size() - depth; i++) {
string ngram = "";
for (int j = i; j < i + depth; j++) {
ngram += words[j] + " ";
}
ngram.pop_back(); // Remove extra space at the end
string nextWord = words[i + depth];
if (chain.count(ngram)) {
chain[ngram].push_back(nextWord);
} else {
chain[ngram] = {nextWord};
}
}
return chain;
}
inline string vectorToString(vector<string> vec)
{
string current = "";
for (auto w:vec)
{
current += w + " ";
}
current.pop_back();
return current;
}
std::string generateOutputText(unordered_map<std::string, std::vector<std::string>> chain, int length, std::string seed, int depth) {
std::string output = seed;
std::string current = seed;
int seed_num = std::count(current.begin(), current.end(), ' ') + 1; // Determine depth from first ngram in chain
auto seed_words = split(seed);
std::vector<std::string> ngram_words;
for (int i= seed_num-depth; i < seed_num; ++i)
{
ngram_words.push_back(seed_words[i]);
}
current = vectorToString(ngram_words);
for (int i = 0; i < length - seed_num; i++) {
if (chain.count(current) == 0) {
// No match for current ngram in chain, stop generating output
break;
}
std::vector<std::string> options = chain.at(current);
if (options.empty()) {
// No options available, stop generating output
break;
}
int index = pick_option_index(options.size());
if (index >= options.size()) {
// Index out of bounds, stop generating output
break;
}
std::string next = options[index];
output += " " + next;
ngram_words.push_back(next);
ngram_words.erase(ngram_words.begin());
cerr << current << ":" << output << endl;
current = vectorToString(ngram_words);
cerr << current << ":" << output << endl;
}
return output;
}
int main()
{
string text;
getline(cin, text);
int depth, length;
cin >> depth >> length;
string seed;
cin.ignore();
getline(cin, seed);
// Split text into words
auto words = split(text);
// Generate Markov chain from input text
// map<string, vector<string>> chain = generateMarkovChain(words, depth);
unordered_map<string, vector<string>> chain = generateMarkovChain(words, depth);
// Generate output text using Markov chain and seed
string output = generateOutputText(chain, length, seed, depth);
// Print output
cout << output << endl;
}
参考
https://www.codingame.com/blog/markov-chain-automaton2000/
https://analyticsindiamag.com/hands-on-guide-to-markov-chain-for-text-generation/
https://www.codingame.com/training/hard/code-your-own-automaton2000-step-1