神题!!!!111
考虑如何不重不漏地计数。先考虑全为 1
的情况,令 \(f(u,d)\) 为与 \(u\) 的距离 \(\le d\) 的点集。
首先单独算全集,那么对于不是全集的集合就会有一些比较好的性质。
考虑若有若干个 \(f(u,d)\) 同构,那 只在 \(d\) 最小的时候计数。
那么 \(f(u,d)\) 需要满足不能覆盖全集,且不存在与 \(u\) 相邻的点 \(v\),使得 \(f(u,d) = f(v,d-1)\)(由于 \(d\) 最小的约束)。
考虑若存在后者时发生了什么。把 \(v\) 这棵子树抠掉之后,剩下的点与 \(u\) 距离 \(\le d - 2\)。
令 \(f_u\) 为以 \(u\) 为根的子树最大深度,\(g_u\) 为以 \(u\) 为根的子树次大深度(不存在则为 \(0\)),\(d_u\) 为 \(f(u,d)\) 最大能取到的 \(d\),则等价于 \(d_u < \min(f_u,g_u+2)\)。
换根求出 \(f_u,g_u\) 即可。于是我们就做完了全为 1
的情况。
现在有一些点是 0
。但是我们发现不能完全不考虑它们,因为我们发现有些 1
点的 \(d_u\) 上界过于严苛导致有些情况没有考虑到,那我们将这些情况放到 0
点计算。
发现 0
点的上界仍然可以取到,但是下界并非 \(0\)。设任意 0
点为 \(u\),则未算到的情况满足 1
点所在子树中全被覆盖,并且还可能覆盖了别的子树。设 \(h_u\) 为以 \(u\) 为根的存在 1
点的子树的最大深度,则对于 0
点,\(h_u \le d_u < \min(f_u,g_u+2)\)。
此时的 \(h_u\) 仍然可以换根求出,于是我们就以 \(O(n)\) 的时空复杂度做完了。
code
// Problem: F - Black Radius
// Contest: AtCoder - AtCoder Grand Contest 008
// URL: https://atcoder.jp/contests/agc008/tasks/agc008_f
// Memory Limit: 256 MB
// Time Limit: 2000 ms
//
// Powered by CP Editor (https://cpeditor.org)
#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mems(a, x) memset((a), (x), sizeof(a))
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<ll, ll> pii;
const int maxn = 200100;
const int inf = 0x3f3f3f3f;
int n, head[maxn], len, a[maxn], sz[maxn], f[maxn], g[maxn], h[maxn];
ll ans = 1;
struct edge {
int to, next;
} edges[maxn << 1];
void add_edge(int u, int v) {
edges[++len].to = v;
edges[len].next = head[u];
head[u] = len;
}
void dfs(int u, int fa) {
if (a[u]) {
sz[u] = 1;
} else {
h[u] = inf;
}
for (int i = head[u]; i; i = edges[i].next) {
int v = edges[i].to;
if (v == fa) {
continue;
}
dfs(v, u);
sz[u] += sz[v];
int val = f[v] + 1;
if (val > f[u]) {
g[u] = f[u];
f[u] = val;
} else if (val > g[u]) {
g[u] = val;
}
if (sz[v]) {
h[u] = min(h[u], f[v] + 1);
}
}
}
void dfs2(int u, int fa) {
ans += max(0, min(f[u], g[u] + 2) - h[u]);
for (int i = head[u]; i; i = edges[i].next) {
int v = edges[i].to;
if (v == fa) {
continue;
}
int val = (f[u] == f[v] + 1) ? g[u] + 1 : f[u] + 1;
if (val > f[v]) {
g[v] = f[v];
f[v] = val;
} else if (val > g[v]) {
g[v] = val;
}
if (sz[v] != sz[1]) {
h[v] = min(h[v], val);
}
dfs2(v, u);
}
}
void solve() {
scanf("%d", &n);
for (int i = 1, u, v; i < n; ++i) {
scanf("%d%d", &u, &v);
add_edge(u, v);
add_edge(v, u);
}
for (int i = 1; i <= n; ++i) {
scanf("%1d", &a[i]);
}
dfs(1, -1);
dfs2(1, -1);
printf("%lld\n", ans);
}
int main() {
int T = 1;
// scanf("%d", &T);
while (T--) {
solve();
}
return 0;
}