这是我的第一道黑题!
言归正传,题意是,给定一棵 \(n\) 个节点的树,现有有一张完全图,两点 \(x\),\(y\) 之间的边长为 \(w_x+w_y+dis_{x,y}\),其中 \(dis_{x,y}\) 表示 \(x\) 和 \(y\) 在树上的距离,求完全图的最小生成树。
常规的求最小生成树的算法有 \(kruskal\)、\(prim\)。但是这里这张图是完全图,两个算法都会超时的。
所以一个为这个问题量身定做的算法出现了!它就是 \(boruvka\)!
算法流程
每轮为当前每个连通块找到与其最近的连通块,并连边,直到只有一个连通块。
正确性
最后的最小生成树上的每个点,显然都会保留它连出的最短的边。
否则断掉现在它连出的一条边,再连最短的边一定更优。
那么每轮过后,把一个连通块缩成一个点,按照上面的结论一直做下去就是对的。
轮数
每次连通块个数至少除以二,因为最坏情况下的连边就是1-2,3-4,5-6,... \(n-1\) -\(n\)。
所以最多有 \(\log n\) 轮。
找最短边
做法挺巧妙的。
我们先算出每个点和不同连通块的点的最短边长,以及最近点在哪个块,然后合并到这个点的连通块上去。
设 \(sum_i\) 表示点 \(i\) 到根节点的树上距离,\(a_i\) 表示点 \(i\) 的权值。
考虑树形 \(dp\),先 \(dfs\) 一遍找出子树内最近的点(其实就是子树里 \(sum_v+a_v\) 最小的点,因为 \(u\) 到 \(v\) 的边长是 \(a_u+a_v+sum_v-sum_u\))。
但是!这个点可能是在同一块里的,怎么办呢?
我们可以不只维护最近的点,另外再维护一个和当前最近点不是同一块的次近点。
这样的话,如果最近点在同一块里,次近点就是答案了。
第二遍 \(dfs\),我们求出子树外最近的点,子树外的点到 \(u\) 的边长是 \(a_u+a_v+sum_v+sum_u-2\times sum_{lca}\)。当 \(lca\) 固定的时候,这个 \(v\) 一定是上一次 \(dfs\) 里算出的 \(lca\) 的子树中, \(sum_v+a_v\) 最小的点、和最小点不在一个块的次小点。那么维护从根节点到 \(u\),\(a_v+sum_v-2\times sum_{lca}\) 最小的点,和不在同一块的次小点就好了。
代码:
#include<bits/stdc++.h>
#define int long long
#define mkp make_pair
#define fi first
#define se second
using namespace std;
const int N=2e5+10;
int n,cnt,res,a[N],f[N],s[N];
int idx,hd[N],to[N<<1],nxt[N<<1],len[N<<1];
pair<int,int>p,s1[N],s2[N],ans[N];
int find(int x)
{
if(f[x]==x)return x;
return f[x]=find(f[x]);
}
void add(int u,int v,int w)
{
++idx,to[idx]=v,nxt[idx]=hd[u],len[idx]=w,hd[u]=idx;
return;
}
pair<int,int>cmx(pair<int,int>a,pair<int,int>b,pair<int,int>c,int f)
{
p=mkp(1e18,-1);
if(a.se!=f)p=min(p,a);
if(b.se!=f)p=min(p,b);
if(c.se!=f)p=min(p,c);
return p;
}
void dfs1(int u,int fa)
{
s1[u]=mkp(s[u]+a[u],f[u]);
s2[u]=mkp(1e18,-1);
for(int i=hd[u];i;i=nxt[i])
{
int v=to[i];
if(v==fa)continue;
s[v]=s[u]+len[i];
dfs1(v,u);
if(s1[u]>s1[v])s2[u]=cmx(s1[u],s2[u],s2[v],s1[v].se),s1[u]=s1[v];
else s2[u]=cmx(s2[u],s1[v],s2[v],s1[u].se);
}
return;
}
void dfs2(int u,int fa)
{
s1[u].fi-=2*s[u],s2[u].fi-=2*s[u];
if(fa)if(s1[u]>s1[fa])s2[u]=cmx(s2[fa],s1[u],s2[u],s1[fa].se),s1[u]=s1[fa];
else s2[u]=cmx(s2[u],s1[fa],s2[fa],s1[u].se);
if(s1[u].se!=f[u])ans[f[u]]=min(ans[f[u]],mkp(s1[u].fi+s[u]+a[u],s1[u].se));
else if(s2[u].se!=f[u])ans[f[u]]=min(ans[f[u]],mkp(s2[u].fi+s[u]+a[u],s2[u].se));
for(int i=hd[u];i;i=nxt[i])
{
int v=to[i];
if(v==fa)continue;
dfs2(v,u);
}
return;
}
signed main()
{
scanf("%lld",&n);
for(int i=1;i<=n;i++)scanf("%lld",&a[i]);
for(int i=1,u,v,w;i<n;i++)scanf("%lld%lld%lld",&u,&v,&w),add(u,v,w),add(v,u,w);
for(int i=1;i<=n;i++)f[i]=i;
cnt=n;
while(cnt>1)
{
dfs1(1,0);
for(int i=1;i<=n;i++)ans[i]=mkp(1e18,-1);
dfs2(1,0);
for(int i=1;i<=n;i++)
{
if(f[i]==i&&ans[i].se>0&&find(i)!=find(ans[i].se))f[i]=ans[i].se,res+=ans[i].fi;
}
cnt=0;
for(int i=1;i<=n;i++)cnt+=(find(i)==i);
}
printf("%lld",res);
return 0;
}
标签:Atcoder,int,s2,s1,MST,Tree,fa,sum,se
From: https://www.cnblogs.com/DLYdly1105/p/18629216