感觉我的做法比较奇葩(
容斥,总路径数减去只走点权为 \(k\) 的路径。设点权为 \(k\) 的点数为 \(c_k\),点权不为 \(k\) 的点构成的每个连通块大小为 \(s_i\),那么 \(ans_k = \frac{n(n-1)}{2} - \sum \frac{s_i (s_i - 1)}{2} + c_k\)。
考虑快速计算 \(\sum \frac{s_i (s_i - 1)}{2}\),考虑线段树分治,每条边 \((u,v)\) 当 \(k \ne a_u \land k \ne a_v\) 是有用的,把它插入对应结点,然后直接上可撤销并查集维护即可。
复杂度 \(O(n \log^2 n)\)。
code
// Problem: F - path pass i
// Contest: AtCoder - AtCoder Beginner Contest 163
// URL: https://atcoder.jp/contests/abc163/tasks/abc163_f
// Memory Limit: 1024 MB
// Time Limit: 2000 ms
//
// Powered by CP Editor (https://cpeditor.org)
#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mems(a, x) memset((a), (x), sizeof(a))
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef double db;
typedef long double ldb;
typedef pair<ll, ll> pii;
const int maxn = 200100;
ll n, a[maxn], fa[maxn], rnk[maxn], cnt, top, ans, sz[maxn], b[maxn], c[maxn];
pair<ll*, ll> stk[maxn * 10];
vector<pii> tree[maxn << 2];
int find(int x) {
return fa[x] == x ? x : find(fa[x]);
}
inline void merge(int x, int y) {
x = find(x);
y = find(y);
if (x == y) {
return;
}
stk[++top] = make_pair(&cnt, cnt);
--cnt;
stk[++top] = make_pair(&ans, ans);
if (rnk[x] <= rnk[y]) {
stk[++top] = make_pair(fa + x, fa[x]);
fa[x] = y;
ans += sz[x] * (sz[x] - 1) / 2;
ans += sz[y] * (sz[y] - 1) / 2;
stk[++top] = make_pair(sz + y, sz[y]);
sz[y] += sz[x];
ans -= sz[y] * (sz[y] - 1) / 2;
if (rnk[x] == rnk[y]) {
stk[++top] = make_pair(rnk + y, rnk[y]);
++rnk[y];
}
} else {
stk[++top] = make_pair(fa + y, fa[y]);
fa[y] = x;
ans += sz[x] * (sz[x] - 1) / 2;
ans += sz[y] * (sz[y] - 1) / 2;
stk[++top] = make_pair(sz + x, sz[x]);
sz[x] += sz[y];
ans -= sz[x] * (sz[x] - 1) / 2;
}
}
inline void undo() {
*stk[top].fst = stk[top].scd;
--top;
}
void update(int rt, int l, int r, int ql, int qr, pii x) {
if (ql > qr) {
return;
}
if (ql <= l && r <= qr) {
tree[rt].pb(x);
return;
}
int mid = (l + r) >> 1;
if (ql <= mid) {
update(rt << 1, l, mid, ql, qr, x);
}
if (qr > mid) {
update(rt << 1 | 1, mid + 1, r, ql, qr, x);
}
}
void dfs(int rt, int l, int r) {
int lsttop = top;
for (pii p : tree[rt]) {
merge(p.fst, p.scd);
}
if (l == r) {
b[l] = ans;
} else {
int mid = (l + r) >> 1;
dfs(rt << 1, l, mid);
dfs(rt << 1 | 1, mid + 1, r);
}
while (top > lsttop) {
undo();
}
}
void solve() {
scanf("%lld", &n);
ans = n * (n - 1) / 2;
for (int i = 1; i <= n; ++i) {
scanf("%lld", &a[i]);
++c[a[i]];
fa[i] = i;
sz[i] = 1;
rnk[i] = 1;
}
for (int i = 1, u, v; i < n; ++i) {
scanf("%d%d", &u, &v);
if (a[u] > a[v]) {
swap(u, v);
}
pii e = make_pair(u, v);
update(1, 1, n, 1, a[u] - 1, e);
update(1, 1, n, a[u] + 1, a[v] - 1, e);
update(1, 1, n, a[v] + 1, n, e);
}
dfs(1, 1, n);
for (int i = 1; i <= n; ++i) {
printf("%lld\n", b[i] + c[i]);
}
}
int main() {
int T = 1;
// scanf("%d", &T);
while (T--) {
solve();
}
return 0;
}