与其说树上启发式合并是一种算法,不如说是一种思想。它在于通过”小的并入大的“保证复杂度,从而解决很多看似无法做的问题。
论纯用树上启发式合并的题很少,但是很多题却可以用树上启发式合并去解决。
模板
求解的问题往往具有如下性质:
- 每颗子树都有要记录的信息,信息的数量和子树大小有关。
- 一个父亲的信息包含它儿子的信息。
(若觉得抽象,不妨先看例题,再回来看模板)。
这种方法和轻重链剖分一样是找出重儿子,然后把其它儿子的信息逐个合并到重儿子上。
重儿子:子树大小最大的儿子(集合大小往往和子树大小有关)。
模板 $1$
/*
dep[u]:u的深度
son[u]:u的重儿子
e[u]: u的所有儿子
*/
void Union(int u, int v){
//把v的信息合并到u上
}
void solve(int u){
//处理一些信息
for(int v: e[u]){
dep[v] = dep[u] + 1, dfs(v);
Union(u, v);
}
//对询问进行处理
}
模板 $2$
/*
sz[u]:u的子树大小
dep[u]:u的深度
son[u]:u的重儿子
e[u]: u的所有儿子
*/
void Union(int u, int v){
//把v的信息合并到u上
}
void dfs(int u){//进行预处理,包括重儿子、子树大小、深度等等
sz[u] = 1;
//额外的一些处理
for(int v: e[u]){
dep[v] = dep[u] + 1, dfs(v);//遍历所有儿子
sz[u] += sz[v];
if(sz[v] > sz[son[u]]) son[u] = v;
}
}
void solve(int u){
for(int v: e[u]){
if(v == son[u]) continue;
dfs(v);
}
if(son[u]) dfs(son[u]);
//Union+对询问进行处理
}
//上述的"对询问进行处理"是一种离线的方法,即用一个vector存每个点(子树)相关的询问。之所以要这样,是因为树上启发式合并的空间和一些线段树合并一样,对一个节点处理完之后是释放掉的。
哪种比较方便要看具体的题目。
时间复杂度
考虑共有 \(n\) 个元素,一个元素 \(i\) 所在集合被合并到另一集合时才会产生时间花销。而我们每次把小的集合合并到大的,\(i\) 所在集合大小至少\(\times 2\)。而\(\times 2\) 必定不超过 \(\log n\) 次。故时间复杂度为 \(O(n\log n)\)。
例题
CF208E Blood Cousins
这道题可以通过倍增访问到 \(p\) 级祖先,然后 \(p\) 即表亲就等于 \(p\) 级祖先的 \(p\) 级子孙个数 \(-1\) 。
显然求一个点 \(u\) 的 \(p\) 级子孙个数等价于求 \(u\) 子树里深度为 \(dep_u+p\) 那一层的节点数。
这个东西用树上启发式合并求就行了。对于每颗子树维护两个东西:深度的集合以及每种深度出现次数(用 \(\text{STL}\) 的动态开点哈希 unordered_map
)。
点击查看代码
#include <bits/stdc++.h>
#define FL(i, a, b) for(int i = (a); i <= (b); i++)
#define FR(i, a, b) for(int i = (a); i >= (b); i--)
using namespace std;
const int N = 1e5 + 10;
struct Q{int d, i, s;};
int n, m, tot, rt[N], dep[N], fa[N][20], ans[N];
vector<int> root, e[N], s[N];
vector<Q> q[N];
unordered_map<int, int> mp[N];
void Union(int u, int v){
if(s[u].size() < s[v].size())
swap(s[u], s[v]), swap(mp[u], mp[v]);
for(int &x: s[v]){
if(!mp[u][x]) s[u].push_back(x);
mp[u][x] += mp[v][x]; mp[v].erase(x);
}
s[v].clear();
}
void dfs_lca(int u, int root){
dep[u] = dep[fa[u][0]] + 1, rt[u] = root;
FL(i, 1, 16) fa[u][i] = fa[fa[u][i - 1]][i - 1];
for(int &v: e[u]) fa[v][0] = u, dfs_lca(v, root);
}
void solve(int u){
for(int &v: e[u]) solve(v), Union(u, v);
s[u].push_back(dep[u]), mp[u][dep[u]]++;
for(Q &t: q[u]) ans[t.i] = mp[u][t.d] - 1;
}
int main(){
scanf("%d", &n);
FL(i, 1, n){
int r; scanf("%d", &r);
if(!r) root.push_back(i);
else e[r].push_back(i);
}
for(int &u: root) dfs_lca(u, u);
scanf("%d", &m);
FL(i, 1, m){
int v, p, r;
scanf("%d%d", &v, &p), r = p - 1;
FR(j, 16, 0) if((1 << j) <= r)
r -= (1 << j), v = fa[v][j];
if(!fa[v][0]) continue;
q[fa[v][0]].emplace_back((Q){dep[fa[v][0]] + p, i, v});
}
for(int &u: root) solve(u);
FL(i, 1, m) printf("%d ", ans[i]);
return 0;
}
CF570D Tree Requests
能重组成回文串仅当只存在至多一种字符的出现次数为奇数。
\(\text{Solution 1}\)
这道题其实是在上一题的启发式合并的基础上,哈希加了一维字母。
查询时遍历所有字母。
时间复杂度 \(O(n\log n + 26n)\),前者为启发式合并的复杂度,后者为查询的复杂度。
\(\text{Solution 2}\)
用一个二进制状态来表示每种字母出现是奇数次还是偶数次。
把上题哈希中存的出现次数换成这个二进制状态就行了。
正睿OI 908
这题是启发式合并与动规计数的结合应用。
动规计数:对于所有点分别算出作为 \(o_1,o_2\) 凑成 \(7\) 的方案数,然后再求出答案即可。
但是求 \(o_1,o_2\) 的过程中需要求 \(x\) 子树里与 \(x\) 距离为 \(d\) 的点的个数。
这个等价于求 \(x\) 子树里深度为 \(dep_x+d\) 那一层有多少个节点。直接树上启发式合并就行了。
点击查看代码
#include <bits/stdc++.h>
#define FL(i, a, b) for(int i = (a); i <= (b); i++)
#define FR(i, a, b) for(int i = (a); i >= (b); i--)
using namespace std;
typedef long long ll;
const int N = 1e5 + 10, mod = 1e9 + 7;
int n, A, B, C, D, tot, id[N], id2[N], dep[N];
ll ans, sum, f[N][2];
vector<int> e[N], s[N];
unordered_map<int, int> mp[N];
void Union(int u, int v){
if(s[u].size() < s[v].size()) swap(s[u], s[v]), swap(mp[u], mp[v]);
for(int &x: s[v])
s[u].push_back(x), mp[u][x] += mp[v][x], mp[v].erase(x);
s[v].clear();
}
void dfs(int u){
id[u] = ++tot;
for(int &v: e[u]){
dep[v] = dep[u] + 1, dfs(v);
(f[u][0] += 1ll * mp[u][dep[u] + A] * mp[v][dep[u] + B]) %= mod;
(f[u][1] += 1ll * mp[u][dep[u] + C] * mp[v][dep[u] + D]) %= mod;
Union(u, v);
}
s[u].push_back(dep[u]), mp[u][dep[u]]++;
}
int main(){
scanf("%d", &n), dep[1] = 1;
FL(i, 2, n){
int u, v; scanf("%d%d", &u, &v);
e[u].push_back(v);
}
scanf("%d%d%d%d", &A, &B, &C, &D);
dfs(1);
FL(i, 1, n) id2[id[i]] = i;
FL(i, 1, n) (ans += sum * f[id2[i]][1]) %= mod, (sum += f[id2[i]][0]) %= mod;
printf("%lld\n", ans);
return 0;
}
CF600E Lomsat gelral、CF1009F Dominant Indices
两题做法基本一致。用维护深度的办法在维护颜色的同时,记录一下最大值以及编号和就行了。
这里用 CF600E的代码举例:
点击查看代码
#include <bits/stdc++.h>
#define FL(i, a, b) for(int i = (a); i <= (b); i++)
#define FR(i, a, b) for(int i = (a); i >= (b); i--)
using namespace std;
typedef long long ll;
const int N = 1e5 + 10;
struct Edge{int v, nxt;} e[N << 1];
int n, tot, son[N], c[N], sz[N], head[N], t[N];
ll cnt, sum, ans[N];
void init(){
tot = 0, memset(head, -1, sizeof(head));
}
void Adde(int u,int v){
e[++tot] = {v, head[u]}, head[u] = tot;
}
void dfs(int u,int p){
sz[u] = 1;
for(int i = head[u]; ~i; i = e[i].nxt){
int v = e[i].v; if(v == p) continue;
dfs(v, u), sz[u] += sz[v];
if(sz[v] > sz[son[u]]) son[u] = v;
}
}
void Add(int u, int p){
t[c[u]]++;
if(t[c[u]]>cnt) cnt = t[sum = c[u]];
else if(t[c[u]] == cnt) sum += c[u];
for(int i = head[u]; ~i; i = e[i].nxt){
int v = e[i].v;
if(v != p) Add(v, u);
}
}
void Sub(int u, int p){
t[c[u]]--;
for(int i = head[u]; ~i; i = e[i].nxt){
int v = e[i].v;
if(v != p) Sub(v, u);
}
}
void dfs2(int u,int p){
for(int i = head[u]; ~i; i = e[i].nxt){
int v = e[i].v;
if(v != p && v != son[u]){
dfs2(v, u), Sub(v, u), cnt = sum = 0;
}
}
if(son[u]) dfs2(son[u], u);
for(int i = head[u]; ~i; i = e[i].nxt){
int v = e[i].v;
if(v != p && v != son[u]) Add(v, u);
}
t[c[u]]++;
if(t[c[u]] > cnt) cnt = t[sum = c[u]];
else if(t[c[u]] == cnt) sum+=c[u];
ans[u] = sum;
}
signed main(){
scanf("%d", &n);
memset(head, -1, sizeof(head));
FL(i, 1, n) scanf("%lld",&c[i]);
FL(i, 2, n){
int u, v; scanf("%d%d", &u, &v);
Adde(u, v), Adde(v, u);
}
dfs(1, 0), dfs2(1, 0);
FL(i, 1, n) printf("%lld ",ans[i]);
return 0;
}
CF246E Blood Cousins Return
把 CF208E 中 unordered_map
里存的东西替换为深度对应的不同子串个数。
合并时每往大的集合里加一个元素,就看看是否出现过(显然再开一个哈希就行了)。
询问离线到每个节点处理。
CF375D Tree and Queries
树上启发式合并维护颜色数的同时:
记 \(cnt_i\) 为颜色 \(i\) 的出现次数,\(sum_i\) 为出现次数大于等于 \(i\) 的颜色数。
和莫队类似的修改函数:
void add(int u){sum[++cnt[c[u]]]++;}
void del(int u){sum[cnt[c[u]]--]++;}
巧妙在 \(cnt\) 从 \(0\) 开始加起,所以 \(\sum_i^{\le cnt_{c_u}} sum_i\) 均加了 \(1\),也正好与 \(sum\) 的定义相呼应。
CF741D Arpa’s letter-marked tree and Mehrdad’s Dokhtar-kosh paths
判断字符集能否重构成回文串的方法同上
能重组成回文串仅当只存在至多一种字符的出现次数为奇数。
我们令 \(a_u\) 表示 \(1\to u\) 路径上的字符集的二进制状态。具体的,从右往左数第 \(1\) 位表示字符 \(a\) 的出现次数是否为奇数;从右往左第 \(2\) 位表示字符 \(b\) 的出现次数是否为奇数……以此类推。
我们发现,祖先 \(p\) 到 \(u\) 路径上的二进制状态等价于 \(a_p\bigoplus a_u\)。也就是任意点对 \((u,v)\) 路径上的二进制状态等价于 \((a_u\bigoplus a_{lca})\bigoplus (a_v\bigoplus a_{lca})=a_u\bigoplus a_v\)。
这时我们就有方法统计答案的最大值了。点 \(u\) 的答案等价于经过 \(u\) 的最长合法路径的长度,以及其子节点的答案的最大值。维护经过 \(u\) 的最长合法路径只需要维护所有的 \(a_i\),之后直接树上启发式合并即可。
这题的启发式合并过程中,先遍历轻儿子,最后重儿子。轻儿子的信息清空,重儿子的不清空。对于一颗子树先查询再修改。查询:由于允许至多一种字符出现次数为奇数,所以就枚举哪种字符出现次数为奇数(或者没有字符出现次数为奇数)。
点击查看代码
#include <bits/stdc++.h>
#define FL(i, a, b) for(int i = (a); i <= (b); i++)
#define FR(i, a, b) for(int i = (a); i >= (b); i--)
using namespace std;
const int N = 5e5 + 10, INF = 1e9;
int n, a[N], sz[N], son[N], ans[N], dep[N], cnt[1 << 22];
vector<pair<int, int> > e[N];
void dfs(int u){
sz[u] = 1;
for(auto &p: e[u]){
int v = p.first, w = p.second;
a[p.first] = a[u] ^ (1 << w);
dep[v] = dep[u] + 1, dfs(v), sz[u] += sz[v];
if(sz[v] > sz[son[u]]) son[u] = v;
}
}
void Add(int u){
cnt[a[u]] = max(cnt[a[u]], dep[u]);
for(auto &p: e[u]) Add(p.first);
}
void Del(int u){
cnt[a[u]] = -INF;
for(auto &p: e[u]) Del(p.first);
}
int calc(int u, int rt){
int ret = max(0, dep[u] + cnt[a[u]]);
FL(i, 0, 21) ret = max(ret, dep[u] + cnt[a[u] ^ (1 << i)]);
if(u == rt) cnt[a[u]] = max(cnt[a[u]], dep[u]);
for(auto &p: e[u]) if(p.first != son[rt]){
ret = max(ret, calc(p.first, rt));
if(u == rt) Add(p.first);
}
return ret;
}
void solve(int u, int h){
for(auto &p: e[u])
if(p.first != son[u]) solve(p.first, 0);
if(son[u]) solve(son[u], 1);
ans[u] = calc(u, u), ans[u] = max(0, ans[u] - dep[u] * 2);
for(auto &p: e[u]) ans[u] = max(ans[u], ans[p.first]);
if(!h) Del(u);
}
int main(){
scanf("%d", &n);
FL(i, 0, (1 << 22) - 1) cnt[i] = -INF;
FL(i, 2, n){
int p; char c;
scanf("%d %c", &p, &c);
e[p].push_back({i, c - 'a'});
}
dfs(1), solve(1, 0);
FL(i, 1, n) printf("%d ", ans[i]);
return 0;
}
此外,这里给出一份树上启发式合并模板 \(1\) 的写法(被卡常了):
点击查看代码
#include <bits/stdc++.h>
#pragma GCC optimize("Ofast,no-stack-protector")
#pragma GCC target("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx,tune=native")
#define FL(i, a, b) for(register int i = (a); i <= (b); i++)
#define FR(i, a, b) for(register int i = (a); i >= (b); i--)
using namespace std;
const int N = 5e5 + 10;
int n, son[N], sz[N], sc[N], ans[N], id[N];
unordered_map<int, int> mp[N];
vector<pair<int, int> > e[N], s[N];
inline void write(int x){
if(x < 10) putchar(x + '0');
else write(x / 10), putchar(x % 10 + '0');
}
inline void Union(int &u, int &v, int &x, int &y){
for(register auto &i: s[id[v]]) FL(b, -1, 21){
int d = b < 0? 0 : (1 << b), c = (i.first ^ d) == x? max(y, mp[id[u]][i.first ^ d]) : mp[id[u]][i.first ^ d];
if(c) ans[u] = max(ans[u], c + i.second + 1 - 2 * y);
}
for(register auto &i: s[id[v]]){
s[id[u]].emplace_back(make_pair(i.first, i.second + 1));
mp[id[u]][i.first] = max(mp[id[u]][i.first], mp[id[v]][i.first] + 1);
}
vector<pair<int, int> >().swap(s[id[v]]);
unordered_map<int, int> ().swap(mp[id[v]]);
}
inline void Dfs(int u){
sz[u] = 1, id[u] = u;
for(register pair<int, int> &p: e[u]){
Dfs(p.first), sz[u] += sz[p.first];
if(sz[p.first] > sz[son[u]]) son[u] = p.first, sc[u] = p.second;
}
}
inline void solve(int u, int x, int y){
if(son[u]){
solve(son[u], x ^ (1 << sc[u]), y + 1);
id[u] = id[son[u]]; ans[u] = max(ans[u], ans[son[u]]);
FL(b, -1, 21){
int d = b < 0? 0 : (1 << b);
if(mp[id[u]][(x ^ d)]) ans[u] = max(ans[u], mp[id[u]][(x ^ d)] - y);
}
}
for(register pair<int, int> &p: e[u]){
if(p.first == son[u]) continue;
solve(p.first, x ^ (1 << p.second), y);
Union(u, p.first, x, y), ans[u] = max(ans[u], ans[p.first]);
}
s[id[u]].emplace_back(make_pair(x, y));
mp[id[u]][x] = max(mp[id[u]][x], y);
vector<pair<int, int> >().swap(e[u]);
}
int main(){
scanf("%d", &n);
FL(i, 2, n){
int p; char c; scanf("%d %c", &p, &c);
e[p].emplace_back(make_pair(i, c - 'a'));
}
Dfs(1), solve(1, 0, 1);
FL(i, 1, n) write(ans[i]), putchar(' ');
return 0;
}