点分治及其应用
算法:点分治,树的重心。
思想
先说一下点分治的基本思想:选择树上一个点作为分治中心,为了保证复杂度,选择的点有一些特殊的要求。
接下来,把原问题分解成几个相同的子问题,进行递归解决。
一般地,我们假设当前根节点为 \(rt\),所以我们要统计的路径必然满足以下二者之一:
-
经过 \(rt\)。
-
不经过 \(rt\),就是在 \(rt\) 的子树上。
树的重心
上面说到,为了保证时间复杂度,我们选择的点有一些特殊要求。一般地,我们选择的点为树的重心。因为这样剩下的子树的最大大小不超过整棵树大小的一半,所以这样递归层数为 \(O(log n)\) 级别的。
为了帮助读者更好的理解上文,这里放一下查找树的重心的代码:
void get_root(int u,int f){
siz[u]=1;
mx[u]=0;
for(int i=h[u];~i;i=ne[i]){
int j=e[i];
if(j==f||st[j])continue;
get_root(j,u);
siz[u]+=siz[j];
mx[u]=max(mx[u],siz[j]);
}
mx[u]=max(mx[u],sum-siz[u]);
if(mx[u]<mx[rt])rt=u;//根据重心的定义找重心
}
点分治1
这道例题比较简单,但是实测如果每次都跑一次点分治的话时间不能接受,所以我们考虑把询问离线下来,然后做一次点分治处理所有操作。
下面说一下这道题的做法:
考虑路径是否经过根 \(rt\),不经过的我们递归处理,这里考虑经过的怎么做。
比较显然地,如果在一个子树中距离 \(rt\) 长为 \(l\) 的链,并且距离 \(rt\) 长为 \(k-l\) 的链在其他子树中出现,那么长为 \(k\) 的链一定存在。
并且,两者的出现顺序对于我们的答案没有影响,于是我们便可以一棵一棵子树维护,具体分为下面几步:
-
\(dfs\) 处理当前子树上每个点与根 \(rt\) 的距离。
-
在 \(dfs\) 时,记录哪些长度的链出现过。
-
与之前已经统计过的子树的数据相结合,并且进行更新。
-
清空当前子树的边的信息,这里为了保证时间复杂度,不能使用 \(memset\)。
-
最后递归进入子树进行点分治。
有一些小细节和注释,放在了下面的代码中:
#include<bits/stdc++.h>
#define int long long
#define N 100005
#define M 200005
#define K 10000005
#define inf 2e18
using namespace std;
int n,m,h[N],e[M],w[M],ne[M],idx;
int qs[N],res[N],rt,sum,siz[N],mx[N];
int st[N],tf[K],dis[N],d[N],dcnt,q[N],hh,tt;
/*
第一行基础变量和链式前向星存图
qs 为每个询问,res 为询问结果
sum 子树总大小,rt 当前树根
mx 最大子树大小,st 当前点是否被处理过
tf 是否有某一长度的链,dis 距离根的距离
d 当前子树的链的长度,dcnt 当前子树到根的链的个数
q,hh,tt 队列
*/
void add(int a,int b,int c){
e[idx]=b;w[idx]=c;ne[idx]=h[a];h[a]=idx++;
}
void get_root(int u,int f){
siz[u]=1;
mx[u]=0;
for(int i=h[u];~i;i=ne[i]){
int j=e[i];
if(j==f||st[j])continue;
get_root(j,u);
siz[u]+=siz[j];
mx[u]=max(mx[u],siz[j]);
}
mx[u]=max(mx[u],sum-siz[u]);
if(mx[u]<mx[rt])rt=u;//根据重心的定义找重心
}
void get_dist(int u,int f){
d[++dcnt]=dis[u];//当前的距离存入可能的长度
for(int i=h[u];~i;i=ne[i]){
int j=e[i];
if(j==f||st[j])continue;
dis[j]=dis[u]+w[i];
get_dist(j,u);
}
}
void solve(int u,int f){
hh=0;tt=-1;
q[++tt]=0;
tf[0]=1;st[u]=1;//当前点已经处理过,且长度为0的链一定存在(即不选)
for(int i=h[u];~i;i=ne[i]){
int v=e[i];
if(v==f||st[v])continue;
dis[v]=w[i];
get_dist(v,u);
for(int k=1;k<=dcnt;k++){//链长度
for(int j=1;j<=m;j++){//询问
if(qs[j]>=d[k]){
res[j]|=tf[qs[j]-d[k]];//如果d[k]和qs[j]-d[k]都出现,就存在
}
}
}
for(int j=1;j<=dcnt;j++){
if(d[j]<10000005){//如果链长度有意义(可能被询问)
q[++tt]=d[j];//就加到队列里
tf[d[j]]=1;//这个长度能被凑出来
}
}
dcnt=0;
}
while(hh<=tt)tf[q[hh++]]=0;//清空之前子树的信息
for(int i=h[u];~i;i=ne[i]){
int j=e[i];//继续找其他子树计算
if(j==f||st[j])continue;
sum=siz[j];rt=0;
mx[rt]=inf;
get_root(j,u);
get_root(rt,-1);
solve(rt,u);
}
}
signed main(){
cin>>n>>m;
memset(h,-1,sizeof h);
for(int i=1;i<n;i++){
int a,b,c;
cin>>a>>b>>c;
add(a,b,c);add(b,a,c);
}
for(int i=1;i<=m;i++){
cin>>qs[i];
}
rt=0;mx[rt]=inf;sum=n;
get_root(1,-1);
get_root(rt,-1);
solve(rt,-1);
for(int i=1;i<=m;i++){
if(res[i])cout<<"AYE\n";
else cout<<"NAY\n";
}
return 0;
}
点分治上面的三个函数比较模板化,最多只需要微调,建议理解性背诵,而 \(solve\) 函数需要根据题目来推,相对不需要过多记忆。