点分治
树的重心(前置芝士)
如果在树中选择某个节点并删除,这棵树将分为若干棵子树,统计子树节点数并记录最大值。取遍树上所有节点,使此最大值取到最小的节点被称为整个树的重心。
性质
- 树的重心如果不唯一,则至多两个且相邻
- 以树的重心为根时,所有子树的大小都不超过整棵树大小的一半
- 树中所有点到某个点的距离和中,到重心的距离和是最小的;如果有两个重心,那么到它们的距离和一样
- 把两棵树通过一条边相连得到一棵新的树,那么新的树的重心在连接原来两棵树的重心的路径上
- 在一棵树上添加或删除一个叶子,那么它的重心最多只移动一条边的距离
如何求重心
在 DFS 中计算每个子树的大小,记录「向下」的子树的最大大小,利用总点数 - 当前子树(这里的子树指有根树的子树)的大小得到「向上」的子树的大小,然后就可以依据定义找到重心了
int head[N], e[N], ne[N]; // head存储边的起点编号,e存储边的终点编号,ne存储所有终点
int siz[N], // siz[x]表示以x为根的子树大小
mx[N], // mx[x]表示选择x为根,其所有子树大小的最大值
rt; // 树的重心
int vis[N];
void getCentroid(int u, int fa){
siz[u] = 1, mx[u] = 0;
for(int i = head[u]; i; i = ne[i]){
if(e[i] != fa && !vis[x]){ // vis下面会讲
getCentroid(e[i], u);
siz[u] += siz[e[i]];
mx[x] = max(mx[x], siz[e[i]]);
}
}
mx[x] = max(mx[x], n - siz[x]);
if(mx[x] < mx[rt]) rt = x;
}
点分治(正式)
点分治是一种解决树上问题的常用方法,本质思想是选择一点作为分治中心,将原问题划分为几个相同的子树上的问题递归解决
常见题目中给出的都是无根树(维护的信息与根节点无关)
我们选择树的重心作为根节点,其性质2保证了递归层数最少,是\(O(logn)\)的
细节
注意到,每一次递归下去时,选择的根节点总是当前子树的重心,但是新的根不一定是之前的根的儿子,如下图
第一次找到的根为 \(1\),递归下去后两颗子树上的重心分别为 \(4\) 和 \(5\),都不是 \(1\) 的儿子,所以为了防止重复递归,应当每个节点进行点分治后加上标记,之后的递归不再进入已经打过标记的点,这也就是上面函数中 vis[]
的作用。
例题1
思路
离线处理
对于一条长度为 \(k\) 的路径,分为经过 \(rt\) 和不经过 \(rt\) 两种情况,不经过的情况可以递归进入子树处理
对于经过 \(rt\) 的路径,枚举所有子节点 \(ch\) ,以 \(ch\) 为根计算 \(ch\) 子树中所有节点到 \(rt\) 的距离。
假设子树中出现了距离 \(rt\) 为 \(l\) 的链,如果 \(k-l\) 在其他的子树中出现,那么 \(k\) 就会出现。注意到两者出现顺序无影响,所以可以依次递归子树
- \(DFS\) 处理当前子树每个点与 \(rt\) 的距离
- 在 \(DFS\) 过程中同时统计有哪些长度出现
- 与之前子树中出现的长度结合更新答案
- 将当前子树的信息并入
最后要清空记录的之前子树出现的长度,不能使用memset,要使用队列保证时间复杂度正确
AC代码
#include <iostream>
#include <queue>
using namespace std;
const int INF = 2e9;
const int N = 1e4 + 10;
int n, m, a, b, v, k, Q[N];
int head[N], e[N << 1], w[N << 1], ne[N << 1], idx; // w 表示边权
int siz[N], mx[N], sum, rt;
int dist[N], dd[N], cnt; // dist[x]存储 x 到 rt 的距离, dd[x] 记录当前子树拥有的链的长度, cnt 记录当前子树到 rt 的链的个数
bool tf[10000010], vis[N], ans[N]; // tf[x]存储是否有长度为 x 的链
queue<int> tag;
void add(int a, int b, int v) {
e[++idx] = b, w[idx] = v, ne[idx] = head[a], head[a] = idx;
}
void calcsiz(int x, int fa) {
siz[x] = 1, mx[x] = 0; //初始化
for (int i = head[x]; i; i = ne[i]) {
if (e[i] == fa || vis[e[i]]) continue;
calcsiz(e[i], x);
siz[x] += siz[e[i]];
mx[x] = max(mx[x], siz[e[i]]);
}
mx[x] = max(mx[x], sum - siz[x]);
if (mx[x] < mx[rt]) rt = x;
}
void calcdist(int x, int fa){
dd[++cnt] = dist[x];
for(int i = head[x]; i; i = ne[i])
if(e[i] != fa && !vis[e[i]]) dist[e[i]] = dist[x] + w[i], calcdist(e[i], x);
}
void dfz(int x, int fa){
tf[0] = true, tag.push(0), vis[x] = true;
// 枚举所有子节点
for(int i = head[x]; i; i = ne[i]){
if(e[i] == fa || vis[e[i]]) continue;
dist[e[i]] = w[i], calcdist(e[i], x);
// 与之前子树中出现的长度结合更新答案
for(int i = 1; i <= cnt; i++)
for(int j = 1; j <= m; j++)
if(Q[j] >= dd[i]) ans[j] |= tf[Q[j] - dd[i]]; // 看有没有 k-l 的边
// 将当前子树信息并入
for(int i = 1; i <= cnt; i++)
// 观察题目数据范围询问不会超过1e7
if(dd[i] < 1e7 + 10) tag.push(dd[i]), tf[dd[i]] = true;
cnt = 0;
}
// 至此,经过 rt 的情况被解决,清空队列
while(!tag.empty()) tf[tag.front()] = false, tag.pop();
// 递归进入子树
for(int i = head[x]; i; i = ne[i]){
if(e[i] == fa || vis[e[i]]) continue;
rt = 0;
mx[rt] = INF, sum = siz[e[i]];
calcsiz(e[i], x), calcsiz(rt, -1), dfz(rt, x);
}
}
int main() {
cin >> n >> m;
for (int i = 1; i < n; i++) {
cin >> a >> b >> v;
add(a, b, v), add(b, a, v);
}
for(int i = 1; i <= m; i++) cin >> Q[i];
rt = 0;
mx[rt] = INF, sum = n;
calcsiz(1, -1), calcsiz(rt, -1), dfz(rt, -1);
for(int i = 1; i <= m; i++){
if(ans[i]) cout << "AYE\n";
else cout << "NAY\n";
}
return 0;
}
Tips
代码中调用两次 calcsiz
是因为,从上一层分治中传下来的子树大小在这一层是不适用的,我们要求出以 \(rt\) 为根时各子树大小,虽然这么求解也是对的,但是不推荐这么做
这是 证明
例题2
思路
我们还是先选择根节点,则路径分为三种
- 两个点在同一个子树,递归进入子树求解
- 两个点在不同子树,直接求解
- 某个端点是根节点,在 \(DFS\) 过程求解
我们使用 \(p\) 存所有子树所有节点的距离,\(q\) 存当前子树所有节点的距离,然后我们在 \(p\) 中任选两点计算符合的路径数,但是同一子树中两点到根节点的距离和 \(<=k\) 的路径也会被计算,所以根据容斥原理,我们每次求解 \(q\) 的时候减去这一部分
针对 get
函数,我们先让数组有序然后双指针求解
- \(i\) ,\(j\) 未相遇时,我们针对每一个 \(j\) 找到最大的 \(i\) 然后求和
- 相遇后,针对每一个\(j\),\(i\)的取值为\(j-1\)(防止重复计算)
AC代码
#include <iostream>
#include <algorithm>
#include <cstring>
using namespace std;
const int N = 1e4 + 10;
int n, m, u, v, l, idx;
int head[N];
bool st[N];
int p[N], q[N]; //
struct node{
int to, nxt, w;
}e[N << 1];
void add(int u, int v, int w){
e[idx].to = v, e[idx].w = w, e[idx].nxt = head[u], head[u] = idx++;
}
// 求子树大小
int getsiz(int u, int fa){
if(st[u]) return 0;
int res = 1;
for(int i = head[u]; ~i; i = e[i].nxt)
if(e[i].to != fa) res += getsiz(e[i].to, u);
return res;
}
// 得到一个所有子树点数 <= n/2 的点(不一定是重心, 只要满足这个条件就好了)
int get_wc(int u, int fa, int tot, int &wc){
if(st[u]) return 0;
int sum = 1, mx = 0;
for(int i = head[u]; ~i; i = e[i].nxt){
if(e[i].to == fa) continue;
int t = get_wc(e[i].to, u, tot, wc);
mx = max(mx, t);
sum += t;
}
mx = max(mx, tot - sum);
if(mx <= tot / 2) wc = u;
return sum;
}
// 得到每个点与根节点的距离
void getdist(int u, int fa, int dist, int &qt){
if(st[u]) return;
q[qt++] = dist;
for(int i = head[u]; ~i; i = e[i].nxt)
if(e[i].to != fa) getdist(e[i].to, u, dist+e[i].w, qt);
}
// 在所有点中任意选择两个点, 返回两点距离 <=k 的路径数
int get(int a[], int k){
sort(a, a + k);
int res = 0;
for(int j = k - 1, i = -1; j >= 0; j--){
while(i + 1 < j && a[i+1] + a[j] <= m) i++;
i = min(i, j - 1);
res += i + 1;
}
return res;
}
int calc(int u){
if(st[u]) return 0;
int res = 0; // 当前子树内有多少个满足的数对
get_wc(u, -1, getsiz(u, -1), u);
st[u] = true; // 删除重心
int pt = 0;
for(int i = head[u]; ~i; i = e[i].nxt){
int qt = 0;
getdist(e[i].to, -1, e[i].w, qt);
res -= get(q, qt); // 减去多算的
for(int k = 0; k < qt; k++) {
if(q[k] <= m) res++; // 第三种情况
p[pt++] = q[k];
}
}
res += get(p, pt);
for(int i = head[u]; ~i; i = e[i].nxt) res += calc(e[i].to);
return res;
}
int main(){
while(cin >> n >> m){
if(!n && !m) break;
memset(st, 0, sizeof(st));
memset(head, -1, sizeof(head));
idx = 0;
for(int i = 1; i < n; i++){
cin >> u >> v >> l;
add(u, v, l), add(v, u, l);
}
cout << calc(0) << "\n";
}
return 0;
}