虚树 学习笔记
引入
我们在解决树上问题时,往往都是对整棵树进行处理,或者每次询问都对一个点、点对进行处理,这类题型一般都可以通过 dp、树剖解决;然而,有一类问题要求我们每次对树上一些关键点进行处理。这类问题的特点就是询问次数多,而询问的点的总数不多。可如果我们每次都把整棵树都 dfs 一遍,时间复杂度就是 \(n^2\) 级别的。我们发现,每次 dfs 的时候,有用的只有关键点,我们所关注的也只有这些关键点之间的关系,那我们是不是可以考虑把整棵树抽象起来,变成一颗只有关键点和某些辅助点的树呢?
虚树
我们把关键点和用于体现关键点之间关系的辅助点连接起来,就形成了虚树。这些辅助点一般都是 LCA。由于至少需要一个关键点才会出现一个辅助点(比如某个点 是 “两个关键点的 LCA” 与 “另一个关键点” 的 LCA),所以最后建出来所有的虚树并遍历的总代价是 \(2n\) 级别的。
建树过程
我们肯定希望辅助点越少越好,但同时还得保证信息正确,所以我们考虑按照 dfs 序建树,因为 dfs 序越相近,两个点在树上的关系越近。我们先把关键点按 dfn 排序,然后用一个栈来维护一条虚树上的链,每次都询问栈顶是否为新点和栈顶的 LCA,如果不是,说明要开一条新的链,就弹栈并加边。最后一定要把栈内剩余元素加边。
参考代码:
void build(){
sort(p+1, p+K+1, cmp);
top = 0;
stk[++top] = 1;
G2.head[1] = 0;//注意不能全部清空,在加边的过程中动态清空即可。
for(int i = 1; i<=K; ++i){
if(p[i] == 1) continue;
int lca = th.LCA(stk[top], p[i]);
if(lca != stk[top]){
while(dfn[lca] < dfn[stk[top-1]]){
G2.add(stk[top-1], stk[top]);
--top;
}
if(dfn[lca] > dfn[stk[top-1]]){
G2.head[lca] = 0;
G2.add(lca, stk[top]), stk[top] = lca;
} else{
G2.add(lca, stk[top]);
--top;
}
}
G2.head[p[i]] = 0;
stk[++top] = p[i];
}
for(int i = 1; i<top; ++i){
G2.add(stk[i], stk[i+1]);
}
}
例题
消耗战
首先 dp 式子很明显,我们分类讨论。如果子节点 \(v\) 是关键点,那么 \(f_u+=w(u, v)\);如果不是,那么就是 \(f_u+= \min{f_v, w(u, v)}\)。我们可以预处理出来根节点到每个节点的路径上所经过的最小边权,把它作为虚树上的边权即可。
代码:
点击查看代码
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N = 2.5e5+10;
inline int read(){
int x = 0; char ch = getchar();
while(ch<'0' || ch>'9') ch = getchar();
while(ch>='0'&&ch<='9') x = x*10+ch-48, ch = getchar();
return x;
}
struct node{
int nxt, to, w;
};
struct Graph{
int head[N], tot;
int num;
node edge[N<<1];
void add(int u, int v, int w){
edge[++tot].nxt = head[u];
edge[tot].to = v;
edge[tot].w = w;
head[u] = tot;
}
}G1, G2;//原树,虚树。
int dfn[N];
struct HPD{//重链剖分,heavy path decomposition
int siz[N], totd, top[N], son[N], dep[N], fa[N];
void dfs1(int u, int fath){
dep[u] = dep[fath]+1;
siz[u] = 1;
fa[u] = fath;
for(int i = G1.head[u]; i; i = G1.edge[i].nxt){
int v = G1.edge[i].to;
if(v == fath) continue;
dfs1(v, u);
siz[u]+=siz[v];
if(siz[son[u]]<siz[v]) son[u] = v;
}
}
void dfs2(int u, int Top){
top[u] = Top;
dfn[u] = ++totd;
if(!son[u]) return;
dfs2(son[u], Top);
for(int i = G1.head[u]; i; i = G1.edge[i].nxt){
int v = G1.edge[i].to;
if(!dfn[v]) dfs2(v, v);
}
}
int LCA(int x, int y){
while(top[x] != top[y]){
if(dep[top[x]] < dep[top[y]]) swap(x, y);
x = fa[top[x]];
}
if(dep[x] > dep[y]) swap(x, y);
return x;
}
}th;
int n;
int m, K;
int dst[N], p[N];
void dfsG1(int u, int fath){
for(int i = G1.head[u]; i; i = G1.edge[i].nxt){
int v = G1.edge[i].to;
if(v == fath) continue;
dst[v] = min(dst[u], G1.edge[i].w);
dfsG1(v, u);
}
}
bool cmp(int a, int b){
return dfn[a] < dfn[b];
}
int stk[N], tp;
bool is_tar[N];
void build(){
sort(p+1, p+K+1, cmp);
tp = 0;
stk[++tp] = 1, G2.head[1] = 0;
for(int i = 1, l; i<=K; ++i){
if(p[i] == 1) continue;
l = th.LCA(p[i], stk[tp]);
if(l != stk[tp]){
while(dfn[l] < dfn[stk[tp-1]]){
G2.add(stk[tp-1], stk[tp], dst[stk[tp]]);
--tp;
}
if(dfn[l] > dfn[stk[tp-1]]){
G2.head[l] = 0;
G2.add(l, stk[tp], dst[stk[tp]]), stk[tp] = l;
} else{
G2.add(l, stk[tp], dst[stk[tp]]);
--tp;
}
}
G2.head[p[i]] = 0;
stk[++tp] = p[i];
}
for(int i = 1; i<tp; ++i){
G2.add(stk[i], stk[i+1], dst[stk[i+1]]);
}
}
ll f[N];
void dfs_ans(int u, int fath){
f[u] = 0;
for(int i = G2.head[u]; i; i = G2.edge[i].nxt){
int v = G2.edge[i].to;
if(v == fath) continue;
dfs_ans(v, u);
if(is_tar[v]){
f[u]+=G2.edge[i].w;
} else{
f[u]+= min(f[v], 1ll*G2.edge[i].w);
}
}
}
int main(){
n = read();
dst[1] = 0x3f3f3f3f;
for(int i = 1; i<n; ++i){
int u = read(), v = read(), w = read();
G1.add(u, v, w);
G1.add(v, u, w);
}
th.dfs1(1, 0);
th.dfs2(1, 1);
dfsG1(1, 0);
m = read();
while(m--){
K = read();
G2.tot = 0;//一定注意要清空!
for(int i = 1; i<=K; ++i){
p[i] = read();
is_tar[p[i]] = 1;
}
build();
dfs_ans(1, 0);
printf("%lld\n", f[1]);
for(int i = 1; i<=K; ++i){
is_tar[p[i]] = 0;
}
}
return 0;
}
大工程
也是考虑建好虚树后怎么做。最大值和最小值都可以通过拼接求得,每次找出最大/最小和次大/次小值拼接即可。至于路径权值和,我们考虑每条边的贡献,发现就等于这条边所连接的两棵子树中关键点数量的乘积。至于建好虚树后新边的边权,因为是单位边权,所以直接通过两点的深度做差即可求得。
代码:
点击查看代码
#include<bits/stdc++.h>
using namespace std;
const int N = 1e6+100;
const int INF = 0x3f3f3f3f;
inline int read(){
int x = 0; char ch = getchar();
while(ch<'0' || ch>'9') ch = getchar();
while(ch>='0'&&ch<='9') x = x*10+ch-48, ch = getchar();
return x;
}
struct node{
int nxt, to;
};
struct Graph{
int tot, head[N];
node edge[N<<1];
void add(int u, int v){
edge[++tot].nxt = head[u];
edge[tot].to = v;
head[u] = tot;
}
}G1, G2;
int dep[N], dfn[N], totd;
struct HPD{
private:
int fa[N], top[N], son[N], siz[N];
public:
void dfs1(int u, int fath){
dep[u] = dep[fath]+1;
fa[u] = fath;
siz[u] = 1;
for(int i = G1.head[u]; i; i = G1.edge[i].nxt){
int v = G1.edge[i].to;
if(v == fath) continue;
dfs1(v, u);
siz[u]+=siz[v];
if(siz[son[u]] < siz[v]) son[u] = v;
}
}
void dfs2(int u, int Top){
dfn[u] = ++totd;
top[u] = Top;
if(!son[u]) return;
dfs2(son[u], Top);
for(int i = G1.head[u]; i; i = G1.edge[i].nxt){
int v = G1.edge[i].to;
if(!dfn[v]) dfs2(v, v);
}
}
inline int LCA(int x, int y){
while(top[x] != top[y]){
if(dep[top[x]] < dep[top[y]]) swap(x, y);
x = fa[top[x]];
}
if(dep[x] > dep[y]) swap(x, y);
return x;
}
}th;
int K;
int stk[N], top;
int p[N];
bool is_tar[N];
bool cmp(int x, int y){
return dfn[x] < dfn[y];
}
void build(){
sort(p+1, p+K+1, cmp);
top = 0;
stk[++top] = 1;
G2.head[1] = 0;
for(int i = 1; i<=K; ++i){
if(p[i] == 1) continue;
int lca = th.LCA(stk[top], p[i]);
if(lca != stk[top]){
while(dfn[lca] < dfn[stk[top-1]]){
G2.add(stk[top-1], stk[top]);
--top;
}
if(dfn[lca] > dfn[stk[top-1]]){
G2.head[lca] = 0;
G2.add(lca, stk[top]), stk[top] = lca;
} else{
G2.add(lca, stk[top]);
--top;
}
}
G2.head[p[i]] = 0;
stk[++top] = p[i];
}
for(int i = 1; i<top; ++i){
G2.add(stk[i], stk[i+1]);
}
}
int fmn[N], fmx[N]; long long fsum[N];
int mn, mx;
long long sum;
void dfs_ans(int u, int fath){
int firmn = INF, secmn = INF;
fmn[u] = INF, fmx[u] = 0;
int firmx = 0, secmx = 0;
fsum[u] = 0;
if(is_tar[u]){
fsum[u] = 1;
}
for(int i = G2.head[u]; i; i = G2.edge[i].nxt){
int v = G2.edge[i].to;
if(v == fath) continue;
dfs_ans(v, u);
if(is_tar[v]){
fmn[u] = min(fmn[u], dep[v]-dep[u]);
if(dep[v]-dep[u] < firmn){
secmn = firmn;
firmn = dep[v]-dep[u];
} else if(dep[v]-dep[u]<secmn){
secmn = dep[v]-dep[u];
}
} else{
fmn[u] = min(fmn[v]+dep[v]-dep[u], fmn[u]);
if(fmn[v]+dep[v]-dep[u] < firmn){
secmn = firmn;
firmn = fmn[v]+dep[v]-dep[u];
} else if(fmn[v]+dep[v]-dep[u]<secmn){
secmn = fmn[v]+dep[v]-dep[u];
}
}
fmx[u] = max(fmx[u], fmx[v]+dep[v]-dep[u]);
if(fmx[v]+dep[v]-dep[u] > firmx){
secmx = firmx;
firmx = fmx[v]+dep[v]-dep[u];
} else if(fmx[v]+dep[v]-dep[u] > secmx){
secmx = fmx[v]+dep[v]-dep[u];
}
fsum[u]+=fsum[v];
sum+=(fsum[v]*(K-fsum[v])*(dep[v]-dep[u]));
}
if(is_tar[u]){
mn = min(mn, fmn[u]);
} else{
mn = min(mn, secmn+firmn);
}
if(secmx){
mx = max(mx, firmx+secmx);
} else if(is_tar[u]){
mx = max(fmx[u], mx);
}
}
int n, Q;
int main(){
n = read();
for(int i = 1; i<n; ++i){
int u = read(), v = read();
G1.add(u, v);
G1.add(v, u);
}
th.dfs1(1, 0);
th.dfs2(1, 1);
Q = read();
while(Q--){
K = read();
G2.tot = 0;
for(int i = 1; i<=K; ++i){
p[i] = read();
is_tar[p[i]] = 1;
}
build();
sum = mx = 0, mn = INF;
dfs_ans(1, 0);
printf("%lld %d %d\n", sum, mn, mx);
for(int i = 1; i<=K; ++i){
is_tar[p[i]] = 0;
}
}
return 0;
}