考一遍,学一遍,忘一遍
这里是重链剖分。
两个dfs,第一个找重儿子,第二个找重链顶和dfn(注意要优先对重儿子dfs来保证同一条重链上的dfs序连续)
查询和维护时一个一个跳重链顶端,时间复杂度O(nlogn)。常和线段树配套使用。
模板
#include<bits/stdc++.h>
#define ll long long
#define lid (id<<1)
#define rid (id<<1|1)
using namespace std;
const int maxn=100005;
int tot,h[maxn<<1];
struct edge{
int to,nxt;
}e[maxn<<1];
void addedge(int u,int v){
e[++tot].to=v;
e[tot].nxt=h[u];
h[u]=tot;
}
int n,m,num,rt,size[maxn],f[maxn],dep[maxn],son[maxn],pre[maxn],dfn[maxn],top[maxn];
ll mod,a[maxn];
struct node{
int l,r; ll lazy,val;
}t[maxn<<2];
void dfs1(int x,int fa){
size[x]=1;
for(int i=h[x];i;i=e[i].nxt){
int y=e[i].to;
if(y==fa) continue;
dep[y]=dep[x]+1;
f[y]=x;
dfs1(y,x);
size[x]+=size[y];
if(size[y]>size[son[x]]) son[x]=y;
}
}
void dfs2(int x,int tp){
dfn[x]=++num,top[x]=tp;pre[num]=x;
if(son[x]) dfs2(son[x],tp);
for(int i=h[x];i;i=e[i].nxt){
int y=e[i].to;
if(y==f[x]||y==son[x]) continue;
dfs2(y,y);
}
}
void build(int id,int l,int r){
t[id].l=l,t[id].r=r;
if(l==r){
t[id].val=a[pre[l]]%mod;
return ;
}
int mid=(l+r)>>1;
build(lid,l,mid);
build(rid,mid+1,r);
t[id].val=(t[lid].val+t[rid].val)%mod;
}
void pushdown(int id){
if(t[id].lazy&&t[id].l!=t[id].r){
t[lid].lazy=(t[lid].lazy+t[id].lazy)%mod;
t[rid].lazy=(t[rid].lazy+t[id].lazy)%mod;
t[lid].val=(t[lid].val+t[id].lazy*(t[lid].r-t[lid].l+1)%mod)%mod;
t[rid].val=(t[rid].val+t[id].lazy*(t[rid].r-t[rid].l+1)%mod)%mod;
t[id].lazy=0;
}
}
void add(int id,int l,int r,ll val){
if(t[id].l==l&&t[id].r==r){
t[id].lazy=(t[id].lazy+val)%mod;
t[id].val=(t[id].val+val*(t[id].r-t[id].l+1)%mod)%mod;
return;
}
pushdown(id);
int mid=(t[id].l+t[id].r)>>1;
if(r<=mid) add(lid,l,r,val);
else if(l>mid) add(rid,l,r,val);
else add(lid,l,mid,val),add(rid,mid+1,r,val);
t[id].val=(t[lid].val+t[rid].val)%mod;
}
ll query(int id,int l,int r){
if(t[id].l==l&&t[id].r==r){
return t[id].val;
}
pushdown(id);
int mid=(t[id].l+t[id].r)>>1;
if(r<=mid) return query(lid,l,r);
else if(l>mid) return query(rid,l,r);
else return (query(lid,l,mid)+query(rid,mid+1,r))%mod;
}
void qadd(int x,int y,ll k){
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);//比的是top,否则可能跳多了……
add(1,dfn[top[x]],dfn[x],k);
x=f[top[x]];
}
if(dep[x]<dep[y]) swap(x,y);
add(1,dfn[y],dfn[x],k);
}
ll qsum(int x,int y){
ll ans=0;
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
ans=(ans+query(1,dfn[top[x]],dfn[x]))%mod;
x=f[top[x]];
}
if(dep[x]<dep[y]) swap(x,y);
ans=(ans+query(1,dfn[y],dfn[x]))%mod;
return ans;
}
int main(){
scanf("%d%d%d%lld",&n,&m,&rt,&mod);
for(int i=1;i<=n;i++){
scanf("%lld",&a[i]);
}
for(int i=1;i<n;i++){
int u,v;
scanf("%d%d",&u,&v);
addedge(u,v),addedge(v,u);
}
dfs1(rt,0);
dfs2(rt,rt);
build(1,1,n);
for(int i=1;i<=m;i++){
int tmp,x,y; ll z;
scanf("%d%d",&tmp,&x);
if(tmp==1){
scanf("%d%lld",&y,&z);
qadd(x,y,z%mod);
}
else if(tmp==2){
scanf("%d",&y);
printf("%lld\n",qsum(x,y));
}
else if(tmp==3){
scanf("%lld",&z);
add(1,dfn[x],dfn[x]+size[x]-1,z%mod);
// cout<<dfn[x]<<" "<<dfn[x]+size[x]-1<<endl;
}
else{
printf("%lld\n",query(1,dfn[x],dfn[x]+size[x]-1));
//cout<<dfn[x]<<" "<<dfn[x]+size[x]-1<<endl;
}
//cout<<i<<" "<<query(1,5,5)<<endl;
}
return 0;
}
性质:树上的每个节点都属于且仅属于一条重链;重链内的dfs序是连续的;一棵子树内的dfs序是连续的;当我们向下经过一条轻边时,所在子树大小至少除以2。
查询时间复杂度证明:重链的两边必是轻边,而每向下经过一条轻边,所在子树大小至少除以二,所以是log级别的。
另外,树剖可以求lca且一般跑不满,会比倍增lca快很多。(需要注意的是,最后top相等的时候,深度小的为lca。)当然欧拉序lca和黑科技dfs序lca的查询确实很快。
最后一点性质的应用:树上启发式合并。
步骤:1.遍历节点u的轻儿子并计算答案,但不保留遍历后它对cnt数组的影响
2.遍历它的重儿子,保留它对cnt数组的影响
3.再次遍历u的轻儿子的子树节点,加入这些节点的贡献,得到u的答案
时间复杂度的证明,根节点到树上任意节点的轻边数不超过logn条,而一个节点被遍历的次数等于它到根节点路径上的轻边数+1,总时间复杂度O(nlogn)
例题:CF741D
码
//正解树上启发式合并+重链剖分:
//计算轻儿子答案不保留对cnt贡献-->计算重儿子答案并保留贡献-->再次遍历轻儿子子树节点
#include<bits/stdc++.h>
using namespace std;
const int maxn=2e6+5;
int n,h[maxn],cnt,to[maxn],nxt[maxn],w[maxn],siz[maxn],son[maxn];
int id[maxn],l[maxn],r[maxn],num,dis[maxn],ans[maxn],t[1<<22],dep[maxn];
void add(int u,int v,int val){
nxt[++cnt]=h[u];
to[cnt]=v;
h[u]=cnt;
w[cnt]=(1<<val);
}
void pre(int x,int fa){
siz[x]=1,id[++num]=x,l[x]=num;
dep[x]=dep[fa]+1;
for(int i=h[x];i;i=nxt[i]){
int y=to[i];
dis[y]=(dis[x]^w[i]);
pre(y,x);
siz[x]+=siz[y];
if(siz[y]>siz[son[x]]) son[x]=y;
} r[x]=num;
}
void dfs(int x,int tg){
for(int i=h[x];i;i=nxt[i]){
int y=to[i];
if(y==son[x]) continue;
dfs(y,0),ans[x]=max(ans[x],ans[y]);
}
if(son[x]) dfs(son[x],1),ans[x]=max(ans[x],ans[son[x]]);
if(t[dis[x]]) ans[x]=max(ans[x],t[dis[x]]-dep[x]);//特判重儿子,否则难以计算到全在这里的情况
for(int i=0;i<=21;i++) ans[x]=max(ans[x],t[dis[x]^(1<<i)]-dep[x]);
t[dis[x]]=max(t[dis[x]],dep[x]); //更新!否则都不能转移,全是0
for(int i=h[x];i;i=nxt[i]){
int y=to[i];
if(y==son[x]) continue;
for(int j=l[y];j<=r[y];j++){
int z=id[j];
if(t[dis[z]]) ans[x]=max(ans[x],t[dis[z]]+dep[z]-2*dep[x]);
for(int k=0;k<=21;k++){
if(t[dis[z]^(1<<k)]) ans[x]=max(ans[x],t[dis[z]^(1<<k)]+dep[z]-2*dep[x]);
}
}
for(int j=l[y];j<=r[y];j++){
int z=id[j];
t[dis[z]]=max(t[dis[z]],dep[z]);
}
}
if(!tg) for(int i=l[x];i<=r[x];i++) t[dis[id[i]]]=0;
}
int main(){
// freopen("T3.in","r",stdin);
// freopen("ans.out","w",stdout);
scanf("%d",&n);
for(int i=2;i<=n;i++){
int x;char val;
scanf("%d",&x);
scanf(" %c",&val);
add(x,i,(int)(val-'a'));
}
pre(1,1);
dfs(1,1);
for(int i=1;i<=n;i++){
printf("%d ",ans[i]);
}
return 0;
}