换根 dp 一般不会指定根节点,并且根节点的变化会对一些值进行改变。因此我们需要转移根。
换根 dp一般需要预处理一下一个节点的值,然后对于任意节点开始树上dp转移。
所以我们常用两次 dfs,第一次 dfs预处理,第二次 dfs为树上 dp。
一般比较套路。
接下来会给出一个典型例题。
典例1:Luogu P3478
题目链接:Luogu P3478
Solution
我们发现本题没有给定 root,而且 root 之间的转移会影响每个节点到根的简单路径上的边的数量。
那么这种变化之间有什么关联呢?
我们发现对于一条边 \(u-v\) ,其中 \(v\) 是儿子。如果从 \(u\) 到 \(v\),那么
\(v\) 及 \(v\) 的儿子深度都会 \(-1\),反之 \(v\) 上面的节点深度都会 \(+1\)。
这就是转移式!借鉴先前求树的重心经验,对于 \(v\) 上面的部分,用 \(n-num_v-1\) 即可。
我们需要预处理每个节点的深度,以及每个节点下面有多少个儿子。然后转移即可。
形式化地,设 \(num_k\) 表示以 \(k\) 为父节点,其下面的儿子个数。\(dep_k\) 表示 \(k\) 的深度。则有:
\(f_k=f_{now}-num_v+(n-num_v)=f_{now}-2\times num_v+n\)(\(now-v\) 表示一条边)
换根 dp 一般转移直接转移,因为显然当 root 确定时答案是显然确定的!最后求以哪个节点作为 root 最大即可。
Code
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <vector>
#define int long long
using namespace std;
const int N = 1000100;
vector <int> Edge[N];
int n;
int dep[N],sz[N];
int f[N];
void dfs1(int now,int fa)
{
int u = now;
sz[u] = 1;
dep[u] = dep[fa] + 1;
for(int i=0;i<Edge[now].size();i++)
{
int v = Edge[now][i];
if(v == fa) continue;
dfs1(v,now);
sz[now] += sz[v];
}
}
void dfs2(int now,int fa)
{
for(int i=0;i<Edge[now].size();i++)
{
int v = Edge[now][i];
if(v == fa) continue;
f[v] = f[now]-2*sz[v]+n;
dfs2(v,now);
}
}
signed main()
{
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
cin>>n;
for(int i=1;i<n;i++)
{
int u,v;
cin>>u>>v;
Edge[u].push_back(v);
Edge[v].push_back(u);
}
dfs1(1,-1);
dfs2(1,-1);
int maxn = -1,maxn_num;
for(int i=1;i<=n;i++)
{
if(f[i] > maxn)
{
maxn = f[i];
maxn_num = i;
}
}
cout<<maxn_num<<endl;
return 0;
}
典例2:求树的中心
原题链接https://www.luogu.com.cn/problem/U262945
显然我们可以直接求直径
这里主要介绍换根 dp 的思路,个人认为换根 dp 思考非常自然。
设 \(f_i\) 表示以 \(i\) 为树的中心的答案。我们想到对于 \(i\),她要找距离其他节点最远的距离,是不是可以向下找和向上找?向下找非常容易,dfs 预处理即可。向上找好像无法直接搜...
我们可以转化,对于一条边 \(u-v\) ,其中 \(u\) 是父亲。是不是可以用 \(u\) 向下所能走到的最远距离来更新?
需要格外注意的是,由于我们都是简单路径,不能走回头路。所以如果 \(u\) 向下最大能走到 \(v\) 。那我们显然不能先走到 \(u\),再走回 \(v\)。所以用次大值更新。
形式化地,状态转移如下:
\(up_{v}=max(up_u,d1_u)+w(d1_u \ne v)\)
\(up_{v}=max(up_u,d2_u)+w(d1_u = v)\)
实现
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <vector>
using namespace std;
const int N = 100010;
const int INF = 0x3f3f3f3f;
typedef pair<int,int> PAIR;
int n;
vector <PAIR> Edge[N];
int d1[N],d2[N],s1[N],s2[N];
int up[N];
void dfs1(int now,int fa)
{
for(int i=0;i<Edge[now].size();i++)
{
int v = Edge[now][i].first,w = Edge[now][i].second;
if(v == fa) continue;
dfs1(v,now);
if(d1[v] + w >= d1[now])
{
d2[now] = d1[now];
s2[now] = s1[now];
d1[now] = d1[v] + w;
s1[now] = v;
}
else if(d1[v] + w > d2[now])
{
d2[now] = d1[v] + w;
s2[now] = v;
}
}
}
void dfs2(int now,int fa)
{
for(int i=0;i<Edge[now].size();i++)
{
int v = Edge[now][i].first,w = Edge[now][i].second;
if(v == fa) continue;
int u = now;
if(s1[now] == v)
{
up[v] = w + max(d2[u],up[u]);
}
else up[v] = w + max(d1[u],up[u]);
dfs2(v,now);
}
}
int main()
{
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
cin>>n;
for(int i=1;i<n;i++)
{
int a,b,c;
cin>>a>>b>>c;
Edge[a].push_back(PAIR(b,c));
Edge[b].push_back(PAIR(a,c));
}
dfs1(1,-1);
dfs2(1,-1);
int res = INF;
for(int i=1;i<=n;i++) res = min(res,max(up[i],d1[i]));
cout<<res<<endl;
return 0;
}