题意:给定一颗 \(n\) 个点的树,点 \(i\) 有权值 \(a_{i}\),边有边权。现在有另外一个完全图,两点之间的边权为树上两点之间的距离加上树上两点的点权,求这张完全图的最小生成树。
首先有一个很显然的暴力,把完全图中每两点之间的边权算出来,然后跑一边最小生成树,时间复杂度 \(O(n^{2} \log (n^{2}))\)。
考虑如何优化。发现有很多路径是不必要的,因为它们一定劣于其它路径,这些路径我们就不用加到完全图中去了。那么可以用点分治来筛选路径。
假设当前重心为 \(u\),我们可以把路径分为两种:
-
一个端点是 \(u\) 的路径。
-
经过 \(u\) 但是端点不在 \(u\) 的路径。
对于第一种路径,我们可以直接将它加入边集,因为总边数不超过 \(O(n \log n)\) 条。对于第二种,考虑如何选出最优的。假设两个点为 \(x\) 和 \(y\),那么可以把边权分为两个部分:\(x \to u\),\(u \to y\),即 \((a_{x}+dis(u,x))+(a_{y}+dis(u,y))\)。发现这个式子的前一半和后一半的形式是一样的,所以要让边权最小,只需要选一个 \((a_{x}+dis(u,x))\) 最小的 \(x\) 点,再连向其它所有的 \(y\) 点即可。
总时间复杂度 \(O(n \log^{2} n)\)。
本题的 Trick:求最小生成树遇到边数很多时,可以先把边权小的边拿出来,删除一些没用的边,然后再做最小生成树。
#include<bits/stdc++.h>
using namespace std;
#define int long long
#define inf 0x3f
#define inf_db 127
#define ls id << 1
#define rs id << 1 | 1
#define re register
#define endl '\n'
typedef pair <int,int> pii;
const int MAXN = 8e6 + 10;
int n,a[MAXN],head[MAXN],x,y,z,fa[MAXN],mn,id;
int pt,tot = 0,root,cnt,ans = 0;
bool vis[MAXN];
struct Node{int u,v,w,nxt;}e[MAXN << 1];
struct Edge{int u,v,w;}E[MAXN << 1];
struct F{int u,dis,idx;}p[MAXN << 1];
inline void Add(int u,int v,int w){e[++cnt] = {u,v,w,head[u]};head[u] = cnt;}
inline int Get_size(int u,int father)
{
if(vis[u] == true) return 0;
int sum = 1;
for(int i = head[u]; ~ i;i = e[i].nxt)
if(e[i].v != father) sum += Get_size(e[i].v,u);
return sum;
}
inline int Get_wc(int u,int father,int tot,int &wc)
{
if(vis[u] == true) return 0;
int sum = 1,mx = 0;
for(int i = head[u]; ~ i;i = e[i].nxt)
{
int now = e[i].v;
if(now == father) continue;
int tmp = Get_wc(now,u,tot,wc);
sum += tmp,mx = max(mx,tmp);
}
if(max(mx,tot - sum) <= tot / 2) wc = u;
return sum;
}
inline void dfs(int u,int father,int dist,int r)
{
if(vis[u] == true) return;
int val = (a[u] + dist);
p[++pt] = F{u,dist,r};
if(val < mn) mn = val,id = u,root = r;
for(int i = head[u]; ~ i;i = e[i].nxt)
{
int now = e[i].v;
if(now == father) continue;
dfs(now,u,dist + e[i].w,r);
}
}
inline void solve(int u)
{
if(vis[u] == true) return;
Get_wc(u,0,Get_size(u,0),u),vis[u] = true;
pt = 0,mn = 1e18,root = 0;
for(int i = head[u]; ~ i;i = e[i].nxt)
{
int now = e[i].v;
dfs(now,u,e[i].w,now);
}
for(int i = 1;i <= pt;i++)
if(p[i].idx != root) E[++tot] = {id,p[i].u,mn + p[i].dis + a[p[i].u]};
for(int i = 1;i <= pt;i++) E[++tot] = {u,p[i].u,p[i].dis + a[u] + a[p[i].u]};
for(int i = head[u]; ~ i;i = e[i].nxt) solve(e[i].v);
}
inline bool cmp(Edge x,Edge y){return x.w < y.w;}
inline int Find(int x)
{
if(x == fa[x]) return x;
return fa[x] = Find(fa[x]);
}
signed main()
{
memset(head,-1,sizeof head);
cin >> n;
for(int i = 1;i <= n;i++) scanf("%lld",&a[i]);
for(int i = 1;i < n;i++) scanf("%lld%lld%lld",&x,&y,&z),Add(x,y,z),Add(y,x,z);
solve(1);
sort(E + 1,E + tot + 1,cmp);
for(int i = 1;i <= n;i++) fa[i] = i;
for(int i = 1;i <= tot;i++)
if(Find(E[i].u) != Find(E[i].v))
{
fa[Find(E[i].u)] = Find(E[i].v);
ans += E[i].w;
}
cout << ans;
return 0;
}
标签:路径,int,题解,边权,MST,Tree,最小,MAXN,define
From: https://www.cnblogs.com/Creeperl/p/17892750.html