先写静态点分治,带修改的还没学,咕咕咕
点分治是用于处理树上简单路径统计的一种算法,利用分治的思想,对每一课子树统计答案,最后累加(看起来就很暴力)
所以我们要对其进行优化,将每一棵树按重心进行分割,再逐个处理子树,整体复杂度在 \(O(nlog_n)\) 左右
求重心
需要 \(dfs\) 一遍,对每一个节点开一个变量记录子树中最大的子树的 \(size\) ,让最大的 \(size\) 最小即可
点击查看代码
void get(int x,int fa)
{
size[x]=1,wt[x]=0;
for(int i=head[x];i;i=nxt[i])
{
int y=to[i];
if(y==fa||vis[y])continue;
get(y,x);
size[x]+=size[y],wt[x]=max(wt[x],size[y]);
}
wt[x]=max(wt[x],siz-size[x]);
if(wt[root]>wt[x])root=x;
}
分治过程
分治的方法大概有两种,一是求整棵树对答案的贡献,再把子树中不合法的去了,类似容斥,二是一个一个子树合并来统计答案
相比而言代码量差不多,但第二个更泛用一些。
板子什么的我就随便一放,毕竟题和题的代码不是完全一样的。。。
方法一,摘自《聪聪可可》
void lsx(int x,int d,int fa)
{
arr[++cnt]=d;
for(int i=head[x];i;i=nxt[i])
{
int y=to[i];
if(y==fa||vis[y])continue;
lsx(y,d+val[i],x);
}
}
int calc(int x,int d){
cnt=0; lsx(x,d,0); int l=1,r=cnt,sum=0;
sort(arr+1,arr+cnt+1);
for(;;++l){
while(r&&arr[l]+arr[r]>k) --r;
if(r<l) break;
sum+=r-l+1;
}
return sum;
}
void solve(int x){
ans+=calc(x,0); vis[x]=1;
for(int i=head[x];i;i=nxt[i])
{
int y=to[i];
if(vis[y])continue;
ans-=calc(y,val[i]);
root=0, siz=size[y],get(y,0);
solve(root);
}
}
方法二,摘自《Race》
void lsx(int x,int d,int fa,int deep)
{
if(d>k)return ;
arr[++cnt]=d;
c[d]=min(c[d],deep);
for(int i=head[x];i;i=nxt[i])
{
int y=to[i];
if(y==fa||vis[y])continue;
lsx(y,d+val[i],x,deep+1);
}
}
void solve(int x,int fa)
{
vis[x]=1;
b[0]=0,q[0]++,a[++sum]=0;
for(int i=head[x];i;i=nxt[i])
{
int y=to[i];
if(vis[y]||y==fa)continue;
lsx(y,val[i],x,1);
for(int j=1;j<=cnt;j++)
{
if(q[k-arr[j]]) ans=min(ans,(long long)c[arr[j]]+b[k-arr[j]]);
}
for(int j=1;j<=cnt;j++)
{
b[arr[j]]=min(b[arr[j]],c[arr[j]]);
c[arr[j]]=0x7f7f7f;
}
for(int j=1;j<=cnt;j++) q[arr[j]]++,a[++sum]=arr[j];
cnt=0;
}
for(int i=1;i<=sum;i++) b[a[i]]=0x7f7f7f,q[a[i]]--;
sum=0;
for(int i=head[x];i;i=nxt[i])
{
int y=to[i];
if(vis[y]||y==fa)continue;
root=0,siz=size[y],get(y,0);
solve(root,0);
}
}
题目
鉴于洛谷被封ip了,就不放链接了。。。
- 1:《聪聪可可》
开桶记录一下模3后为0,1,2,的边的个数,直接算即可
点击查看代码
#include<bits/stdc++.h>
const int maxn=1e5+10;
using namespace std;
int n,k,ans,root,size[maxn],siz,wt[maxn],arr[maxn],cnt;
int head[maxn],nxt[maxn<<1],to[maxn<<1],val[maxn<<1],tot;
int f[3];
bool vis[maxn];
void add(int x,int y,int z)
{
to[++tot]=y;
val[tot]=z;
nxt[tot]=head[x];
head[x]=tot;
}
void addm(int x,int y,int z)
{
add(x,y,z);add(y,x,z);
}
void get(int x,int fa)
{
size[x]=1,wt[x]=0;
for(int i=head[x];i;i=nxt[i])
{
int y=to[i];
if(y==fa||vis[y])continue;
get(y,x);
size[x]+=size[y],wt[x]=max(wt[x],size[y]);
}
wt[x]=max(wt[x],siz-size[x]);
if(wt[root]>wt[x])root=x;
}
void lsx(int x,int d,int fa)
{
f[d%3]++;
for(int i=head[x];i;i=nxt[i])
{
int y=to[i];
if(y==fa||vis[y])continue;
lsx(y,d+val[i],x);
}
}
int calc(int x,int d)
{
memset(f,0,sizeof f);
lsx(x,d,0);
return f[0]*(f[0]-1)/2+f[1]*f[2];
}
void solve(int x)
{
ans+=calc(x,0);vis[x]=1;
for(int i=head[x];i;i=nxt[i])
{
int y=to[i];
if(vis[y])continue;
ans-=calc(y,val[i]);
root=0,siz=size[y],get(y,0);
solve(root);
}
}
int main()
{
ios::sync_with_stdio(0);
cin.tie(0),cout.tie(0);
cin>>n;
for(int i=1;i<n;i++)
{
int x,y,z;
cin>>x>>y>>z;
addm(x,y,z);
}
wt[root=0]=0x7f7f7f;
siz=n;
get(1,0);
solve(root);
int a=ans*2+n,b=n*n,p=__gcd(a,b);
cout<<a/p<<"/"<<b/p<<'\n';
return 0;
}
- 2: 《Race》
记录一个每个边权是否出现,所用的最小边数,这里方法一不太适用,所用只能按子树合并,直接把已合并的子树和要合并
的子树的贡献统计即可,记得清空
点击查看代码
#include<bits/stdc++.h>
const int maxn=2e5+10;
using namespace std;
int n,k,root,size[maxn],siz,wt[maxn],arr[maxn],cnt,b[1000005],c[1000005];
int head[maxn],nxt[maxn<<1],to[maxn<<1],val[maxn<<1],tot,sum,a[1000005],q[1000005];
long long ans;
bool vis[maxn];
void add(int x,int y,int z)
{
to[++tot]=y;
val[tot]=z;
nxt[tot]=head[x];
head[x]=tot;
}
void addm(int x,int y,int z)
{
add(x,y,z);add(y,x,z);
}
void get(int x,int fa)
{
size[x]=1,wt[x]=0;
for(int i=head[x];i;i=nxt[i])
{
int y=to[i];
if(y==fa||vis[y])continue;
get(y,x);
size[x]+=size[y],wt[x]=max(wt[x],size[y]);
}
wt[x]=max(wt[x],siz-size[x]);
if(wt[root]>wt[x])root=x;
}
void lsx(int x,int d,int fa,int deep)
{
if(d>k)return ;
arr[++cnt]=d;
c[d]=min(c[d],deep);
for(int i=head[x];i;i=nxt[i])
{
int y=to[i];
if(y==fa||vis[y])continue;
lsx(y,d+val[i],x,deep+1);
}
}
void solve(int x,int fa)
{
vis[x]=1;
b[0]=0,q[0]++,a[++sum]=0;
for(int i=head[x];i;i=nxt[i])
{
int y=to[i];
if(vis[y]||y==fa)continue;
lsx(y,val[i],x,1);
for(int j=1;j<=cnt;j++)
{
// cout<<arr[j]<<"! "<<q[k-arr[j]]<<endl;
if(q[k-arr[j]])
{
// cout<<ans<<"!";
ans=min(ans,(long long)c[arr[j]]+b[k-arr[j]]);
}
}
for(int j=1;j<=cnt;j++)
{
// cout<<arr[j]<<"!"<<endl;
b[arr[j]]=min(b[arr[j]],c[arr[j]]);
c[arr[j]]=0x7f7f7f;
}
for(int j=1;j<=cnt;j++)q[arr[j]]++,a[++sum]=arr[j];
cnt=0;
}
for(int i=1;i<=sum;i++) b[a[i]]=0x7f7f7f,q[a[i]]--;
sum=0;
for(int i=head[x];i;i=nxt[i])
{
int y=to[i];
if(vis[y]||y==fa)continue;
root=0,siz=size[y],get(y,0);
// cout<<root<<"!"<<'\n';
solve(root,0);
}
}
int main()
{
ios::sync_with_stdio(0);
cin.tie(0),cout.tie(0);
cin>>n>>k;
for(int i=1;i<n;i++)
{
int x,y,z;
cin>>x>>y>>z;
x++,y++;
addm(x,y,z);
if(z==k)
{
cout<<1<<'\n';
return 0;
}
}
wt[root=0]=0x7f7f7f;
memset(b,0x7f,sizeof b);
memset(c,0x7f,sizeof c);
siz=n;
ans=1e17;
get(1,0);
// cout<<root<<"!"<<'\n';
solve(root,0);
cout<<(ans>=n?-1:ans)<<'\n';
return 0;
}
/*
4 3
0 1 1
1 2 2
2 3 4
*/
- 3:《tree》
对答案贡献的只有过根的路径,把到子树根的距离都统计,双指针统计即可
点击查看代码
#include<bits/stdc++.h>
const int maxn=4e4+10;
using namespace std;
int n,k,ans,root,size[maxn],siz,wt[maxn],arr[10001],cnt;
int head[maxn],nxt[maxn<<1],to[maxn<<1],val[maxn<<1],tot;
bool vis[maxn];
void add(int x,int y,int z)
{
to[++tot]=y;
val[tot]=z;
nxt[tot]=head[x];
head[x]=tot;
}
void get(int x,int fa)
{
size[x]=1;wt[x]=0;
for(int i=head[x];i;i=nxt[i])
{
int y=to[i];
if(y==fa||vis[y])continue;
get(y,x);
size[x]+=size[y],wt[x]=max(wt[x],size[y]);
}
wt[x]=max(wt[x],siz-size[x]);
if(wt[root>wt[x]])root=x;
}
void dfs(int x,int d,int fa)
{
arr[++cnt]=d;
for(int i=head[x];i;i=nxt[i])
{
int y=to[i];
if(y==fa||vis[y])continue;
dfs(y,d+val[i],x);
}
}
int calc(int x,int d){
cnt=0; dfs(x,d,0); int l=1,r=cnt,sum=0;
sort(arr+1,arr+cnt+1);
for(;;++l){
while(r&&arr[l]+arr[r]>k) --r;
if(r<l) break;
sum+=r-l+1;
}
return sum;
}
void solve(int x){
ans+=calc(x,0); vis[x]=1;
for(int i=head[x];i;i=nxt[i])
{
int y=to[i];
if(vis[y])continue;
ans-=calc(y,val[i]);
root=0, siz=size[y],get(y,0);
solve(root);
}
}
int main()
{
ios::sync_with_stdio(0);
cin.tie(0),cout.tie(0);
cin>>n;
for(int i=1;i<n;i++)
{
int x,y,z;
cin>>x>>y>>z;
add(x,y,z);add(y,x,z);
}
cin>>k;
wt[root=0]=0x7f7f7f;
siz=n;
get(1,0);
solve(root);
cout<<ans-n;
return 0;
}
/*
7
1 6 13
6 3 9
3 5 7
4 1 3
2 4 20
4 7 2
10
*/