Trie树
Trie树又称字典树、单词查找树,是一种能够高效存储和查找字符串合集的数据结构。可以快速地在集合中查询某个字符串。
Trie树的本质就是利用字符串之间的公共前缀,将重复的前缀合并在一起。
举个例子,有五个字符串,code,cook,five,file,fat,组织成字典树就是下面这个样子:
性质:
- 1.根节点不包含字符,除根节点外每一个节点都只包含一个字符
- 2.从根节点到某一字符,路径上经过的字符连接起来,为该节点对应的字符串
- 3.每个节点的所有子节点包含的字符都不相同
基本操作:
查找、插入
Trie树的插入操作:是将单词的每个字母逐一插进Trie树中。插入前先看该字母对应的节点是否已经存在,若存在共享节点,则不用创建对应的节点,若不存在,就需要创建一个新节点。最后我们在最后一个字母节点进行标记(即图中的黄色节点代表已标记)。
Trie树的查询操作:比如我们查找code,可以将要查找的单词code分割成单个的字母c,o,d,e,然后从Trie树的根节点开始匹配,直到每个节点都存在(需要在一条路径上)且最后一个节点被标记,才算查找成功。
时间复杂度:
假设所有的字符串长度之和为n,构建字典树的时间复杂度为\(O(n)\),
假设要查找的字符串长度为k,查找的时间复杂度为\(O(k)\)
例题1:模板题:Trie字符串统计
题目描述:
维护一个字符串集合,支持两种操作:
"I x"向集合中插入一个字符串x;
"Q x"询问一个字符串在集合中出现了多少次。
共有n个操作,输入仅包含小写字母。
#include<bits/stdc++.h>
#define rg register
#define qwq 0
using namespace std;
const int N = 1e5 + 3;
int son[N][26], cnt[N], idx;
//son[N][26]存的值是儿子节点对应的idx。一维下标是父节点的idx,二维下标是这个父节点的直接子节点的str[i]-'a'的值
//cnt[N]:以字符串的最后一个字符对应的idx作为下标,存储以该idx结尾的字符串的个数
//idx表示当前字符的编号,根节点为0
char str[N];
inline void insert(char str[]) { //插入操作
rg int p = 0;
for (rg int i = 1; str[i]; i++) { //对这个单词的每个字母依次进行插入
rg int u = str[i] - 'a';
if (!son[p][u]) son[p][u] = ++idx; //如果没有就添加
p = son[p][u];
}
cnt[p]++; //以p结尾的字符串多了一个
}
inline int query(char str[]) { //查询出现次数
rg int p = 0;
for (rg int i = 1; str[i]; i++) {
rg int u = str[i] - 'a';
if (!son[p][u]) return qwq; //有某个节点不存在直接结束
p = son[p][u];
}
return cnt[p];
}
int main() {
ios::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
int n;
cin >> n;
while (n--) {
rg char opt;
cin >> opt >> str + 1;
if (opt == 'I') insert(str);
else cout << query(str) << "\n";
}
return qwq;
}
/*
11
I abcd
I abc
I abcdf
Q abcd
I abcd
Q abcd
I bcdf
I abcd
I abc
Q abc
Q abcd
1
2
2
3
*/
例题2:最大异或树
将数字变为二进制存储统计。
我们想要求一个数a的匹配数b使得a xor b最大,自然是希望从最高位开始,b的每一位都和a相同。自然想到,可以用Trie树来存储每一个数的二进制表示,然后从高位查询异或和最大的数。
#include<bits/stdc++.h>
#define rg register
#define qwq 0
using namespace std;
const int N = 1e5 + 3, M = 31 * N;
int n;
int a[N];
int trie[M][2], idx;
inline void insert(int x) { //将x的二进制插入Trie树中
rg int p = 0;
for (rg int i = 30; i >= 0; i--) {
//取出x的第i位数
rg int u = x >> i & 1;
if (!trie[p][u]) trie[p][u] = ++idx;
p = trie[p][u];
}
}
inline int query(int x) {
rg int p = 0;
//res是与x异或和最大的数
rg int res = 0;
for (rg int i = 30; i >= 0; i--) {
//取出x的第i位
rg int u = x >> i & 1;
if (trie[p][!u]) {
p = trie[p][!u];
res = (res << 1) + !u;
} else {
p = trie[p][u];
res = (res << 1) + u;
}
}
return res;
}
int main() {
ios::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
cin >> n;
for (rg int i = 0; i < n; i++) cin >> a[i];
rg int res = 0;
for (rg int i = 0; i < n; i++) {
//先插入查询,目的是解决边界问题,最开始还没有储存任何一个数
insert(a[i]);
rg int t = query(a[i]);
res = max(res, a[i] ^ t);
}
cout << res << "\n";
return qwq;
}
例题3:Remember the Word
解决方法:Trie+dp
对于dp:
- 状态:令f[i]表示\(i\sim len\)的串有多少种表示方法。
- 转移:\(f[i]=\sum f[j+1]\),且要满足\(s[i,\cdots,j+1]\)可以由多个字典拼成。
dp的结果很显然就是\(f[0]\),因为i是倒序枚举的。
#include<bits/stdc++.h>
#define rg register
#define qwq 0
using namespace std;
const int N = 103, M = 1e6 + 3, mod = 20071027;
char str[N], c[M];
int n, idx, sum;
int vis[M], trie[M][26], f[M];
inline void insert(char s[]) {
rg int p = 0;
for (rg int i = 1; s[i]; i++) {
rg int u = s[i] - 'a';
if (!trie[p][u]) trie[p][u] = ++idx;
p = trie[p][u];
}
vis[p] = 1;
}
int main() {
ios::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
while (cin >> c + 1) {
idx = 0;
memset(trie, 0, sizeof(trie));
memset(vis, 0, sizeof(vis));
memset(f, 0, sizeof(f));
cin >> n;
for (rg int i = 1; i <= n; i++) {
cin >> str + 1;
insert(str);
}
rg int len = strlen(c + 1);
f[len + 1] = 1;
for (rg int i = len; i >= 1; i--) {
rg int p = 0;
for (rg int j = i; j <= len; j++) {
rg int u = c[j] - 'a';
if (!trie[p][u]) break;
p = trie[p][u];
if (vis[p]) {
f[i] = (f[j + 1] + f[i] + mod) % mod;
}
}
}
cout << "Case " << ++sum << ": " << f[1] << "\n";
}
return qwq;
}
例题4:[省选联考 2020 A 卷] 树
模拟这个过程,我们发现每次都要将子树内的点的w值整体+1,再求子树内的异或和。所以,我们需要一种工具,支持子树+1、插入一个数和查询子树异或和这三种操作。
解决位运算问题,第一要想到的是01trie,便于实现。那么每次的操作是:dfs孩子,将孩子的01trie合并,再把合并后的01trie整体+1,最后插入\(w(x)\),整体求值。
01trie的合并和线段树合并一样,把每个点的权值对应相加即可,复杂度O(log W)。
01trie的整体+1,就是把点x的左儿子和右儿子同时+1,那么左儿子0 +1就变成了1,成了右儿子(规定一个节点的左儿子为0,右儿子为1)。这里,我们每个点记一个结束标记ed。如果点x有结束标记,那么左儿子+1后变成的右儿子,也应该不上一个结束标记。注意:如果点x原先没有左儿子,应该新建一个点进行操作。而右儿子整体+1就变成了左儿子,再递归处理即可。
inline void add1(int rt) {
if (!rt) return ;
swap(t[rt].c[0], t[rt].c[1]);
if (t[rt].ed && !t[rt].c[1]) t[rt].c[1] = ++cur
t[t[rt].c[1]].ed ^= t[rt].ed;
t[t[rt].c[1]].cnt ^= t[rt].ed;
t[rt].ed = 0;
add1(t[rt].c[0]);
update(rt);
}
最后是update。一个点的权值,是左儿子的权值乘2异或右儿子的权值乘2,因为右儿子代表最低位为1,如果右儿子子树内有奇数的串,那么还要把这个点的权值加一,所以还要记一个cnt代表子树内串的个数的奇偶性。
#include<bits/stdc++.h>
#define rg register
#define qwq 0
#define ll long long
using namespace std;
const int N = 525013;
ll ans;
int n, tot, cur;
int w[N], root[N];
struct trie {
int cnt, val, ed;
int c[2];
} t[N * 27];
vector<int> e[N];
inline void pushup(int rt) {
rg int ls = t[rt].c[0], rs = t[rt].c[1];
t[rt].cnt = t[rt].ed ^ t[ls].cnt ^ t[rs].cnt;
t[rt].val = ((t[ls].val << 1) ^ ((t[rs].val << 1) | t[rs].cnt));
}
inline void add1(int rt) {
if (!rt) return ;
swap(t[rt].c[0], t[rt].c[1]);
if (t[rt].ed && !t[rt].c[1]) t[rt].c[1] = ++cur;
t[t[rt].c[1]].ed ^= t[rt].ed;
t[t[rt].c[1]].cnt ^= t[rt].ed;
t[rt].ed = 0;
add1(t[rt].c[0]);
pushup(rt);
}
inline void insert(int &rt, int y) {
if (!rt) rt = ++cur;
if (!y) {
t[rt].ed ^= 1;
t[rt].cnt ^= 1;
return ;
}
insert(t[rt].c[y & 1], y >> 1);
pushup(rt);
}
inline int merge(int x, int y) {
if (!x || !y) return x | y;
t[x].ed ^= t[y].ed;
t[x].cnt ^= t[y].cnt;
t[x].val ^= t[y].val;
t[x].c[0] = merge(t[x].c[0], t[y].c[0]);
t[x].c[1] = merge(t[x].c[1], t[y].c[1]);
return x;
}
inline void dfs(int u) {
insert(root[u], w[u]);
for (rg int i = 0; i < e[u].size(); i++) {
rg int v = e[u][i];
dfs(v);
add1(root[v]);
root[u] = merge(root[u], root[v]);
}
ans += 1ll * t[root[u]].val;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
cout.tie(nullptr);
cin >> n;
for (rg int i = 1; i <= n; i++) {
cin >> w[i];
}
for (rg int i = 2; i <= n; i++) {
rg int a;
cin >> a;
e[a].push_back(i);
}
dfs(1);
cout << ans << "\n";
return qwq;
}
AC自动机
概述
AC自动机是以Trie树的结构为基础,结合KNP的思想建立的。
解释
简单来说,建立一个AC自动机有两个步骤:
- 基础的Trie结构:将所有模式串构成一棵Trie
- KMP的思想:对Trie树上的所有节点构造失配指针
然后就可以利用它进行多模式匹配了。
字典树构建
AC自动机在初始时会将若干个模式串丢到一个Trie里,然后在Trie上建立AC自动机。这个Trie就是普通的Trie,该怎么建怎么建。
这里需要解释一下Trie节点的含义。Trie中的节点表示的是某个模式串的前缀。我们也将其称作状态。一个节点表示一个状态,Trie的边就是状态的转移。
形式化的说,对于若干个模式串\(s_1,s_2\cdots,s_n\),将它们构建成一棵字典树后的所有状态的集合记作\(Q\)。
失配指针
AC自动机利用一个fail指针来辅助多模式串的匹配。
状态u的fail指针指向另一个状态v,其中\(v \in Q\),且v是u的最长后缀(即在若干个后缀状态中取最长的一个作为fail指针)。
这里简单比较一下fail指针与KMP中的next指针:
- 共同点:两者同样是在失配的时候用于跳转的指针。
- 不同点:next指针求的是最长的相同前后缀,而fail指针指向 所有模式串的前缀 中匹配 当前状态的最长后缀。
因为KMP只对一个模式串做匹配,而AC自动机要对多个模式串做匹配。有可能fail指针指向的节点对应着另外一个模式串,两者前缀不同。
总结,AC自动机的fail指针指向当前状态的最长后缀状态,且AC自动机在做匹配时同一位上可匹配多个模式串。
构建指针
下面介绍构建fail指针的基础思想:
考虑将字典树中当前的节点u,u的父亲是p,p通过字符c的边指向u,即\(trie[p][c]=u\)。假设深度小于u的所有节点的fail指针都已求得。
- 1.如果\(trie[fail[p]][c]\)存在:则让u的fail指针指向\(trie[fail[p]][c]\)。相当于在p和\(fail[p]\)后面加一个字符c,分别对应u和\(fail[u]\)。
- 2.如果\(trie[fail[p]][c]\)不存在:那么我们继续找到\(trie[fail[fail[p]]][c]\)。重复1的判断过程,一直跳fail指针直到根节点。
- 3.如果真的没有,就让fail指针指向根节点。
如此即完成了\(fail[u]\)的构建。
例子
对字符串\(i,he,his,she,hers\)组成的字典树构建fail指针:
黄色结点:当前的结点 u。
绿色结点:已经 BFS 遍历完毕的结点。
橙色的边:fail 指针。
红色的边:当前求出的 fail 指针。
我们重点分析结点 6 的 fail 指针构建:
找到6的父结点5,\(fail[5]=10\)。然而10结点没有字母s连出的边;继续跳到 10的fail指针,\(fail[10]=0\)。发现0结点有字母s连出的边,指向7结点;所以 \(fail[6]=7\)。
最后放一张建出来的图:
字典树与字典图
先来看构建函数\(build()\),该函数的目标有两个,一个是构建fail指针,一个是构建自动机。
\(tr[u][c]\):从状态u后加一个字符c到达的状态。
队列q:用于BFS遍历字典树。
\(fail[u]\):节点u的fail指针。
inline void build() {
for (rg int i = 0; i < 26; i++) {
if (trans[0][i]) q.push(trans[0][i]);
}
while (q.size()) {
rg int u = q.front();
q.pop();
for (rg int i = 0; i < 26; i++) {
if (trans[u][i]) {
fail[trans[u][i]] = trans[fail[u]][i];
q.push(trans[u][i]);
} else {
trans[u][i] = trans[fail[u]][i];
}
}
}
}
解释
build函数将节点按BFS顺序入队,依次求fail指针。这里的字典树根节点为0,我们将根节点的子节点一一入队。若将根节点入队,则在第一次BFS的时候,会将根节点儿子的fail指针标记为本身。因此我们将根节点的儿子一一入队,而不是将根节点入队。
然后开始BFS:每次取出队首的节点u(\(fail[u]\)在之前的BFS过程中已求得),然后遍历字符集:
- 1.如果\(trans[u][i]\)存在,我们就将\(trans[u][i]\)的fail指针赋值为\(trans[fail[u]][i]\)。根据之前的讲解,这里应该不停的跳指针,但这里通过特殊处理简化了这些代码。
- 2.否则,令\(trans[u][i]\)指向\(trans[fail[u]][i]\)的状态。
这里的处理是,通过else语句的代码修改字典树的结构。没错,它将不存在的字典树的状态链接到了失配指针的对应状态。在原字典树中,每一个节点代表一个字符串S,是某个模式串的前缀。而在修改字典树结构后,尽管增加了许多转移关系,但节点所代表的字符串是不变的。
而\(trans[S][c]\)相当于是在S后添加一个字符c变成另一个状态\(S'\)。如果\(S'\)存在,说明存在一个模式串的前缀是\(S'\),否则我们让\(trans[S][c]\)指向\(trans[fail[S]][c]\)。由于\(fail[S]\)对应的字符串是S的后缀,因此\(trans[fail[S]][c]\)对应的字符串也是\(S'\)的后缀。
换言之在字典树上跳转的时候,我们只会从\(S\)跳转到\(S'\),相当于匹配了一个\(S'\);但在AC自动机上跳转的时候,我们会从S跳转到\(S'\)的后缀,也就是说我们在匹配一个字符c,然后舍弃S的部分前缀。舍弃前缀显然是能匹配的。那么fail指针呢?它也在舍弃前缀!试想一下,如果文本串能匹配S,显然它也能匹配S的后缀。所谓的fail指针其实就是S的一个后缀合集。
这样修改字典树的结构,使得匹配转移更加完善。同时它将fail指针跳转的路径做了压缩(就像并查集的路径压缩),使得本来需要跳很多次的fail指针变成跳一次。
过程
我们将之前的图改一下:
1.蓝色结点:BFS 遍历到的结点 u
2.蓝色的边:当前结点下,AC 自动机修改字典树结构连出的边。
3.黑色的边:AC 自动机修改字典树结构连出的边。
4.红色的边:当前结点求出的 fail 指针
5.黄色的边:fail 指针
6.灰色的边:字典树的边
可以发现,众多交错的黑边将字典树变成了字典图。图中省略了连向根节点的黑边(否则会更乱)。
我们重点分析一下节点5遍历时的情况。我们求\(trans[5][s]=6\)的fail指针:
本来的策略是找fail指针,于是我们跳到\(fail[5]=10\)发现没有s连出的字典树的边,于是跳到\(fail[10]=0\),发现有\(trie[0][s]=7\),于是\(fail[6]=7\);但是有了黑边、蓝边,我们跳到\(fail[5]=10\)之后直接走\(trans[10][s]=7\)就到7号节点了。
这就是build完成的两件事:构建fail指针和建立字典图。这个字典图也会在查询的时候起到关键的作用。
多模式匹配
接下来分析匹配函数\(query()\):
实现:
inline int query(char *t) {
rg int u = 0, res = 0;
for (rg int i = 1; t[i]; i++) {
rg int u = trans[u][t[i] - 'a']; //转移
for (rg int j = u; j && cnt[j] != -1; j = fail[j]) {
res += cnt[j];
cnt[j] = -1;
}
}
return res;
}
解释
这里u作为字典树上当前匹配到的节点,res即返回的答案。循环遍历匹配串,u在字典树上跟踪当前字符。利用fail指针找出所有匹配的模式串,累加到答案中。然后清零。在上文中我们分析过,字典树的结构其实就是一个trans函数,而构建好了这个函数后,在匹配字符串的过程中,我们会舍弃部分前缀达到最低限度的匹配。fail指针则指向了更多的匹配状态。
例题1:AC 自动机(简单版)
#include<bits/stdc++.h>
#define rg register
#define qwq 0
using namespace std;
const int N = 1e6 + 10;
int trie[N][26];
int cntword[N]; //记录该单词出现次数
int fail[N];
int cnt;
inline void insert(char s[]) {
rg int rt = 0;
for (rg int i = 0; s[i]; i++) {
rg int u = s[i] - 'a';
if (!trie[rt][u]) trie[rt][u] = ++cnt;
rt = trie[rt][u];
}
cntword[rt]++;
}
inline void build() {
queue<int> q;
for (rg int i = 0; i < 26; i++) {
if (trie[0][i]) {
//fail[trie[0][i]] = 0;
q.push(trie[0][i]);
}
}
while (!q.empty()) {
rg int u = q.front();
q.pop();
for (rg int i = 0; i < 26; i++) {
if (trie[u][i]) {
fail[trie[u][i]] = trie[fail[u]][i];
q.push(trie[u][i]);
} else {
trie[u][i] = trie[fail[u]][i];
}
}
}
}
inline int query(char s[]) {
rg int u = 0, res = 0;
for (rg int i = 0; s[i]; i++) {
u = trie[u][s[i] - 'a']; //转移
for (rg int j = u; j && cntword[j] != -1; j = fail[j]) {
res += cntword[j];
cntword[j] = -1;
}
}
return res;
}
int n;
char c[N];
int main() {
scanf("%d", &n);
for (rg int i = 1; i <= n; i++) {
scanf("%s", c);
insert(c);
}
build();
scanf("%s", c);
printf("%d\n", query(c));
return qwq;
}
例题2:AC 自动机(简单版 II)
我们将标记模式串的idx设为当前是第几个模式串。
于是插入时idx[rt]++
变为idx[rt]=id
,id表示该字符是第id个输入的。
查询:
我们开一个数组\(val[i]\),表示第i个字符串出现的次数。因为是重复计算,所以不能标记为-1了。我们每经过一个点,如果有模式串标记,就将val[]++
,然后继续跳fail。
#include<bits/stdc++.h>
#define rg register
#define qwq 0
using namespace std;
const int N = 156, M = N * 80, L = 1e6 + 6;
int tot, trie[M][26];
int fail[M], idx[M], val[M];
int cnt[N]; //记录第i个字符串出现的次数
inline void init() {
memset(fail, 0, sizeof(fail));
memset(trie, 0, sizeof(trie));
memset(val, 0, sizeof(val));
memset(cnt, 0, sizeof(cnt));
memset(idx, 0, sizeof(idx));
tot = 0;
}
inline void insert(char str[], int id) {
rg int rt = 0;
for (rg int i = 0; str[i]; i++) {
rg int u = str[i] - 'a';
if (!trie[rt][u]) trie[rt][u] = ++tot;
rt = trie[rt][u];
}
idx[rt] = id; //以u为结尾的字符串编号为 idx[u]
}
inline void get_fail() {
queue<int> q;
for (rg int i = 0; i < 26; i++) {
if (trie[0][i]) q.push(trie[0][i]);
}
while (!q.empty()) {
rg int u = q.front();
q.pop();
for (rg int i = 0; i < 26; i++) {
if (trie[u][i]) {
fail[trie[u][i]] = trie[fail[u]][i];
q.push(trie[u][i]);
} else {
trie[u][i] = trie[fail[u]][i];
}
}
}
}
inline int query(char str[]) {
rg int rt = 0, res = 0;
for (rg int i = 0; str[i]; i++) {
rt = trie[rt][str[i] - 'a'];
for (rg int j = rt; j; j = fail[j]) {
val[j]++;
}
}
for (rg int i = 0; i <= tot; i++) {
if (idx[i]) {
res = max(res, val[i]);
cnt[idx[i]] = val[i];
}
}
return res;
}
int n;
char c[N][100], t[L];
int main() {
while (scanf("%d", &n) && n) {
init();
for (rg int i = 1; i <= n; i++) {
scanf("%s", c[i]);
insert(c[i], i);
}
get_fail();
scanf("%s", t);
rg int x = query(t);
printf("%d\n", x);
for (rg int i = 1; i <= n; i++) {
if (cnt[i] == x) printf("%s\n", c[i]);
}
}
return qwq;
}
例题3:【模板】AC 自动机
因为我们是通过不断跳fail指针来减小复杂度,但假如有一个串是\(aaaaaaaa\cdots\),那么每个点的fail指针跳跃后深度只减小1,复杂度上天……
比如这张图,假如我们匹配到4号节点,那么跳fail到7号节点,再跳到9,
然后下次匹配到节点7,又跳到9……
实际上,我们可以把次数先存在自己这里,最后一次性传给后代。
不难发现,如果把fail指针指向的节点当作父亲,自己当作儿子,一遍树形dp就可以解决。
所以AC自动机只需要加第一个节点的值就好了(树上差分)。
#include<bits/stdc++.h>
#define rg register
#define qwq 0
using namespace std;
const int N = 2e5 + 3, M = 2e6 + 3;
int n, id, num[N], mp[N];
char s[N];
char w[M];
vector<int> e[N];
int trie[N][26], fail[N];
inline void insert(char str[], int idx) {
rg int rt = 0;
for (rg int i = 0; str[i]; i++) {
rg int u = str[i] - 'a';
if (!trie[rt][u]) trie[rt][u] = ++id;
rt = trie[rt][u];
}
mp[idx] = rt;
}
inline void get_fail() {
queue<int> q;
for (rg int i = 0; i < 26; i++) {
if (trie[0][i]) q.push(trie[0][i]);
}
while (!q.empty()) {
rg int u = q.front();
q.pop();
for (rg int i = 0; i < 26; i++) {
if (trie[u][i]) {
fail[trie[u][i]] = trie[fail[u]][i];
q.push(trie[u][i]);
} else {
trie[u][i] = trie[fail[u]][i];
}
}
}
}
inline void query(char c[]) {
rg int rt = 0;
for (rg int i = 0; c[i]; i++) {
rt = trie[rt][c[i] - 'a'];
num[rt]++;
}
}
inline void dfs(int u) {
for (rg int i = 0; i < e[u].size(); i++) {
rg int v = e[u][i];
dfs(v);
num[u] += num[v];
}
}
int main() {
scanf("%d", &n);
for (rg int i = 1; i <= n; i++) {
scanf("%s", s);
insert(s, i);
}
get_fail();
scanf("%s", w);
query(w);
for (rg int i = 1; i <= id; i++) e[fail[i]].push_back(i);
dfs(0);
for (rg int i = 1; i <= n; i++) printf("%d\n", num[mp[i]]);
return qwq;
}
例题4:[HDU 2222]Keywords Search
同例题1。
#include<bits/stdc++.h>
#define rg register
#define qwq 0
using namespace std;
const int N = 1e4 + 3, M = 1e6 + 3;
int T, n, idx;
int trie[N][26], fail[N * 50], cnt[N * 50];
char c[N], w[M];
inline void init() {
memset(trie, 0, sizeof(trie));
memset(fail, 0, sizeof(fail));
memset(cnt, 0, sizeof(cnt));
idx = 0;
}
inline void insert(char str[], int id) {
rg int rt = 0;
for (rg int i = 0; str[i]; i++) {
rg int u = str[i] - 'a';
if (!trie[rt][u]) trie[rt][u] = ++idx;
rt = trie[rt][u];
}
cnt[rt]++;
}
inline void get_fail() {
queue<int> q;
for (rg int i = 0; i < 26; i++) {
if (trie[0][i]) q.push(trie[0][i]);
}
while (!q.empty()) {
rg int u = q.front();
q.pop();
for (rg int i = 0; i < 26; i++) {
if (trie[u][i]) {
fail[trie[u][i]] = trie[fail[u]][i];
q.push(trie[u][i]);
} else {
trie[u][i] = trie[fail[u]][i];
}
}
}
}
inline int query(char str[]) {
rg int rt = 0, res = 0;
for (rg int i = 0; str[i]; i++) {
rg int u = str[i] - 'a';
rt = trie[rt][u];
for (rg int j = rt; j && !cnt[j]; j = fail[j]) {
res += cnt[j];
cnt[j] = 0;
}
}
return res;
}
int main() {
scanf("%d", &T);
while (T--) {
init();
scanf("%d", &n);
for (rg int i = 1; i <= n; i++) {
scanf("%s", c);
insert(c, i);
}
get_fail();
scanf("%s", w);
printf("%d\n", query(w));
}
return qwq;
}
例题5:[JSOI2012] 玄武密码
把M个模式串处理成trie树,用文本串\(S\)进行匹配。S遍历trie树时经过的节点以及此节点通过fail指针跳到的节点u,在trie树上1到u的路径上的字符串一定存在于S上,故我们在节点u打上标记,再重新遍历每个模式串,拥有标记即可更新答案。
#include<bits/stdc++.h>
#define rg register
#define qwq 0
using namespace std;
const int N = 1e7 + 3, M = 1e5 + 3;
int n, m;
int trie[N][4], idx, fail[N], cnt[N];
int vis[N];
char c[M][103], w[N];
inline int get_u(char ch) {
if (ch == 'E') return 0;
if (ch == 'S') return 1;
if (ch == 'W') return 2;
if (ch == 'N') return 3;
}
inline void insert(char s[]) {
rg int rt = 0;
for (rg int i = 0; s[i]; i++) {
rg int u = get_u(s[i]);
if (!trie[rt][u]) trie[rt][u] = ++idx;
rt = trie[rt][u];
}
cnt[rt]++;
}
inline void get_fail() {
queue<int> q;
for (rg int i = 0; i < 4; i++) {
if (trie[0][i]) q.push(trie[0][i]);
}
while (!q.empty()) {
rg int u = q.front();
q.pop();
for (rg int i = 0; i < 4; i++) {
if (trie[u][i]) {
fail[trie[u][i]] = trie[fail[u]][i];
q.push(trie[u][i]);
} else {
trie[u][i] = trie[fail[u]][i];
}
}
}
rg int rt = 0;
for (rg int i = 0; i < n; i++) {
rg int u = get_u(w[i]);
rt = trie[rt][u];
for (rg int j = rt; j && !vis[j]; j = fail[j]) vis[j] = 1;
}
}
inline int query(char s[]) {
rg int rt = 0, res = 0;
for (rg int i = 0; s[i]; i++) {
rg int u = get_u(s[i]);
rt = trie[rt][u];
if (vis[rt]) res = i + 1;
}
return res;
}
int main() {
scanf("%d%d", &n, &m);
scanf("%s", w);
for (rg int i = 1; i <= m; i++) {
scanf("%s", c[i]);
insert(c[i]);
}
get_fail();
for (rg int i = 1; i <= m; i++) {
printf("%d\n", query(c[i]));
}
return qwq;
}
例题6:[TJOI2013] 单词
和例题3差不多,就是查询对象变为每个模式串,于是就可以在插入的时候同时统计,既节省了空间,又减少了码量。
注意!!!:char数组能用一维就用一维,实在要用二维建议改成string,不然容易炸空间!!!
#include<bits/stdc++.h>
#define rg register
#define qwq 0
using namespace std;
const int N = 1e6 + 3;
int n, idx;
int cnt[N], h[N], id[N];
int trie[N][26], fail[N];
char s[N];
vector<int> e[N];
inline void insert(char c[], int th) {
rg int rt = 0;
for (rg int i = 0; c[i]; i++) {
rg int u = c[i] - 'a';
if (!trie[rt][u]) trie[rt][u] = ++idx;
rt = trie[rt][u];
cnt[rt]++; //直接在insert时统计
}
id[th] = rt;
}
inline void get_fail() {
queue<int> q;
for (rg int i = 0; i < 26; i++) {
if (trie[0][i]) q.push(trie[0][i]);
}
while (!q.empty()) {
rg int u = q.front();
q.pop();
for (rg int i = 0; i < 26; i++) {
if (trie[u][i]) {
fail[trie[u][i]] = trie[fail[u]][i];
q.push(trie[u][i]);
} else {
trie[u][i] = trie[fail[u]][i];
}
}
}
}
inline void dfs(int u) {
for (rg int i = 0; i < e[u].size(); i++) {
rg int v = e[u][i];
dfs(v);
cnt[u] += cnt[v];
}
}
int main() {
scanf("%d", &n);
for (rg int i = 1; i <= n; i++) {
scanf("%s", s);
insert(s, i);
}
get_fail();
for (rg int i = 1; i <= idx; i++) e[fail[i]].push_back(i);
dfs(0);
for (rg int i = 1; i <= n; i++) printf("%d\n", cnt[id[i]]);
}
例题7:[POI2000] 病毒
一般的AC自动机是尽量多的遍历模式串末尾的标记,使总值最大。本题却不太一样,是要构造一个无限长的文本串,使其不到达标记。
于是可以想到,要在trie图上找到一个环,且环上没有任何标记,这样就能一直匹配。这个找环可以通过dfs来实现。
- 1.我们建立两个布尔数组,一个记录每个节点在当前dfs路径上有没有被选中,另一个记录每个节点历史上有没有被访问过。如果形成回路,就找到环了。
- 2.避免标记,也就是说如果下一个节点有标记,就不走那个节点。
- 3.在构造失配指针时一个优化是:如果一个节点拥有了失配指针,它指向的节点如果有危险标记,自己也必然危险。
#include<bits/stdc++.h>
#define rg register
#define qwq 0
using namespace std;
const int N = 3e4 + 3;
int trie[N][2], fail[N];
bool cnt[N];
int idx = 0;
inline void insert(char str[]) {
rg int rt = 0;
for (rg int i = 0; str[i]; i++) {
rg int u = str[i] - '0';
if (!trie[rt][u]) {
trie[rt][u] = ++idx;
}
rt = trie[rt][u];
}
cnt[rt] = true;
}
inline void get_fail() {
queue<int> q;
if (trie[0][0]) q.push(trie[0][0]);
if (trie[0][1]) q.push(trie[0][1]);
while (!q.empty()) {
rg int rt = q.front();
q.pop();
for (rg int i = 0; i <= 1; i++) {
if (trie[rt][i]) {
q.push(trie[rt][i]);
rg int nxt = fail[rt];
while (nxt && trie[nxt][i] <= 0) nxt = fail[nxt]; //要么到根节点,要么找到最长匹配后缀段
if (trie[nxt][i] <= 0) fail[trie[rt][i]] = 0; //失配指针转移到根节点
else {
fail[trie[rt][i]] = trie[nxt][i];
if (cnt[trie[nxt][i]]) cnt[trie[rt][i]] = true; //既然自己后缀行不通,自己也危险
}
} else {
trie[rt][i] = trie[fail[rt]][i];
}
}
}
}
bool vis[N], flag[N];
char c[N];
int n;
inline void dfs(int u) { //寻找有没有环
vis[u] = true;
for (rg int i = 0; i <= 1; i++) {
if (vis[trie[u][i]]) { //根据路径标记判断是否有环
printf("TAK\n");
exit(0); //找到环并推出
} else if (!cnt[trie[u][i]] && !flag[trie[u][i]]) {
//只有下一位不为危险节点并且有可能成环,才递归搜索
flag[trie[u][i]] = true;
dfs(trie[u][i]);
}
}
vis[u] = false;
}
int main() {
scanf("%d", &n);
for (rg int i = 1; i <= n; i++) {
scanf("%s", c);
insert(c);
}
get_fail();
dfs(0);
printf("NIE\n");
return qwq;
}
例题8:[JSOI2007] 文本生成器
在AC自动机上的dp都非常套路,大部分令\(f[i][j]\)表示当前串长为i,在节点j时的情况。有时再加一维表示这个状态里包含了哪些东西。而且AC自动机的dp会经常用矩阵乘法优化。
但这题有很多个单词,我们无法像单个单词那样记录当前文章和每个单词末尾的匹配位数,会MLE。
原题目看起来不好做,但是有“至少一个”这个条件,我们就可以想到容斥,用总方案数减去不合法方案。总方案很好算,为\(26^m\)。
首先建出AC自动机。
令\(dp[i][j]\)表示串长为i,在AC自动机上走到编号为j的节点的合法串个数。
如果走到j的儿子k这个节点的串合法,那么就可以从\((i,j)\)转移到\((i+1,trie[j][k])\),有:
\(dp[i+1][trie[j][k]] += dp[i][j](0 \le k \le 26)\)
初始状态\(dp[0][0]=1\)。答案为所有\(dp[m][i]\)的最大值。
那么如何判断走到点j的串是否合法?想一想fail的性质,我们可以发现:如果从点j不停沿fail往上跳,经过的所有点(包括j)没有船尾的节点,那么j合法,否则不合法。
#include<bits/stdc++.h>
#define rg register
#define qwq 0
using namespace std;
const int N = 63, M = 103, mod = 1e4 + 7;
int n, m, idx;
int trie[N * M][26], dp[M][N * M], fail[N * M];
bool vis[N * M]; //表示节点状态是否包含模式串
char c[M];
inline void insert(char str[]) {
rg int rt = 0;
for (rg int i = 0; str[i]; i++) {
rg int u = str[i] - 'A';
if (!trie[rt][u]) trie[rt][u] = ++idx;
rt = trie[rt][u];
}
vis[rt] = true;
}
inline void get_fail() {
queue<int> q;
for (rg int i = 0; i < 26; i++) {
if (trie[0][i]) q.push(trie[0][i]);
}
while (!q.empty()) {
rg int u = q.front();
q.pop();
for (rg int i = 0; i < 26; i++) {
if (trie[u][i]) {
fail[trie[u][i]] = trie[fail[u]][i];
vis[trie[u][i]] |= vis[fail[trie[u][i]]]; //如果fail不合法,自己也不合法
q.push(trie[u][i]);
} else {
trie[u][i] = trie[fail[u]][i];
}
}
}
}
inline int qpow(int a, int b) {
rg int res = 1;
while (b) {
if (b & 1) res = res * a % mod;
a = a * a % mod;
b >>= 1;
}
return res;
}
int main() {
scanf("%d%d", &n, &m);
for (rg int i = 1; i <= n; i++) {
scanf("%s", c);
insert(c);
}
get_fail();
dp[0][0] = 1;
for (rg int i = 0; i <= m - 1; i++) {
for (rg int j = 0; j <= idx; j++) {
for (rg int k = 0; k <= 25; k++) {
if (!vis[trie[j][k]]) dp[i + 1][trie[j][k]] = (dp[i + 1][trie[j][k]] + dp[i][j]) % mod;
}
}
}
rg int ans = qpow(26, m);
for (rg int i = 0; i <= idx; i++) {
ans = (ans - dp[m][i] + mod) % mod;
}
printf("%d", ans);
return qwq;
}
例题9:[HNOI2006] 最短母串问题
首先建出Trie图,通过BFS搜索所有的合法方案。
在建Trie图时,把单词i的末节点值state设为\(1 << i\),BFS时用当前状态\(or\)遍历到节点的state值,当BFS到某个状态等于\((1 << n) - 1\)时,就证明所有的单词的末端点都遍历了一次,那这个搜索到的字符串就是我们想要的答案了。
当然了,我们把当前搜索到的字符串都存下来是相当消耗空间的,所以我们可以想出一个优化:因为所有BFS出的状态都是由前面的状态推出来的,所以我们可以每次搜索时只记录这次搜索的字符和它是由哪个状态推导出的,输出时倒回去统计就行了。
另外,可能有相同的字符串,所以设state的初值时也要\(or\)而不是直接赋值。
#include<bits/stdc++.h>
#define rg register
#define qwq 0
using namespace std;
const int N = 603;
int nod, n, idx, tot;
int trie[N][26], fail[N], state[N], ans[N * (1 << 12 | 1)], fa[N * (1 << 12 | 1)];
bool vis[N][1 << 12 | 1];
char c[N], ch[51];
inline void insert(char str[], int id) {
rg int rt = 0;
for (rg int i = 0; str[i]; i++) {
rg int u = str[i] - 'A';
if (!trie[rt][u]) trie[rt][u] = ++idx;
rt = trie[rt][u];
}
state[rt] |= 1 << (id - 1); //有重复的,要用|
}
inline void get_fail() {
queue<int> q;
for (rg int i = 0; i < 26; i++) {
if (trie[0][i]) q.push(trie[0][i]);
}
while (!q.empty()) {
rg int u = q.front();
q.pop();
for (rg int i = 0; i < 26; i++) {
if (trie[u][i]) {
fail[trie[u][i]] = trie[fail[u]][i];
state[trie[u][i]] |= state[trie[fail[u]][i]]; //它的fail指针包含的字符串它也包含
q.push(trie[u][i]);
} else {
trie[u][i] = trie[fail[u]][i];
}
}
}
}
int main() {
scanf("%d", &n);
for (rg int i = 1; i <= n; i++) {
scanf("%s", c);
insert(c, i);
}
get_fail();
queue<int> q1, q2;
q1.push(0); //在trie中的位置
q2.push(0); //状态压缩,表示当前包含了哪些要求的字符串
vis[0][0] = 1;
rg int tim = 0; //表示当前搜索到的编号
while (!q1.empty()) {
rg int rt = q1.front(), st = q2.front();
q1.pop();
q2.pop();
if (st == ((1 << n) - 1)) {
while (tim) { //递归回去求答案
c[++nod] = ans[tim];
tim = fa[tim];
}
for (rg int i = nod; i > 0; i--) printf("%c", c[i] + 'A');
return qwq;
}
for (rg int i = 0; i < 26; i++) {
if (!vis[trie[rt][i]][st | state[trie[rt][i]]]) {
//找出新的状态
vis[trie[rt][i]][st | state[trie[rt][i]]] = 1;
q1.push(trie[rt][i]);
q2.push(st | state[trie[rt][i]]);
//记录当前搜到的字符,同时建一棵关于答案的树,方便最后查询
fa[++tot] = tim;
ans[tot] = i;
}
}
tim++;
}
return qwq;
}
例题10:[BZOJ 2905]背单词
考虑AC自动机上一个子串是另一个串的子串的条件。如果T是S的子串,那么就会存在一个S的子串S',从S'不断跳fail可以到达T。
那么我们就先建出fail树。然后设\(dp[i]\)表示到第i个串且选择\(S_i\)的最大值。那么:
\(dp[i]=max\{dp[j]\}+w[i]\)
其实\(S_j\)作为\(S_i\)的一个子串,也就是j是i的一个前缀在fail树上的一个父亲。我们暴力枚举每个串以及它的前缀,那么我们就要求出这个当前前缀在fail树上的最大值。求出\(dp[i]\)之后我们要把它加到i点上。
总结一下上面的操作:在fail树上查询从根节点到一个节点的链上的最大值,在fail函数上给一个点加上一个数。所以可以DFS序+线段树维护(树剖也可)。
#include<bits/stdc++.h>
#define rg register
#define qwq 0
using namespace std;
const int N = 3e5 + 3;
int T, n, w[N];
string s[N];
int trie[N][26], fail[N];
int idx, num[N], fa[N];
inline void insert(string str, int id) {
rg int rt = 0;
for (rg int i = 0; str[i]; i++) {
rg int u = str[i] - 'a';
if (!trie[rt][u]) {
trie[rt][u] = ++idx;
fa[idx] = rt;
}
rt = trie[rt][u];
}
num[id] = rt;
}
inline void get_fail() {
queue<int> q;
for (rg int i = 0; i < 26; i++) {
if (trie[0][i]) q.push(trie[0][i]);
}
while (!q.empty()) {
rg int u = q.front();
q.pop();
for (rg int i = 0; i < 26; i++) {
if (trie[u][i]) {
fail[trie[u][i]] = trie[fail[u]][i];
q.push(trie[u][i]);
} else {
trie[u][i] = trie[fail[u]][i];
}
}
}
}
vector<int> e[N << 1];
int in[N], out[N], tot, dfn[N];
inline void dfs(int u) {
if (u) in[u] = ++tot;
//cout << u << " " << tot << "\n";
for (rg int i = 0; i < e[u].size(); i++) {
rg int v = e[u][i];
dfs(v);
}
out[u] = tot;
}
#define ls rt << 1
#define rs rt << 1 | 1
struct SegmentTree {
int l, r, maxx, lazy;
} t[N << 2];
inline void pushdown(int rt) {
if (t[rt].lazy) {
t[ls].maxx = max(t[ls].maxx, t[rt].lazy);
t[rs].maxx = max(t[rs].maxx, t[rt].lazy);
t[ls].lazy = max(t[ls].lazy, t[rt].lazy);
t[rs].lazy = max(t[rs].lazy, t[rt].lazy);
t[rt].lazy = 0;
}
}
inline void build(int rt, int l, int r) {
t[rt].l = l, t[rt].r = r;
if (l == r) {
t[rt].maxx = 0;
return ;
}
rg int mid = (l + r) >> 1;
build(ls, l, mid);
build(rs, mid + 1, r);
t[rt].maxx = max(t[ls].maxx, t[rs].maxx);
}
inline void update(int rt, int l, int r, int val) {
pushdown(rt);
if (t[rt].l == l && t[rt].r == r) {
t[rt].maxx = max(t[rt].maxx, val);
t[rt].lazy = max(t[rt].lazy, val);
return ;
}
rg int mid = (t[rt].l + t[rt].r) >> 1;
if (l > mid) update(rs, l, r, val);
else if (r <= mid) update(ls, l, r, val);
else {
update(ls, l, mid, val);
update(rs, mid + 1, r, val);
}
t[rt].maxx = max(t[ls].maxx, t[rs].maxx);
}
inline int query(int rt, int x) {
pushdown(rt);
if (t[rt].l == t[rt].r && t[rt].l == x) {
return t[rt].maxx;
}
rg int mid = (t[rt].l + t[rt].r) >> 1;
if (x <= mid) return query(ls, x);
else return query(rs, x);
}
inline void init() {
idx = tot = 0;
memset(fa, 0, sizeof(fa));
memset(num, 0, sizeof(num));
memset(trie, 0, sizeof(trie));
memset(fail, 0, sizeof(fail));
memset(in, 0, sizeof(in));
memset(out, 0, sizeof(out));
memset(dfn, 0, sizeof(dfn));
memset(t, 0, sizeof(t));
memset(e, 0, sizeof(e));
}
int main() {
scanf("%d", &T);
while (T--) {
init();
scanf("%d", &n);
for (rg int i = 1; i <= n; i++) {
cin >> s[i] >> w[i];
insert(s[i], i);
}
get_fail();
for (rg int i = 1; i <= idx; i++) {
e[fail[i]].push_back(i);
}
dfs(0);
build(1, 1, idx);
rg int ans = 0;
for (rg int i = 1; i <= n; i++) {
rg int p = num[i], res = 0;
while (p) { //遍历Si的所有子串
res = max(res, query(1, in[p]));
p = fa[p];
}
p = num[i];
res = max(res, res + w[i]);
ans = max(ans, res);
update(1, in[p], out[p], res);
}
printf("%d\n", ans);
}
return qwq;
}
例题11:[JSOI2009] 密码
AC自动机上状压dp。令\(dp[i][j][S]\)表示长度为i,当前在Trie树上的j号节点,观察到的字符串的使用情况为S的方案数。
转移很显然:
\(dp[i+1][k][S|S_k] = \sum dp[i][j][S]\)
对于输出方案:因为只有当方案数小于等于42时才输出具体方案,所以直接暴搜出每种方案再输出即可。
因为字符串都是紧密结合的,不存在自由的可以填26种字母的位置,那么我们只需要预处理两个模式串组合最少的字符长度,\(O(n!)\)枚举模式串排列即可,如果大于L就剪枝。
#include<bits/stdc++.h>
#define rg register
#define qwq 0
#define ll long long
using namespace std;
const int N = 102;
int n, m, idx;
int trie[N][26], fail[N], zt[N];
char s[13][13];
ll dp[30][N][1027];
inline void insert(char str[], int id) {
rg int rt = 0;
for (rg int i = 0; str[i]; i++) {
rg int u = str[i] - 'a';
if (!trie[rt][u]) trie[rt][u] = ++idx;
rt = trie[rt][u];
}
zt[rt] = 1 << (id - 1);
}
inline void get_fail() {
queue<int> q;
for (rg int i = 0; i < 26; i++) {
if (trie[0][i]) q.push(trie[0][i]);
}
while (!q.empty()) {
rg int u = q.front();
q.pop();
for (rg int i = 0; i < 26; i++) {
if (trie[u][i]) {
fail[trie[u][i]] = trie[fail[u]][i];
zt[trie[u][i]] |= zt[trie[fail[u]][i]];
q.push(trie[u][i]);
} else {
trie[u][i] = trie[fail[u]][i];
}
}
}
}
int g[20], vis[20], cnt[20][20], tot;
string p[50], tmp;
inline void dfs(int x) {
if (x == m + 1) {
tmp = "";
for (rg int i = 1; i <= m; i++) {
rg int st = 0, len = strlen(s[g[i]]);
if (i != 1) st = cnt[g[i - 1]][g[i]];
for (rg int j = st; j < len; j++) {
tmp += s[g[i]][j];
}
}
if (tmp.length() == n) p[++tot] = tmp;
return ;
}
for (rg int i = 1; i <= m; i++) {
if (!vis[i]) {
g[x] = i;
vis[i] = 1;
dfs(x + 1);
vis[i] = 0;
g[x] = 0;
}
}
}
int main() {
ios::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
cin >> n >> m;
for (rg int i = 1; i <= m; i++) {
cin >> s[i];
insert(s[i], i);
}
get_fail();
dp[0][0][0] = 1;
for (rg int i = 0; i <= n; i++) {
for (rg int j = 0; j <= idx; j++) {
for (rg int S = 0; S < (1 << m); S++) {
if (dp[i][j][S]) {
for (rg int c = 0; c < 26; c++) {
rg int k = trie[j][c];
dp[i + 1][k][S | zt[k]] += dp[i][j][S];
}
}
}
}
}
rg ll ans = 0;
for (rg int i = 0; i <= idx; i++) {
ans += dp[n][i][(1 << m) - 1];
}
cout << ans << "\n";
if (ans <= 42) {
for (rg int x = 1; x <= m; x++) {
for (rg int y = 1; y <= m; y++) {
if (x == y) continue;
rg int lenx = strlen(s[x]), leny = strlen(s[y]);
for (rg int len = min(lenx, leny); len >= 0; len--) {
rg bool flag = false;
for (rg int i = 0; i < len; i++) {
rg int j = lenx - len + i;
if (s[x][j] != s[y][i]) {
flag = true;
break;
}
}
if (!flag) {
cnt[x][y] = len; //串s[x]和串s[y]的最长拼接部分
break;
}
}
}
}
dfs(1);
sort(p + 1, p + ans + 1);
for (rg int i = 1; i <= ans; i++) {
cout << p[i] << "\n";
}
}
}
标签:rt,AC,int,++,trie,rg,fail,自动机
From: https://www.cnblogs.com/Baiyj/p/18242906