1.引入
1.1. 问题描述
给定一个长度为 \(n(1 \le n \le 10 ^ 5)\) 的字符串和 \(m\) 个模式串 \(s_1, \cdots, s_m\),求问字符串中出现了多少个模式串。
\(\sum s_i \le 10^5\)。
1.2. 解法
考虑当 \(m = 1\) 的情况,直接跑 KMP 就可以了。
但是如果 \(m\) 特别大时,跑 KMP 的复杂度就是 \(\mathcal{O}(n m + \sum s_i)\),无法接受。
现在考虑只询问是否有这些模式串,可以使用字典树实现。
考虑将字典树和 KMP 并在一起,然后就得到新的数据结构,俗称 AC 自动机(Aho–Corasick Automaton)。
2.思想
AC 自动机基于字典树的结构和 KMP 的思想组成。
2.1. 失配指针
对于字典树上每个点,可以看作是一个字符串的前缀,本文中指的前缀就是指字典树上的每个点,现在用 \(trie_{u, c}\) 表示。
失配指针 \(fail_u\) 就是指和点 \(u\) 有最长公共后缀且时真后缀的点 \(v\)
当当前在点 \(trie_{u, c}\) 时,失配指针的具体求法:
- 如果存在 \(trie[fail[u]][c]\),那么 \(fail[u] = trie[fail[f]][c]\)。
- 否则,就同 KMP 一样寻找 \(trie[fail[fail[u]]][c], trie[fail[fail[fail[u]]]][c], \cdots\),直到满足条件 \(1\)。
至于为什么是最长的,这个很好理解。
举个例子,就如下图(稍微有点抽象),插入了 orz
, oh
,mobai
,baixie
图中红色的边即为失配指针。
3.实现
3.1. 求失配指针
建树的过程就没有了,直接是建立失配指针。
void GetFail() {
std::queue<int> q;
for (int i = 0; i < 26; i++) {
if (trie[i][0]) {
q.push(trie[i][0]);
}
}
while (!q.empty()) {
int u = q.front(); q.pop();
for (int i = 0; i < 26; i++) {
if (trie[i][u]) {
fail[trie[i][u]] = trie[i][fail[u]];
q.push(trie[i][u]);
} else {
trie[i][u] = trie[i][fail[u]];
}
}
}
}
实现中有一个优化,没有一直向后跳 fail 指针,为什么这么实现,可以看查询的时候。
发现在求解 fail 指针,如果它是叶子就直接连向了指针指向的点,这样虽然改变了树的结构,但在逻辑上没有影响。感性理解一下就很容易明白。
至于这个的用处主要就是在查询时不用判断一些特殊情况就像当前模式串匹配完之类的问题。
3.2. 多模式串匹配单主串
这个就是最开头问题的解法。
考虑两个串 abbc
,bbc
,在匹配完 abbc
后,后缀 bbc
也匹配上了,但是因为有失配指针,其实上匹配串二,就是从串一通过失配指针跳到串二就可以了。
实现中,因为有 end 标记,所以只需要一边跳,一边加 end 标记就可以了,注意打上标记,表示这个串已被匹配。
int query(const std::string & t) {// t 是主串
memset(vis, 0, sizeof(bool) * (tot + 1));
int ans = 0, u = 0;
for (int i = 0; i < (int)t.size(); i++) {
u = trie[t[i] - 'a'][u];
for (int j = u; j && !vis[j]; j = fail[j]) {
ans += end[j], vis[j] = true;
}
}
return ans;
}
4.效率优化
如果要查询每个模式串在主串出现的次数。
考虑一般的查询中,因为跳失配指针次数太多,导致效率较低。
那么就有一个思想,减少失配指针跳跃次数就很好或者能够让尽可能少的跳跃次数统计到尽可能多的答案。
现在考虑答案时如何统计的,答案是在匹配上一个串后然后向它的失配指针一直跳跃,直到根节点。然后将这些所有经过的点全部标记为出现过。
如果将每个点 \(u\) 向它的失配指针 \(fail[u]\) 连边,最后形成的一定是一棵树,因为失配指针不会连成环,而且每个点(除根节点)都会有一条出边,因此最后会形成一棵内向树。
此时就可以先对每个点打个标记,然后最后从叶子开始遍历一边失配指针树,遍历的同时顺便统计答案即可,代码是 【模板】AC 自动机 的代码。
5.一些模板
下面三个是洛谷上 AC 自动机的模板,代码有部分注释。
因为需要处理字符串,所以用的是 std::string
。
题目分别是 AC 自动机 简单版,AC 自动机(简单版 II),【模板】AC 自动机。
5.1.AC 自动机(简单版)
#include <cstdio>
#include <string>
#include <cstring>
#include <iostream>
#include <queue>
#include <algorithm>
using i64 = long long;
const int N = 1e6 + 5;
struct ACAM {
int trie[26][N];
int fail[N], end[N];
bool vis[N];
int tot;
void insert(const std::string & s) {
int u = 0, c = 0;
for (int i = 0; i < (int)s.size(); i++) {
c = s[i] - 'a';
if (!trie[c][u]) {
trie[c][u] = ++tot;
}
u = trie[c][u];
}
++end[u];
}
void GetFail() {
std::queue<int> q;
for (int i = 0; i < 26; i++) {
if (trie[i][0]) {
q.push(trie[i][0]);
}
}
while (!q.empty()) {
int u = q.front(); q.pop();
for (int i = 0; i < 26; i++) {
if (trie[i][u]) {
fail[trie[i][u]] = trie[i][fail[u]];
q.push(trie[i][u]);
} else {
trie[i][u] = trie[i][fail[u]];
}
}
}
}
int query(const std::string & t) {
memset(vis, 0, sizeof(bool) * (tot + 1));
int ans = 0, u = 0;
for (int i = 0; i < (int)t.size(); i++) {
u = trie[t[i] - 'a'][u];
for (int j = u; j && !vis[j]; j = fail[j]) {
ans += end[j], vis[j] = true;
}
}
return ans;
}
}AM;
int main() {
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
int n; std::cin >> n;
for (int i = 1; i <= n; i++) {
std::string s; std::cin >> s;
AM.insert(s);
}
AM.GetFail();
std::string t; std::cin >> t;
std::cout << AM.query(t);
return 0;
}
5.2.AC 自动机(简单版 II)
#include <cstdio>
#include <string>
#include <cstring>
#include <iostream>
#include <queue>
#include <algorithm>
using i64 = long long;
const int N = 1e6 + 5;
std::string s[155];
struct ACAM {
int trie[26][N];
int fail[N], end[N];
int val[N];
int tot;
void insert(const std::string & s, int id) {
int u = 0, c = 0;
for (int i = 0; i < (int)s.size(); i++) {
c = s[i] - 'a';
if (!trie[c][u]) {
trie[c][u] = ++tot;
}
u = trie[c][u];
}
end[u] = id;
}
void GetFail() {
std::queue<int> q;
for (int i = 0; i < 26; i++) {
if (trie[i][0]) {
q.push(trie[i][0]);
}
}
while (!q.empty()) {
int u = q.front(); q.pop();
for (int i = 0; i < 26; i++) {
if (trie[i][u]) {
fail[trie[i][u]] = trie[i][fail[u]];
q.push(trie[i][u]);
} else {
trie[i][u] = trie[i][fail[u]];
}
}
}
}
/*
模式串的长度并不长,所以直接暴力统计,暴力取最大值就可以啦。
*/
int query(const std::string & t) {
int ans = 0, u = 0;
for (int i = 0; i < (int)t.size(); i++) {
u = trie[t[i] - 'a'][u];
for (int j = u; j; j = fail[j]) {
if (end[j]) {
val[end[j]]++, ans = std::max(ans, val[end[j]]);
}
}
}
return ans;
}
void clear() {
tot = 0;
memset(fail, 0, sizeof(fail));
memset(end, 0, sizeof(end));
memset(val, 0, sizeof(val));
memset(trie, 0, sizeof(trie));
}
int size(){return tot;}
int count(int id){return val[id];}
int son(int u, int c){return trie[c][u];}
int IsEnd(int tot){return end[tot];}
}AM;
int main() {
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
int n;
while (std::cin >> n) {
if (n == 0) {
break;
}
AM.clear();
for (int i = 1; i <= n; i++) {
std::cin >> s[i];
AM.insert(s[i], i);
}
AM.GetFail();
std::string t; std::cin >> t;
int ans = AM.query(t);
std::cout << ans << "\n";
for (int i = 1; i <= n; i++) {
if (AM.count(i) == ans) {
std::cout << s[i] << "\n";
}
}
}
return 0;
}
5.3.【模板】AC自动机
#include <cstdio>
#include <string>
#include <cstring>
#include <iostream>
#include <queue>
#include <algorithm>
using i64 = long long;
const int N = 2e5 + 5;
std::string s[N];
struct ACAM {
int trie[26][N];
int fail[N], end[N];
int tag[N], val[N];
int indx[N];
bool vis[N];
int tot;
void insert(const std::string & s, int id) {
int u = 0, c = 0;
for (int i = 0; i < (int)s.size(); i++) {
c = s[i] - 'a';
if (!trie[c][u]) {
trie[c][u] = ++tot;
}
u = trie[c][u];
}
end[id] = u; // 此题有重复的串所以统计每个串结尾标记。
}
void GetFail() {
std::queue<int> q;
for (int i = 0; i < 26; i++) {
if (trie[i][0]) {
q.push(trie[i][0]);
}
}
while (!q.empty()) {
int u = q.front(); q.pop();
for (int i = 0; i < 26; i++) {
if (trie[i][u]) {
fail[trie[i][u]] = trie[i][fail[u]];
q.push(trie[i][u]);
} else {
trie[i][u] = trie[i][fail[u]];
}
}
}
for (int i = 1; i <= tot; i++) {
indx[fail[i]]++; //不需要真的建树,但是为了拓扑排序,所以记录一下入度,以便从叶子开始遍历。也可以用 dfs 遍历。
}
}
int query(const std::string & t) {
int ans = 0, u = 0;
for (int i = 0; i < (int)t.size(); i++) {
u = trie[t[i] - 'a'][u];
++tag[u];
}
return ans;
}
void recycle() { // 回收答案标记
std::queue<int> q;
for (int i = 1; i <= tot; i++) {
if (!indx[i]) {
q.push(i);
}
}
while (!q.empty()) {
int u = q.front(); q.pop();
val[u] += tag[u];
tag[fail[u]] += tag[u], tag[u] = 0; //将根到点的节点全部加。
if (--indx[fail[u]] == 0) {
q.push(fail[u]);
}
}
}
int count(int id){return val[id];}
int size(){return tot;}
int son(int u, int c){return trie[c][u];}
int End(int tot){return end[tot];}
}AM;
int main() {
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
int n; std::cin >> n;
for (int i = 1; i <= n; i++) {
std::cin >> s[i];
AM.insert(s[i], i);
}
AM.GetFail();
std::string t; std::cin >> t;
AM.query(t);
AM.recycle();
for (int i = 1; i <= n; i++) {
std::cout << AM.count(AM.End(i)) << "\n";
}
return 0;
}
6.AC 自动机优化 dp
想不到吧,AC 自动机上也可以优化 dp.
6.1 例题
6.1.1 JSOI2007 文本生成器
6.1.1.1 题目描述
给定 \(n(1 \le n \le 60)\) 个字符串(\(s_1, s_2, \cdots, s_n\)),\(| s_i | \le 100\),问有多少个长度为 \(m\) 的字串中至少包含一个字符串。
字符集 \(\Sigma = 26\)。
6.1.1.2 过程
直接统计至少包含一个字符串的太麻烦,总方案数又很好求,因此考虑求出没有包含任何串的方案数。
考虑暴力,似乎不太可作,因为需要知道每个串匹配到哪里了。
因为考虑模式串的匹配,所以考虑使用 AC 自动机优化,状态改为 \(f_{i, j}\) 表示前 \(i\) 个字符跳到 AC 自动机上的第 \(j\) 个节点的方案数,转移也是枚举下一个位置填谁,然后考虑是否继承当前方案。
最后答案 \(ans = 26 ^ m - \sum f_{m, i}\),时间复杂度 \(\mathcal{O}(m \sum | s_i |\Sigma)\)。
6.1.1.3 实现
#include <cstdio>
#include <string>
#include <cstring>
#include <iostream>
#include <queue>
#include <algorithm>
using i64 = long long;
const int N = 6005;
const int mod = 1e4 + 7;
struct ACAM {
int trie[26][N];
int fail[N], end[N];
int tot;
void insert(const std::string & s) {
int u = 0, c = 0;
for (int i = 0; i < (int)s.size(); i++) {
c = s[i] - 'A';
if (!trie[c][u]) {
trie[c][u] = ++tot;
}
u = trie[c][u];
}
++end[u];
}
void GetFail() {
std::queue<int> q;
for (int i = 0; i < 26; i++) {
if (trie[i][0]) {
q.push(trie[i][0]);
}
}
while (!q.empty()) {
int u = q.front(); q.pop();
end[u] |= end[fail[u]]; // 注意这个地方的细节
for (int i = 0; i < 26; i++) {
if (trie[i][u]) {
fail[trie[i][u]] = trie[i][fail[u]];
q.push(trie[i][u]);
} else {
trie[i][u] = trie[i][fail[u]];
}
}
}
}
int size(){return tot;}
int son(int u, int c){return trie[c][u];}
int isend(int tot){return end[tot];}
}AM;
i64 qpow(i64 a, i64 b) {
i64 ans = 1;
for (; b; b >>= 1, a = a * a % mod) {
if (b & 1) {
ans = ans * a % mod;
}
}
return ans;
}
i64 f[105][N];
int main() {
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
int n, m; std::cin >> n >> m;
for (int i = 1; i <= n; i++) {
std::string s; std::cin >> s;
AM.insert(s);
}
AM.GetFail();
f[0][0] = 1;
for (int i = 0; i < m; i++) {
for (int j = 0; j <= AM.size(); j++) {
for (int k = 0; k < 26; k++) {
if (!AM.isend(AM.son(j, k))) {
f[i + 1][AM.son(j, k)] = (f[i + 1][AM.son(j, k)] + f[i][j]) % mod;
}
}
}
}
i64 ans = 0;
for (int i = 0; i <= AM.size(); i++) {
ans = (ans + f[m][i]) % mod;
}
std::cout << (qpow(26, m) - ans + mod) % mod;
return 0;
}
7.总结
- AC 自动机上 dp 一般都要设跳到那个节点。
- AC 自动机主要能够处理多模式串匹配问题。