Solution
做法似乎和其他的题解不太一样。
我们考虑统计对于每一个点 \(u\) 对于它的颜色 \(c_u\) 的贡献。
我们将所有经过点 \(u\) 的简单路径进行分类:
-
该路径只包含点 \(u\)。
-
该路径的一个端点为点 \(u\),另外一个点是点 \(u\) 子树中的一个节点。
-
该路径的两个端点分别在 \(u\) 的两个不同的子树中。
-
该路径的一个端点在 \(u\) 以及其子树中,而另外一个端点在它祖先以及其子树中,但又不是 \(u\),且不在 \(u\) 的子树中。
我们接下来对这几种路径分别进行求解。
我们假设当前点为 \(u\),它的子节点为 \(v\),求当前点对 \(c_u\) 的贡献。
-
该路径只包含点 \(u\),很显然只有一种。
-
由于一个端点是 \(u\),另一个是在他的子树中,答案显然为 \(\sum size_v\)。
-
由于该路径的两个端点分别在 \(u\) 的两个不同的子树中,对于每一个 \(v\) 的贡献为 $size_v \times (size_u - size_v - 1) $,由于每一条路径会被统计两次,总的贡献为 \(\frac{\sum size_v \times (size_u - size_v - 1)}{2}\)。
Code
#include <iostream>
#include <algorithm>
#define ll long long
#define PLL pair<ll, ll>
#define inl inline
using namespace std;
const int N = 2e5;
ll n, k, ans[N + 5];
ll c[N + 5], siz[N + 5], mp[N + 5], pre[N + 5], cnt[N + 5], d[N + 5], s[N + 5];
int head[N + 5], ver[N * 2 + 5], nxt[N * 2 + 5], tot;
inl void add(int u, int v) { nxt[++tot] = head[u], head[u] = tot, ver[tot] = v; }
inl ll calc(ll x, ll y) { return !y ? s[x] : d[y]; }
inl void upd(ll x, ll y, ll w) { !y ? s[x] += w : d[y] += w; }
template <class _Tp>
inl void read(_Tp &x) {
x = 0;
char ch = getchar();
while (ch < '0' || ch > '9') ch = getchar();
while ('0' <= ch && ch <= '9') x = (x << 3) + (x << 1) + (ch ^ 48), ch = getchar();
}
inl void print(ll x) {
if (x > 9) print(x / 10);
putchar(x % 10 + '0');
}
inl void dfs1(int u, int fa) {
ll cur = mp[c[u]];
siz[u] = 1, pre[u] = mp[c[u]];
for (int i = head[u]; i; i = nxt[i]) {
int v = ver[i];
if (v != fa) mp[c[u]] = v, dfs1(v, u), siz[u] += siz[v];
}
mp[c[u]] = cur;
}
inl void dfs2(int u, int fa) {
ll res = 0;
for (int i = head[u]; i; i = nxt[i]) {
int v = ver[i];
if (v != fa) dfs2(v, u), ans[c[u]] += siz[v], res += siz[v] * (siz[u] - siz[v] - 1);
}
ans[c[u]] += res / 2, ans[c[u]] += (siz[pre[u]] - siz[u] - calc(c[u], pre[u])) * siz[u] + 1, upd(c[u], pre[u], siz[u]);
}
int main() {
read(n);
for (int i = 1; i <= n; i++) read(c[i]);
for (int i = 1, u, v; i < n; i++) read(u), read(v), add(u, v), add(v, u);
dfs1(1, 0), siz[0] = n, dfs2(1, 0);
for (int i = 1; i <= n; i++, putchar('\n')) print(ans[i]);
return 0;
}
标签:int,siz,ll,路径,inl,ff,size
From: https://www.cnblogs.com/zhouziyi/p/17205999.html