思路
根据题意,极远点对实际上指的是树的直径,而树的直径有一条重要的性质是树的所有直径中点重合(OI Wiki)。
因此先求出树的直径中点 \(root\) ,并固定根为 \(root\),此时树的直径的端点均是叶子,且树的直径的两个端点在根的不同子树上。
设 \(u\) 为要求解的节点, \(fa\) 为 \(u\) 的祖先且 \(fa\) 为 \(root\) 的儿子, \(sum_i\) 为以 \(i\) 为根的子树内直径的端点数,\(f_i\) 为经过 \(u\) 的直径数。
有
\[f_u = (sum_{root} - sum_{fa}) \times sum_{u} \]具体地,根据树的直径的长度的奇偶性来分别讨论。
当直径的长度是奇数时,直接按照上文的公式计算即可。
当直径的长度是偶数时,树的直径又分为两类,一类是到根的距离为奇数的端点,一类是到根的距离为偶数的端点,记 \(total0_i\) 为以 \(i\) 为根的子树内到根的距离为奇数的端点数 \(total1_i\) 为以 \(i\) 为根的子树内到根的距离为偶数的端点数.
有
\[f_u = (total0_{root} - total0_{fa}) \times total1_u + (total1_{root} - total1_{fa}) \times total0_u \]其中根结点要特殊计算,简单容斥一下就好。
代码
#include <bits/stdc++.h>
using namespace std;
inline int read(){
static int bo, x; bo = x = 0;
static char c; c = getchar();
while(c < '0' || c > '9') {if(c == '-')bo = 1; c = getchar();}
while(c >= '0' && c <= '9'){x = (x<<3) + (x<<1) + c - '0'; c = getchar();}
return bo ? -x : x;
}
const int N = 5e6+11;
const long long mod = 998244353;
int n, k, h[N], nt[N<<1], to[N<<1], cnt;
long long maxdist1, st, ed;long long mid;
void link(int u, int v){
nt[++cnt] = h[u], h[u] = cnt, to[cnt] = v;
swap(u, v);
nt[++cnt] = h[u], h[u] = cnt, to[cnt] = v;
}
void dfs1(int u, int fa, long long dep){
if(maxdist1 < dep){
maxdist1 = dep;
st = u;
}
for(int i = h[u]; i; i = nt[i]){
int v = to[i];
if(v == fa) continue;
dfs1(v, u, dep+1);
}
}
void dfs2(int u, int fa, long long dep){
if(fa == 0) maxdist1 = 0;
if(maxdist1 < dep){
maxdist1 = dep;
ed = u;
}
for(int i = h[u]; i; i = nt[i]){
int v = to[i];
if(v == fa) continue;
dfs2(v, u, dep+1);
}
}
long long dfs3(int u, int fa, long long dep){
long long md = 0;
for(int i = h[u]; i; i = nt[i]){
int v = to[i];
if(v == fa) continue;
md = max(dfs3(v, u, dep+1)+1, md);
}
if(abs(md-dep) <= 1) mid = u;
return md;
}
long long num[N];
namespace subtack1{ // maxdist1 % 2 == 0
long long siz[N], tot;
void Dfs1(int u, int fa, long long dep){
for(int i = h[u]; i; i = nt[i]){
int v = to[i];
if(v == fa) continue;
Dfs1(v, u, dep+1);
siz[u] += siz[v];
siz[u] %= mod;
}
if(dep == maxdist1/2) siz[u] = 1ll;
}
void Dfs2(int u, int fa, long long oth){
num[u] = (long long) oth * siz[u] % mod;
for(int i = h[u]; i; i = nt[i]){
int v = to[i];
if(v == fa) continue;
Dfs2(v, u, oth);
}
}
void dp(){
Dfs1(mid, 0, 0);
num[mid] = 0;
long long tot = 0;
for(int i = h[mid]; i; i = nt[i]){
int v = to[i];
Dfs2(v, mid, siz[mid] - siz[v]);
num[mid] += (long long)siz[v] * tot % mod;
num[mid] %= mod;
tot += siz[v];
tot %= mod;
}
}
}
namespace subtack2{
long long siz[N], tot1[N], tot0[N];
void Dfs1(int u, int fa, long long dep){
for(int i = h[u]; i; i = nt[i]){
int v = to[i];
if(v == fa) continue;
Dfs1(v, u, dep+1);
tot0[u] += tot0[v];
tot1[u] += tot1[v];
tot0[u] %= mod;
tot1[u] %= mod;
}
if(dep == maxdist1/2) tot0[u]++;
if(dep == maxdist1/2+1) tot1[u]++;
}
void Dfs2(int u, int fa, long long oth0, long long oth1){
num[u] = (long long)((long long)oth0 * tot1[u]%mod + (long long)oth1 * tot0[u]%mod) % mod;
for(int i = h[u]; i; i = nt[i]){
int v = to[i];
if(v == fa) continue;
Dfs2(v, u, oth0, oth1);
}
}
void dp(){
Dfs1(mid, 0, 0);
long long t0 = 0, t1 = 0;
for(int i = h[mid]; i; i = nt[i]){
int v = to[i];
Dfs2(v, mid, tot0[mid] - tot0[v], tot1[mid] - tot1[v]);
num[mid] += (long long)((long long)t0 * tot1[v] % mod + (long long)t1 * tot0[v] % mod) % mod;
t0 += tot0[v];
t1 += tot1[v];
t0 %= mod;
t1 %= mod;
}
}
}
int main(){
n = read();
k = read();
for(int i = 1; i < n; i ++){
static int u, v;
u = read(), v = read();
link(u, v);
}
dfs1(1, 0, 0), dfs2(st, 0, 0), dfs3(st, 0, 0);
if(maxdist1%2 == 0) subtack1::dp();
else subtack2::dp();
long long ans = 0;
for(int i = 1; i <= n; i ++){
ans = (ans + (long long)(k == 1 ? (long long)num[i] : (long long)num[i] * num[i]%mod) % mod) % mod;
}
printf("%lld\n", ans);
return 0;
}
标签:P8981,sum,total0,fa,DROI,端点,直径,root,Round
From: https://www.cnblogs.com/dadidididi/p/17087163.html