这里是模板题 P6139。
进行了一个广义 SAM 的学习。离线部分 OiWiki 讲得很好,但是在线部分没有。我做一些补充。
首先 SAM 是什么?简单来说,
-
SAM 是一个有向无环图,节点是状态,对应一个 endpos(\(S\) 子串 \(t\) 的结尾位置集合)。边标有字符。
-
SAM 是有 \(t_0\) 初始状态,若干个结尾状态,是 \(\mathcal{O}(n)\) 的,\(\le 4\) 常数。
-
本质不同字串个数=路径条数,可以 dp。
-
一个状态中,定义 \(len(v)\) 为 \(t_0\) 到 \(v\) 的最长路径长度。
-
一个状态中,定义 \(fa(v)\) 为 \(v\) 最长的后缀,使出现位置多余 \(v\)。
-
endpos 要么无交要么包含,可以构成后缀树。同一个 endpos 对应的字符串长度是一个区间,\([len(fa(v))+1,len(v)]\)。
我们当作你会 SAM 了(你去 OiWiki 看一下模板代码),广义 SAM 是什么呢?
首先看一下代码。
pos[0]=1;
for (int j=1; j<s.size(); j++){
pos[j]=ins(s[j]-'a',pos[j-1]);
}
这个是主函数里的。因为是多个串,\(lst\) 是什么你只能是上一个的结尾,不能直接 ++tot
求得。
int ins(int c,int lst){
// (1)
if (t[lst].ch[c] && t[t[lst].ch[c]].len==t[lst].len+1){
return t[lst].ch[c];
}
//
int p=lst,np=++tot,fl=0;
t[np].len=t[p].len+1;
while (p && !t[p].ch[c]){
t[p].ch[c]=np;
p=t[p].fa;
}
if (!p){
t[np].fa=1;
return np;
}
int q=t[p].ch[c];
if (t[q].len==t[p].len+1){
t[np].fa=q;
return np;
}
// (2)
if (p==lst){
fl=1;
np=0;
tot--;
}
//
int nq=++tot;
t[nq]=t[q];
t[nq].len=t[p].len+1;
t[q].fa=t[np].fa=nq;
while (p && t[p].ch[c]==q){
t[p].ch[c]=nq;
p=t[p].fa;
}
// (3)
return fl?nq:np;
//
}
这个是拓展的部分。发现和普通的 SAM 相比,有 \(3\) 个不同的部分。为什么呢?
(1)处是因为如果有一个连续的转移,没必要新建了。
(2)处是因为:这个是有 \(c\) 的儿子,但是不是连续的,那么,我们只需要建新状态,不需要建 \(lst\rightarrow c\)。这个清空 \(np\) 然后 \(tot--\) 就可以了。
(3)处是因为:很容易理解,你删了 \(np\) 就得返回 \(nq\)。
然后就做完了。
Code
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
const int N = 2e6+6;
struct sam {
int fa,len,ch[26];
} t[N];
int tot=1,pos[N];
int ins(int c,int lst){
if (t[lst].ch[c] && t[t[lst].ch[c]].len==t[lst].len+1){
return t[lst].ch[c];
}
int p=lst,np=++tot,fl=0;
t[np].len=t[p].len+1;
while (p && !t[p].ch[c]){
t[p].ch[c]=np;
p=t[p].fa;
}
if (!p){
t[np].fa=1;
return np;
}
int q=t[p].ch[c];
if (t[q].len==t[p].len+1){
t[np].fa=q;
return np;
}
if (p==lst){
fl=1;
np=0;
tot--;
}
int nq=++tot;
t[nq]=t[q];
t[nq].len=t[p].len+1;
t[q].fa=t[np].fa=nq;
while (p && t[p].ch[c]==q){
t[p].ch[c]=nq;
p=t[p].fa;
}
return fl?nq:np;
}
int main(){
ios::sync_with_stdio(false);
cin.tie(0);
int n;
cin>>n;
for (int i=1; i<=n; i++){
string s;
cin>>s;
s=" "+s;
pos[0]=1;
for (int j=1; j<s.size(); j++){
pos[j]=ins(s[j]-'a',pos[j-1]);
}
}
ll ans=0;
for (int i=2; i<=tot; i++){
ans+=t[i].len-t[t[i].fa].len;
}
cout<<ans<<"\n"<<tot<<"\n";
return 0;
}