比较谔谔,为什么题解区都在群魔乱舞。不是有个很简单的点分树做法吗。
考虑建出点分树,由点分树的性质可得任意两点在点分树上的 LCA 一定在它们的路径上。然后每次暴力跳父亲,每个分治中心维护一个 \(f_i\) 表示距离 \(i\) 最近的红色点的距离即可。
若使用 dfn 序 st 表求 lca,时间复杂度为 \(O((n + m) \log n)\)。
有操作分块的做法,大概懂了。就是每 \(O(\sqrt{m})\) 次操作分块,遍历到块左端点时先预处理出所有点到红色点的最短距离,然后块内只会加入 \(O(\sqrt{m})\) 个红色点,这部分可以暴力枚举。总时间复杂度 \(O(n \sqrt{m})\)。
code
// Problem: E. Xenia and Tree
// Contest: Codeforces - Codeforces Round 199 (Div. 2)
// URL: https://codeforces.com/problemset/problem/342/E
// Memory Limit: 256 MB
// Time Limit: 5000 ms
//
// Powered by CP Editor (https://cpeditor.org)
#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mkp make_pair
#define mems(a, x) memset((a), (x), sizeof(a))
using namespace std;
typedef long long ll;
typedef double db;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<ll, ll> pii;
const int maxn = 100100;
const int logn = 20;
int n, m, fa[maxn], sz[maxn], f[maxn], rt;
int dep[maxn], dfn[maxn], tim, st[logn][maxn];
vector<int> G[maxn];
bool vis[maxn];
inline int get(int i, int j) {
return dfn[i] < dfn[j] ? i : j;
}
inline int qlca(int x, int y) {
if (x == y) {
return x;
}
x = dfn[x];
y = dfn[y];
if (x > y) {
swap(x, y);
}
++x;
int k = __lg(y - x + 1);
return get(st[k][x], st[k][y - (1 << k) + 1]);
}
inline int qdis(int x, int y) {
return dep[x] + dep[y] - dep[qlca(x, y)] * 2;
}
void dfs2(int u, int fa, int t) {
f[u] = 0;
sz[u] = 1;
for (int v : G[u]) {
if (v == fa || vis[v]) {
continue;
}
dfs2(v, u, t);
sz[u] += sz[v];
f[u] = max(f[u], sz[v]);
}
f[u] = max(f[u], t - sz[u]);
if (!rt || f[u] < f[rt]) {
rt = u;
}
}
void dfs(int u) {
vis[u] = 1;
for (int v : G[u]) {
if (vis[v]) {
continue;
}
rt = 0;
dfs2(v, -1, sz[v]);
dfs2(rt, -1, sz[v]);
fa[rt] = u;
dfs(rt);
}
}
void dfs3(int u, int fa) {
dfn[u] = ++tim;
st[0][tim] = fa;
dep[u] = dep[fa] + 1;
for (int v : G[u]) {
if (v == fa) {
continue;
}
dfs3(v, u);
}
}
void solve() {
scanf("%d%d", &n, &m);
for (int i = 1, u, v; i < n; ++i) {
scanf("%d%d", &u, &v);
G[u].pb(v);
G[v].pb(u);
}
rt = 0;
dfs2(1, -1, n);
dfs2(rt, -1, n);
dfs(rt);
dfs3(1, 0);
mems(f, 0x3f);
for (int j = 1; (1 << j) <= n; ++j) {
for (int i = 1; i + (1 << j) - 1 <= n; ++i) {
st[j][i] = get(st[j - 1][i], st[j - 1][i + (1 << (j - 1))]);
}
}
for (int i = 1; i; i = fa[i]) {
f[i] = min(f[i], qdis(1, i));
}
while (m--) {
int op, x;
scanf("%d%d", &op, &x);
if (op == 1) {
for (int i = x; i; i = fa[i]) {
f[i] = min(f[i], qdis(x, i));
}
} else {
int ans = 1e9;
for (int i = x; i; i = fa[i]) {
ans = min(ans, f[i] + qdis(x, i));
}
printf("%d\n", ans);
}
}
}
int main() {
int T = 1;
// scanf("%d", &T);
while (T--) {
solve();
}
return 0;
}