咕了一年的题。
先点分治。考虑某条经过当前重心 \(rt\) 的合法回文路径:(图摘自 yww 的题解)
其中 \(x\to y\) 是合法回文路径,那么 \(T\) 是一个回文串。
先把 \(rt\) 到每个点的串的 AC 自动机建出来,然后在 fail 树上 dfs,并把当前 dfs 到的点当作 \(x\),那么 \(y\) 必然是 \(x\) 在 fail 树上的祖先。于是暴力的想法是跳 fail 并判断 \(T\) 是否为回文串(判断回文串可以通过判断正串和反串的 Hash 值是否相同),若是则统计答案。
更优的做法需要用到回文前缀的性质。
-
引理 1:对于一个长度为 \(n\) 的串 \(S\) 的两个循环节 \(p,q\)(\(p>q\)),若 \(p+q\leq n\),则 \(\gcd(p,q)\) 也为串 \(S\) 的循环节。
证明:记 \(d=p-q\),考虑证明 \(S_i=S_{i+d}\),而由于 \(p+q\leq n\),所以 “减 \(q\) 再加 \(p\)” 和 “加 \(p\) 再减 \(q\)” 两种证明方法中肯定有一种是合法的(不会跳出范围)。
-
引理 2:一个长度为 \(n\) 的串 \(S\) 的 border 集合 \(b(S)\) 可以被表示成至多 \(\lceil\log_2n\rceil\) 个等差数列。特别地,为了方便,我们把 \(n\) 也算作一个 border。
证明:归纳证明。border 与循环节一一对应。考虑 \(S\) 的最短循环节 \(p>0\),即 \(S[p+1,n]\) 为最长非原串 border。
-
若 \(p>n/2\)。由于 \(b(S)=\{n\}\cup b(S[p+1,n])\),而 \(b(S[p+1,n])\) 可以被表示为 \(\lceil\log_2(n-p)\rceil\leq \lceil\log_2n\rceil-1\) 个等差数列,于是 \(b(S)\) 可以被表示为 \(\lceil\log_2n\rceil\) 个等差数列。
-
若 \(p\leq n/2\)。根据引理 1,对于任意 \(S\) 的循环节 \(q\leq n-p\),都有 \(\gcd(p,q)=p\),那么 \(q\) 必然为 \(p\) 的倍数(否则与 \(p\) 是最短循环节矛盾);另一方面,对于任意 \(0\leq tp\leq n-p\),容易发现 \(tp\) 都是 \(S\) 的循环节。综上,对于任意 \(S\) 的循环节 \(0\leq q\leq n-p\),它们构成一个等差数列。
记 \(>n-p\) 的第一个循环节为 \(r\),那么 \(b(S)=\{n-q:0\leq q\leq n-p\text{且}q\text{为}S\text{的循环节}\}\cup b(S[r+1,n])\)。而 \(b(S[r+1,n])\) 可以被表示为 \(\lceil\log_2(n-r)\rceil\leq \lceil\log_2n\rceil-1\) 个等差数列,于是 \(b(S)\) 可以被表示为 \(\lceil\log_2n\rceil\) 个等差数列。
-
-
推论 3:\(b(S)\) 被表示成 \(\lceil\log_2 n\rceil\) 个等差数列,且第 \(i\) 个等差数列的公差不超过 \(n/2^i\)。
证明:由引理 2 的证明容易看出。
-
引理 4:一个长度为 \(n\) 的串 \(S\) 的所有回文前缀长度构成至多 \(\lceil\log_2n\rceil\) 个等差数列。
证明:考虑证明一个回文串 \(S\) 的所有回文前缀对应 \(S\) 的所有 border 即可。
一方面,由于 \(S\) 是回文串,所以其长度为 \(i\) 的前缀和长度为 \(i\) 的后缀互为反串,那么这个前缀和后缀相等就等价于这个串是回文串。
回到上面的算法,假设我们走到了 \(x\),我们要询问的是 \(x\) 在 fail 树上的祖先 \(y\) 中,那些满足 \(x\) 的长度为 \(\operatorname{len}(x)-\operatorname{len}(y)\) 的前缀是回文前缀的 \(y\) 的总贡献。
由于 \(x\) 的回文前缀长度是 \(O(\log n)\) 个等差数列,这也对应于 \(O(\log n)\) 次对于 \(y\) 的等差数列的查询。
考虑对公差根号分治。对于大的公差,我们直接暴力查,结合推论 3,这部分的时间复杂度为 \(O(\sum_{i\geq 0}\frac{n}{2^i\sqrt n})=O(\sqrt n)\);对于小的公差 \(d\),我们相当于查询 \(\operatorname{len}(y)\) 模 \(d\) 等于某数的那些 \(y\) 中,\(\operatorname{len}(y)\) 在一段区间 \([l,r]\) 内的 \(y\) 的贡献和。这看起来需要维护树状数组之类的多带个 \(\log\),但实际上我们可以通过差分变为问 \(\operatorname{len}(y)\leq lim\) 的,于是找到对应的祖先并把询问挂在上面即可。于是处理所有小的公差总共也是 \(O(n\sqrt n)\) 的。
至于如何找 \(x\) 的那 \(O(\log n)\) 个等差数列,可以在 Trie 上边 dfs 边更新回文前缀。
单次对于一棵大小为 \(n\) 的子树,时间复杂度是 \(O(n\sqrt n)\)。总时间复杂度 \(T(n)=2T(n/2)+O(n\sqrt n)=O(n\sqrt n)\)。
#include<bits/stdc++.h>
#define N 50010
#define ll long long
using namespace std;
namespace modular
{
const int bias=1789,base=23333,mod=998244353;
inline int add(int x,int y){return x+y>=mod?x+y-mod:x+y;}
inline int dec(int x,int y){return x-y<0?x-y+mod:x-y;}
inline int mul(int x,int y){return 1ll*x*y%mod;}
inline void Add(int &x,int y){x=x+y>=mod?x+y-mod:x+y;}
inline void Dec(int &x,int y){x=x-y<0?x-y+mod:x-y;}
inline void Mul(int &x,int y){x=1ll*x*y%mod;}
inline int poww(int a,int b){int ans=1;for(;b;Mul(a,a),b>>=1)if(b&1)Mul(ans,a);return ans;}
}using namespace modular;
int pw[N];
inline int read()
{
int x=0,f=1;
char ch=getchar();
while(ch<'0'||ch>'9')
{
if(ch=='-') f=-1;
ch=getchar();
}
while(ch>='0'&&ch<='9')
{
x=(x<<1)+(x<<3)+(ch^'0');
ch=getchar();
}
return x*f;
}
const int B=5;
int n;
int cnt,head[N],nxt[N<<1],to[N<<1],c[N<<1];
int nn,rt,maxn,size[N];
bool vis[N];
void adde(int u,int v,int ci)
{
to[++cnt]=v;
c[cnt]=ci;
nxt[cnt]=head[u];
head[u]=cnt;
}
void getsize(int u,int fa)
{
size[u]=1;
for(int i=head[u];i;i=nxt[i])
{
int v=to[i];
if(vis[v]||v==fa) continue;
getsize(v,u),size[u]+=size[v];
}
}
void findroot(int u,int fa)
{
int nmax=nn-size[u];
for(int i=head[u];i;i=nxt[i])
{
int v=to[i];
if(vis[v]||v==fa) continue;
findroot(v,u),nmax=max(nmax,size[v]);
}
if(nmax<maxn) maxn=nmax,rt=u;
}
namespace AC
{
int node,ch[N][2],fail[N],val[N],len[N];
vector<int> e[N];
ll ans;
struct data{int l,r,d;};
vector<data> tl[N];
void dfs1(int u,int h1,int h2)
{
if(u&&h1==h2)
{
if(!tl[u].empty()&&(tl[u].back().d==-1||len[u]-tl[u].back().r==tl[u].back().d))
tl[u].back().d=len[u]-tl[u].back().r,tl[u].back().r=len[u];
else tl[u].push_back({len[u],len[u],-1});
}
ans+=1ll*val[u]*(val[u]-1)/2;
for(int i:{0,1})
{
int v=ch[u][i]; if(!v) continue;
tl[v]=tl[u],len[v]=len[u]+1,dfs1(v,add(mul(h1,base),add(i,bias)),add(h2,mul(add(i,bias),pw[len[u]])));
}
}
void build()
{
queue<int> q;
for(int i:{0,1}) if(ch[0][i]) q.push(ch[0][i]);
while(!q.empty())
{
int u=q.front();
q.pop();
for(int i:{0,1})
{
if(ch[u][i])
{
int v=fail[u];
while(v&&!ch[v][i]) v=fail[v];
fail[ch[u][i]]=ch[v][i];
q.push(ch[u][i]);
}
}
}
for(int i=1;i<=node;i++) e[fail[i]].push_back(i);
}
vector<tuple<int,int,int>> qq[N][2];
void dfs2(int u)
{
static int top,sta[N],mp[N];
auto find=[&](int x)
{
int l=1,r=top,ans=0;
while(l<=r)
{
int mid=(l+r)>>1;
if(len[sta[mid]]<=x) ans=mid,l=mid+1;
else r=mid-1;
}
return sta[ans];
};
sta[++top]=u,mp[len[u]]+=val[u];
for(auto q:tl[u])
{
int l=len[u]-q.r,r=len[u]-q.l,d=q.d;
if(d==-1){ans+=1ll*val[u]*mp[l];continue;}
if(d>B) for(int i=l;i<=r;i+=d) ans+=1ll*val[u]*mp[i];
else
{
qq[find(r)][0].emplace_back(d,r%d,val[u]);
if(l) qq[find(l-1)][1].emplace_back(d,l%d,val[u]);
}
}
for(int v:e[u]) dfs2(v);
top--,mp[len[u]]-=val[u];
}
void dfs3(int u)
{
static int c[B+1][B+1];
for(int d=1;d<=B;d++) c[d][len[u]%d]+=val[u];
for(auto q:qq[u][0]) ans+=1ll*get<2>(q)*c[get<0>(q)][get<1>(q)];
for(auto q:qq[u][1]) ans-=1ll*get<2>(q)*c[get<0>(q)][get<1>(q)];
for(int v:e[u]) dfs3(v);
for(int d=1;d<=B;d++) c[d][len[u]%d]-=val[u];
}
ll work()
{
ans=0;
dfs1(0,0,0);
build();
dfs2(0),dfs3(0);
for(int i=0;i<=node;i++)
{
ch[i][0]=ch[i][1]=fail[i]=val[i]=len[i]=0;
e[i].clear(),tl[i].clear(),qq[i][0].clear(),qq[i][1].clear();
}
node=0;
return ans;
}
}
void dfs(int u,int fa,int tu)
{
AC::val[tu]++;
for(int i=head[u];i;i=nxt[i])
{
int v=to[i];
if(vis[v]||v==fa) continue;
if(!AC::ch[tu][c[i]]) AC::ch[tu][c[i]]=++AC::node;
dfs(v,u,AC::ch[tu][c[i]]);
}
}
ll calc(int u,int tu)
{
dfs(u,0,tu);
return AC::work();
}
ll ans;
void solve(int u)
{
vis[u]=1;
ans+=calc(u,0);
for(int i=head[u];i;i=nxt[i])
{
int v=to[i];
if(vis[v]) continue;
AC::ch[0][c[i]]=++AC::node;
ans-=calc(v,AC::ch[0][c[i]]);
getsize(v,u);
nn=size[v],maxn=INT_MAX,findroot(v,u);
solve(rt);
}
}
int main()
{
n=read();
pw[0]=1;
for(int i=1;i<=n;i++) pw[i]=mul(pw[i-1],base);
for(int i=1;i<n;i++)
{
int u=read(),v=read(),c=read();
adde(u,v,c),adde(v,u,c);
}
getsize(1,0);
nn=size[1],maxn=INT_MAX,findroot(1,0);
solve(rt);
printf("%lld\n",ans);
return 0;
}
标签:yww,ch,log,int,len,leq,LOJ6681,XSY3320,回文
From: https://www.cnblogs.com/ez-lcw/p/16912205.html