注:
- 下文中,\(s[l\sim r]\)表示截取字符串\(s\)的第\(l\)个字符到第\(r\)个字符。
- 文字描述的字符串下标从\(1\)开始,但代码实现从\(0\)开始。
我们建出AC自动机后,有一个比较暴力的思路。
我么用\(f[i]\)表示待查找字符串\(t\)的长度为\(i\)前缀是否满足题意。
我们求\(f[i]\),就从匹配到\(t[i]\)时自动机上的指针\(p\)开始,沿着fail树往上跳到根节点。中途经过的状态正是\(t[1\sim i]\)的所有后缀,对于长度为\(j\)的后缀,如果它正好是一个模式串,而且\(f[i-j]=\text{true}\),那么\(f[i-j]\)的状态就可以转移到\(f[i]\)上。
其中\(len[i]\)表示以节点\(i\)结尾的字符串长度,如果没有字符串以\(i\)结尾,则值为\(0\)。
核心代码:
int query(string s){
int p=0,n=s.size(),ans=0;
for(int i=1;i<=n;i++){
p=tr[p][s[i-1]-'a'];
f[i]=0;
for(int j=p;j;j=fail[j]){
if(len[j]&&f[i-len[j]]){
f[i]=1,ans=i;
break;
}
}
}
return ans;
}
优化前的代码 - 95pts TLE
#include<bits/stdc++.h>
#define T 30//模式串个数
#define N 410//模式串总长(节点数)
#define M 2000010//单个主串长度
#define S 26//字符集大小
using namespace std;
int n,m,tr[N][S],fail[N],cnt;
int f[M],len[N];
string s[T];
queue<int> q;
void ins(string s){
int p=0;
for(char i:s){
int c=i-'a';
if(!tr[p][c])
tr[p][c]=++cnt;
p=tr[p][c];
}
len[p]=s.size();
}
void get_fail(){
for(int i=0;i<26;i++)
if(tr[0][i]) q.push(tr[0][i]);
while(!q.empty()){
int u=q.front();
q.pop();
for(int i=0;i<26;i++){
if(tr[u][i])
fail[tr[u][i]]=tr[fail[u]][i],q.push(tr[u][i]);
else tr[u][i]=tr[fail[u]][i];
}
}
}
int query(string s){
int p=0,n=s.size(),ans=0;
for(int i=1;i<=n;i++){
p=tr[p][s[i-1]-'a'];
f[i]=0;
for(int j=p;j;j=fail[j]){
if(f[i-len[j]]&&len[j]){
f[i]=1,ans=i;
break;
}
}
}
return ans;
}
int main(){
ios::sync_with_stdio(false);
cin.tie(nullptr);
cin>>n>>m;
for(int i=1;i<=n;i++){
cin>>s[i];
ins(s[i]);
}
get_fail();
f[0]=1;
while(m--){
cin>>s[0];
cout<<query(s[0])<<"\n";
}
return 0;
}
建立自动机的复杂度是\(O(n|s||\Sigma|)\),查询单个字符串的复杂度是\(O(|t||s|)\),总时间复杂度是\(O(n|s||\Sigma|+m|t||s|)\)。
显然这样会TLE,考虑如何优化查询。
我们发现\(|s|\le20\),所以除去根节点的话,fail树最多是\(20\)层。
因此我们可以想到状态压缩,只要把每个节点到fail树的根节点路径上的信息压缩成一个整数即可。在构建自动机时,为每个节点\(i\)维护一个\(tlen[i]\),其二进制表示下的第\(j\)位(\(j\in [1,20]\))来表示“状态\(i\)长度为\(j\)的后缀是否是一个完整的字符串”。我们可以在build_fail()
中完成这一过程。
每次匹配字符串的过程中,我们再额外维护一个整数\(x\),对于当前枚举的\(t[i]\),\(x\)的二进制表示下的第\(j\)位表示\(f[i-j]\)(即\(t[1\sim i-j]\)是否合法,规定\(f[0]=1,f[负数]=0\)),每次query()
都应重复该过程。
这样,我们要求\(f[i]\),仅需把\(tlen[i]\)和\(x\)按位与一下,如果结果非零则\(f[i]=1\),否则\(f[i]=0\)。
\(tlen\)和\(x\)都可以\(O(1)\)转移得到:
- \(tlen[i]=tlen[fail[i]]+\big(2^{len[i]}\times end(i)\big)\),其中\(end(i)\)表示节点\(i\)是否是模式串结尾,是则为\(1\),不是则为\(0\)。代码实现有一些区别,但本质是一样的。
- \(x=2\times x+f[i-1]\)。
这个优化的本质就是把不断向上跳比较的过程,转换成\(2\)个整数进行按位与操作。因此时间复杂度优化到了\(O(n|s||\Sigma|+m|t|)\),可以AC。
优化后的代码 - 100pts AC
#include<bits/stdc++.h>
#define T 30//模式串个数
#define N 410//模式串总长(节点数)
#define M 2000010//单个主串长度
#define S 26//字符集大小
using namespace std;
int n,m,tr[N][S],fail[N],cnt;
int f[M],tlen[N];
string s[T];
queue<int> q;
void ins(string s){
int p=0;
for(char i:s){
int c=i-'a';
if(!tr[p][c])
tr[p][c]=++cnt;
p=tr[p][c];
}
tlen[p]=1<<(s.size()-1);//不需要len数组了
}
void get_fail(){
for(int i=0;i<26;i++)
if(tr[0][i]) q.push(tr[0][i]);
while(!q.empty()){
int u=q.front();
q.pop();
for(int i=0;i<26;i++){
if(tr[u][i])
fail[tr[u][i]]=tr[fail[u]][i],q.push(tr[u][i]),
tlen[tr[u][i]]|=tlen[fail[tr[u][i]]];//转移1
else tr[u][i]=tr[fail[u]][i];
}
}
}
int query(string s){
int p=0,x=0,n=s.size();
for(int i=1;i<=n;i++){
p=tr[p][s[i-1]-'a'];
x=((x<<1)|f[i-1])&((1<<20)-1);//转移2
f[i]=(x&tlen[p])!=0;
}
for(int i=n;i>=0;i--)
if(f[i]) return i;
return -1;//这一步理论走不到
}
int main(){
ios::sync_with_stdio(false);
cin.tie(nullptr);
cin>>n>>m;
for(int i=1;i<=n;i++){
cin>>s[i];
ins(s[i]);
}
get_fail();
f[0]=1;//初始化别忘记
while(m--){
cin>>s[0];
cout<<query(s[0])<<"\n";
}
return 0;
}