小红的树上路径查询(hard)
题目描述
本题和 $hard$ 难度的区别是,询问的次数有多次!
小红拿到了一棵树,她有多次询问,每次询问输入一条简单路径 $x,y$,她想知道树上所有节点到该路径的最短路之和是多少,你能帮帮她吗?
定义节点到路径的最短路为:节点到路径上所有点的最短路中,值最小的那个。特殊的,如果节点在路径上,则最短路为 $0$。
简单路径:从树上的一个节点出发,沿着树的边走,不重复地经过树上的节点,到达另一个节点的路径。
输入描述:
第一行输入两个正整数 $n,q$,代表节点数量和询问次数。
接下来的 $n−1$ 行,每行输入两个正整数 $u,v$,代表节点 $u$ 和节点 $v$ 有一条边连接。
接下来的 $q$ 行,每行输入两个正整数 $x,y$,代表一次询问。
$1 \leq n,q \leq 10^5$
$1 \leq u,v,x,y \leq n$
输出描述:
输出 $q$ 行,每行输出一个整数,代表询问的答案。
示例1
输入
4 2
1 2
1 3
1 4
2 3
2 1
输出
1
2
解题思路
不会,直接参考的题解。
对于 $x$ 与 $y$ 构成的 $xy$ 路径外一点 $v$,假设 $v$ 到 $xy$ 路径上最近的点是 $u$,那么 $v$ 到 $x$ 或 $v$ 到 $y$ 的路径必定包含点 $u$(假设以 $v$ 为根进行 bfs,那么第一次遍历到 $xy$ 路径上的点就是 $u$,继而从 $u$ 扩展到 $xy$ 路径上的其他点)。这点其实还是很难想到的。
然后求 $v$ 分别到 $x$ 和 $y$ 距离之和。有
\begin{align*}
&d(v,x) + d(v,y) \\
= &d(v,u) + d(u,x) + d(v,u) + d(u,y) \\
= &2d(v,u) + d(u,x) + d(u,y) \\
= &2d(v,u) + d(x,y)
\end{align*}
即有 $d(v,x) + d(v,y) = 2d(v,u) + d(x,y) \Rightarrow d(v,u) = \frac{d(v,x) + d(v,y) - d(x,y)}{2}$。这条式子有什么用呢?实际上我们关心的是所有点到 $u$(即该点到 $xy$ 路径最近的点)的距离的和,并不需要求出具体的 $u$。同时可以发现如果 $v$ 也是 $xy$ 路径上的点上式同样成立。因此所有点关于 $d(v,u)$ 的和就是
\begin{align*}
&\sum\limits_{v=1}^{n}{d(v,u)} \\
=&\frac{1}{2}\sum\limits_{v=1}^{n}{d(v,x) + d(v,y) - d(x,y)} \\
=&\frac{1}{2}\left(\sum\limits_{v=1}^{n}{d(v,x)} + \sum\limits_{v=1}^{n}{d(v,y)} - n \cdot d(x,y)\right) \\
\end{align*}
其中 $\sum\limits_{v=1}^{n}{d(v,x)}$ 和 $\sum\limits_{v=1}^{n}{d(v,y)}$ 分别是 $x$ 和 $y$ 到所有点的距离总和,这个可以用换根 dp 求得。$d(x,y)$ 可以分别求出 $x$ 和 $y$ 到最近公共祖先的距离再求和。
下面简单讲一下如何求所有点到 $u$ 的距离总和。固定 $1$ 为根,定义 $f(u)$ 表示子树 $u$ 中所有点到 $u$ 的距离总和,$g(u)$ 表示从 $u$ 往上走的所有点(其实就是除了子树 $u$ 外的点)到 $u$ 的距离总和。那么状态转移方程就是
$$f(u) = \sum\limits_{v \in \text{son}(u)}{f(v) + \text{sz}_v}$$
$$g(u) = g(p_u) + f(p_u) - (f(u) + \text{sz}_u) + (n - \text{sz}_u)$$
其中 $\text{sz}_u$ 表示子树 $u$ 的大小,$p_u$ 表示 $u$ 的父节点。那么所有点到 $u$ 的距离总和就是 $f(u)+g(u)$。
AC 代码如下,时间复杂度为 $O((n+q)\log{n})$:
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 1e5 + 5, M = N * 2;
int n, m;
int h[N], e[M], ne[M], idx;
LL sz[N], f[N], g[N];
int fa[N][17], d[N];
void add(int u, int v) {
e[idx] = v, ne[idx] = h[u], h[u] = idx++;
}
void dfs1(int u, int p) {
sz[u] = 1;
d[u] = d[p] + 1;
fa[u][0] = p;
for (int i = 1; i <= 16; i++) {
fa[u][i] = fa[fa[u][i - 1]][i - 1];
}
for (int i = h[u]; i != -1; i = ne[i]) {
int v = e[i];
if (v == p) continue;
dfs1(v, u);
sz[u] += sz[v];
f[u] += f[v] + sz[v];
}
}
void dfs2(int u, int p) {
for (int i = h[u]; i != -1; i = ne[i]) {
int v = e[i];
if (v == p) continue;
g[v] = g[u] + f[u] - (f[v] + sz[v]) + n - sz[v];
dfs2(v, u);
}
}
int lca(int a, int b) {
if (d[a] < d[b]) swap(a, b);
for (int i = 16; i >= 0; i--) {
if (d[fa[a][i]] >= d[b]) a = fa[a][i];
}
if (a == b) return a;
for (int i = 16; i >= 0; i--) {
if (fa[a][i] != fa[b][i]) a = fa[a][i], b = fa[b][i];
}
return fa[a][0];
}
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
cin >> n >> m;
memset(h, -1, sizeof(h));
for (int i = 0; i < n - 1; i++) {
int u, v;
cin >> u >> v;
add(u, v), add(v, u);
}
dfs1(1, 0);
dfs2(1, 0);
while (m--) {
int x, y;
cin >> x >> y;
cout << (f[x] + g[x] + f[y] + g[y] - n * (d[x] + d[y] - 2 * d[lca(x, y)])) / 2 << '\n';
}
return 0;
}
参考资料
牛客周赛64题解:https://ac.nowcoder.com/discuss/1421788
标签:路径,limits,int,sum,hard,小红,查询,fa,节点 From: https://www.cnblogs.com/onlyblues/p/18498311