dsu on tree
前言
在我认为,这个并不能说单独列出来成为一个算法,更恰当的说,是一种思想、技巧。反正挺简单的,也很有趣(谁会拒绝一个优美的暴力呢),所以写篇笔记记录一手。
dsu 是什么
dsu 一般指“disjoint set union”,即并查集。那么 dsu on tree 也就是指树上的合并和查询操作。
但是 dsu on tree 的实现却跟普通并查集没有太大联系。共同点仅在于功能上都能合并集合、查询。
dsu on tree 有什么用
dsu on tree 可称为树上启发式合并,是一种优美的暴力,合并子树的时候,把轻儿子合并到重儿子上去。
由于保存合并结果的是一个全局数组。所以每次计算新的字数时,都需要清空。我们可以先计算轻儿子,把重儿子留到最后,重儿子可以不用清零,直接把重儿子的信息拿去计算父亲。
这样一来,在暴力的基础上,将重儿子留到了最后,少算了一次重儿子,时间可以来到优秀的 \(O(n\log n)\)(暴力是纯粹的 \(O(n^2)\))。
从题目出发
题目:给一棵根为1的树,每次询问子树颜色种类数。
纯暴力
void update(int x, int f, int flg){
cnt[col[x]] += flg;
if(cnt[col[x]] == 0 && flg == -1) cols--;
if(cnt[col[x]] == 1 && flg == 1) cols++;
for(int i = fir[x]; i; i = es[i].nxt){
int tv = es[i].v;
if(tv != f) update(tv, x, flg);
}
}
void dfs(int r, int fa){// 求子树r中的信息, fa为r的父亲
for(int i = fir[r]; i; i = es[i].nxt){ // 遍历r的邻接点
int tv = es[i].v;
if(tv != fa) dfs(tv, r);
}
update(r, fa, 1);
ans[r] = cols;
update(r, fa, -1);
}
直接 \(O(n^2)\) T 飞。
当然你也可以把 dfs
写成这样(更接近 dsu 的打法):
void dfs(int r, int fa){// 求子树r中的信息, fa为r的父亲
for(int i = fir[r]; i; i = es[i].nxt){ // 遍历r的邻接点
int tv = es[i].v;
if(tv != fa){
dfs(tv, r);
update(tv, r, -1);
}
}
update(r, fa, 1);
ans[r] = cols;
}
dsu
现在我们来优化一手暴力。先预处理出轻、重儿子,然后 dfs
轻儿子、再 dfs
重儿子。
void dfs(int x, int f){
for(int i = fir[x]; i; i = es[i].nxt){
int tv = es[i].v;
if(tv != son[x] && tv != f){
dfs(tv, x);
update(tv, x, -1);
}
}
if(son[x]) dfs(son[x], x);
cnt[col[x]]++;
if(cnt[col[x]] == 1) cols++;
for(int i = fir[x]; i; i = es[i].nxt){
int tv = es[i].v;
if(tv != son[x] && tv != f){
update(tv, x, 1);
}
}
ans[x] = cols;
}
时间复杂度来到优秀的 \(O(n\log n)\) !!
时间复杂度分析
但是为什么呢?
因为根据轻重链划分的思想,任何一条到根的路径上,轻边不会超过 \(\log n\) 条,重链是被轻边分隔的,数量也不会超过 \(\log n\) 条。
每棵子树到父亲的边为轻边,做一次 update
,最多做 \(\log n\) 次。
一次 update
可以看作是轻儿子想重儿子的合并操作。
每个节点最多合并 \(\log n\) 次,总的时间复杂度为 \(O(n\log n)\) 次。
完整代码
code
#include <bits/stdc++.h>
using namespace std;
#define MAXN 100005
int n, m, ecnt, cols, col[MAXN], fir[MAXN], sz[MAXN];
int son[MAXN], cnt[MAXN], ans[MAXN];
struct edge
{
int v, nxt;
} es[MAXN << 1];
void adde(int a, int b)
{
es[++ecnt].v = b, es[ecnt].nxt = fir[a], fir[a] = ecnt;
es[++ecnt].v = a, es[ecnt].nxt = fir[b], fir[b] = ecnt;
}
void dfs1(int x, int f)
{
sz[x]++;
int maxz = 0;
for (int i = fir[x]; i; i = es[i].nxt)
{
int tv = es[i].v;
if (tv != f)
{
dfs1(tv, x);
if (sz[tv] > maxz)
maxz = sz[tv], son[x] = tv;
sz[x] += sz[tv];
}
}
}
void update(int x, int f, int flg)
{
cnt[col[x]] += flg;
if (cnt[col[x]] == 0 && flg == -1)
cols--;
if (cnt[col[x]] == 1 && flg == 1)
cols++;
for (int i = fir[x]; i; i = es[i].nxt)
{
int tv = es[i].v;
if (tv != f)
update(tv, x, flg);
}
}
void dfs2(int x, int f)
{
for (int i = fir[x]; i; i = es[i].nxt)
{
int tv = es[i].v;
if (tv != son[x] && tv != f)
dfs2(tv, x), update(tv, x, -1);
}
if (son[x])
dfs2(son[x], x);
cnt[col[x]]++;
if (cnt[col[x]] == 1)
cols++;
for (int i = fir[x]; i; i = es[i].nxt)
{
int tv = es[i].v;
if (tv != son[x] && tv != f)
update(tv, x, 1);
}
ans[x] = cols;
}
int main()
{
int a, b;
scanf("%d", &n);
for (int i = 1; i < n; i++)
{
scanf("%d %d", &a, &b);
adde(a, b);
}
for (int i = 1; i <= n; i++)
scanf("%d", &col[i]);
dfs1(1, 0);
dfs2(1, 0);
scanf("%d", &m);
for (int i = 1; i <= m; i++)
{
scanf("%d", &a);
printf("ans: %d\n", ans[a]);
}
return 0;
}
Lomsat gelral
这个就是板题了,也可以线段树合并去做。但是 dsu on tree 明显更短,更好打。
code
#include <bits/stdc++.h>
using namespace std;
#define MAXN 100025
#define LL long long int
LL sum, ans[MAXN];
int n, m, ecnt, maxcnt, col[MAXN], fir[MAXN], sz[MAXN], son[MAXN], cnt[MAXN];
struct edge
{
int v, nxt;
} es[MAXN << 1];
void adde(int a, int b)
{
es[++ecnt].v = b, es[ecnt].nxt = fir[a], fir[a] = ecnt;
es[++ecnt].v = a, es[ecnt].nxt = fir[b], fir[b] = ecnt;
}
void dfs1(int x, int f)
{
sz[x]++;
int maxz = 0;
for (int i = fir[x]; i; i = es[i].nxt)
{
int tv = es[i].v;
if (tv != f)
{
dfs1(tv, x);
if (sz[tv] > maxz)
maxz = sz[tv], son[x] = tv;
sz[x] += sz[tv];
}
}
}
void update(int x, int f, int flg)
{
cnt[col[x]] += flg;
if (flg == 1)
{
if (cnt[col[x]] > maxcnt)
maxcnt = cnt[col[x]], sum = col[x];
else if (cnt[col[x]] == maxcnt)
sum += col[x];
}
else
maxcnt = 0, sum = 0;
for (int i = fir[x]; i; i = es[i].nxt)
{
int tv = es[i].v;
if (tv != f)
{
update(tv, x, flg);
}
}
}
void dfs2(int x, int f)
{
for (int i = fir[x]; i; i = es[i].nxt)
{
int tv = es[i].v;
if (tv != son[x] && tv != f)
dfs2(tv, x), update(tv, x, -1);
}
if (son[x])
dfs2(son[x], x);
cnt[col[x]]++;
if (cnt[col[x]] > maxcnt)
maxcnt = cnt[col[x]], sum = col[x];
else if (cnt[col[x]] == maxcnt)
sum += col[x];
for (int i = fir[x]; i; i = es[i].nxt)
{
int tv = es[i].v;
if (tv != son[x] && tv != f)
update(tv, x, 1);
}
ans[x] = sum;
}
int main()
{
int a, b;
scanf("%d", &n);
for (int i = 1; i <= n; i++)
scanf("%d", &col[i]);
for (int i = 1; i < n; i++)
scanf("%d %d", &a, &b), adde(a, b);
dfs1(1, 0);
dfs2(1, 0);
for (int i = 1; i <= n; i++)
printf("%lld ", ans[i]);
return 0;
}
标签:cnt,int,dsu,tree,update,tv,启发式,col,es
From: https://www.cnblogs.com/wang-holmes/p/17983950