算法竞赛进阶指南 P374 解法3(解法2为P1099 树网的核),7FA4.3.2.5.3, LuoguP2491 SDOI2011
- 二分答案 mid 在树的直径上找离两端最远且距离小于 mid 的点, 判断其他点是否到这个点的距离均小于等于mid
点击查看代码
#include <iostream>
#include <stdio.h>
#include <string.h>
#include <algorithm>
#include <utility>
#include <array>
#include <vector>
using namespace std;
const int N = 300005, M = N << 1;
int n, m;
int h[N], e[M], w[M], nxt[M], idx;
int vertax[N], edge[N], tot; // 直径的端点, 边, 点的总数
int dist1[N], dist2[N], *dist; // 求直径
int sum[N]; // 直径权值的前缀和
bool vis[N]; // 是否遍历
int father[N], to_father[N]; // 父节点, 到父节点的边
void add(int a, int b, int c) {
e[++ idx] = b, w[idx] = c, nxt[idx] = h[a], h[a] = idx;
}
int ans; // 最远距离
void dfs(int u) {
vis[u] = true;
for(int i = h[u]; i; i = nxt[i]) {
int v = e[i];
if(vis[v]) continue;
father[v] = u;
to_father[v] = i;
dist[v] = dist[u] + w[i];
ans = max(ans, dist[v]);
dfs(v);
}
}
int furthest(int u, int *d) {
dist = d, dist[u] = 0, father[u] = to_father[u] = -1;
memset(vis, false, sizeof(vis)), dfs(u);
return max_element(dist + 1, dist + n + 1) - dist;
}
bool check(int mid) {
int p = 1, q = tot;
while(p + 1 <= tot && sum[p + 1] <= mid) p ++; // 左边的最右
while(q - 1 >= 1 && sum[tot] - sum[q - 1] <= mid) q --; // 右边的最左
if(p > q) return true;
if(sum[q] - sum[p] > m) return false;
memset(dist, 0, sizeof(dist1));
memset(vis, false, sizeof(vis));
for(int i = 1; i <= tot; i ++) vis[vertax[i]] = true;
// O(n) 判断: 这条链将这个树分为若干个连通块
for(int i = p; i <= q; i ++) {
int &v = vertax[i];
ans = 0, dfs(v);
if(ans > mid) return false;
}
return true;
}
int main() {
scanf("%d%d", &n, &m);
int sss = 0;
for(int i = 1, a, b, c; i < n; i ++) {
scanf("%d%d%d", &a, &b, &c);
add(a, b, c), add(b, a, c);
sss += c;
}
int x = furthest(1, dist1), y = furthest(x, dist1);
x = furthest(y, dist2); // x, y 为端点
int posx = x;
tot = 1;
while(posx != y) vertax[tot] = posx, edge[++ tot] = to_father[posx], posx = father[posx];
vertax[tot] = y; // 以上为找直径
for(int i = 1; i <= tot; i ++) {
sum[i] = sum[i - 1] + w[edge[i]];
}
int l = 0, r = sss;
while(l < r) {
int mid = ((long long)l + r) >> 1;
if(check(mid)) r = mid;
else l = mid + 1;
}
printf("%d\n", l);
return 0;
}