首先这个 \(k\) 的限制不是很好入手,考虑先从如果选取了区间 \([l, r]\) 来入手。
那么此时连通块最小的 \(siz\) 就是把这些点拎出来建虚树对应在原树上的所有点。
那么这个有个结论,考虑按 \(\operatorname{dfn}\) 序排序后的点为 \(p_0\sim p_{k - 1}\),那么对应的最小 \(sz\) 就是 \(\frac{\sum\limits_{i = 0}^{k - 1} \operatorname{dis}(p_i, p_{(i + 1)\bmod k})}{2} + 1\)。
考虑证明,首先考虑边的数量,那么一个边如果会在虚树中,肯定满足子树内有关键点子树外也有关键点,那么子树内和子树外的 \(\operatorname{dis}\) 都不会经过这条边,只有是这里面 \(\operatorname{dfn} = \min \operatorname{or} \max\) 时会走出去,那么就会被算上 \(2\) 次,\(\times \frac{1}{2}\) 即可。
那么点数就 \(=\) 边数 \(+ 1\)。
同时对于 \([l, r]\) 的答案,可以知道一定 \(\le\) \([l, r + 1]\) 的答案。
于是可以考虑用双指针扫过去。
中途 \(\operatorname{dfn}\) 排序的维护可以用 set
处理。
时间复杂度 \(\mathcal{O}(n\log n)\)。
代码
#include<bits/stdc++.h>
const int maxn = 1e5 + 10;
std::vector<int> son[maxn];
int fa[maxn], dep[maxn], siz[maxn];
void dfs1(int u) {
dep[u] = dep[fa[u]] + 1;
siz[u] = 1;
for (int &v : son[u]) {
fa[v] = u, son[v].erase(std::find(son[v].begin(), son[v].end(), u));
dfs1(v);
siz[u] += siz[v], siz[v] > siz[son[u][0]] && (std::swap(son[u][0], v), 1);
}
}
int dfn[maxn], dfp[maxn], top[maxn], dt;
void dfs2(int u) {
dfp[dfn[u] = ++dt] = u;
for (int v : son[u]) {
top[v] = v == son[u][0] ? top[u] : v;
dfs2(v);
}
}
inline int lca(int x, int y) {
while (top[x] != top[y]) {
if (dfn[x] < dfn[y]) std::swap(x, y);
x = fa[top[x]];
}
return dep[x] < dep[y] ? x : y;
}
inline int dis(int x, int y) {
return dep[x] + dep[y] - 2 * dep[lca(x, y)];
}
std::set<int> s;
int tot;
inline void add(int x, int v) {
if (v == 1) s.emplace(dfn[x]);
std::set<int>::iterator it = s.find(dfn[x]), it1 = it, it2 = it;
it2++;
int y, z;
if (it1 == s.begin()) y = dfp[*s.rbegin()];
else y = dfp[*--it1];
if (it2 == s.end()) z = dfp[*s.begin()];
else z = dfp[*it2];
tot += v * (dis(x, y) + dis(x, z) - dis(y, z));
if (v == -1) s.erase(it);
}
int main() {
int n, k;
scanf("%d%d", &n, &k);
for (int i = 1, x, y; i < n; i++) {
scanf("%d%d", &x, &y);
son[x].push_back(y), son[y].push_back(x);
}
dfs1(1), top[1] = 1, dfs2(1);
int ans = 0;
for (int i = 1, j = 1; i <= n; i++) {
add(i, 1);
while (tot / 2 + 1 > k) add(j++, -1);
ans = std::max(ans, i - j + 1);
}
printf("%d\n", ans);
return 0;
}