题目描述
九条可怜是一个喜欢算法的女孩子,在众多算法中她尤其喜欢深度优先搜索(DFS)。
有一天,可怜得到了一棵有根树,树根为 \(\mathit{root}\),树上每个节点 \(x\) 有一个权值 \(a_x\)。
在一棵树上从 \(x\) 出发,寻找 \(y\) 节点,如果使用深度优先搜索,则可描述为以下演算过程:
- 将递归栈设置为空。
- 首先将节点 \(x\) 放入递归栈中。
- 从递归栈中取出栈顶节点,如果该节点为 \(y\),则结束演算过程;否则,如果存在未访问的直接子节点,则以均等概率随机选择一个子节点加入递归栈中。
- 重复步骤 3,直到不存在未访问的直接子节点。
- 将上一级节点加入递归栈中,重复步骤 3。
- 重复步骤 5,直至当前一级节点为 \(x\),演算过程结束。
我们定义 \(f(x, y)\) 合法当且仅当 \(y\) 在 \(x\) 的子树中。它的值为从 \(x\) 出发,对 \(x\) 的子树进行深度优先搜索寻找 \(y\) 期间访问过的所有节点(包括 \(x\) 和 \(y\))权值最小值的期望。
九条可怜想知道对于所有合法的点对 \((x, y)\),\(\sum f(x, y)\) 的值。你只需要输出答案对 \(998244353\) 取模的结果。具体地,如果答案的最简分数表示为 \(\frac{a}{b}\),输出 \(a \times b^{-1} \bmod 998244353\)。
提示
对于所有测试点,满足 \(1 \le T \le 100\),\(\sum n \le 8 \times {10}^5\),\(1 \le n \le 4 \times {10}^5\),\(1 \le \mathit{root}, u, v \le n\),\(1 \le a_i \le {10}^9\)。
每个测试点的具体限制见下表:
测试点编号 | \(\sum n \le\) | \(n \le\) | 特殊限制 |
---|---|---|---|
\(1\) | \(50\) | \(10\) | 无 |
\(2 \sim 4\) | \(40000\) | \(5000\) | 无 |
\(5 \sim 10\) | \(4 \times {10}^5\) | \({10}^5\) | 无 |
\(11\) | \(8 \times {10}^5\) | \(4 \times {10}^5\) | 树的生成方式随机 |
\(12\) | \(8 \times {10}^5\) | \(4 \times {10}^5\) | 树是一条链 |
\(13\) | \(8 \times {10}^5\) | \(4 \times {10}^5\) | 根的度数为 \(n - 1\) |
\(14 \sim 20\) | \(8 \times {10}^5\) | \(4 \times {10}^5\) | 无 |
对于测试点 \(11\),树的生成方式为:以 \(1\) 为根,对于节点 \(i \in [2, n]\),从 \([1, i - 1]\) 中等概率随机选择一个点作为父亲。之后将编号随机重排。
题解
默认根的父亲为 \(0\)。
算法1
对每个权值 \(val\) 计算 \(P_{val}\) 表示权值最小值 \(>=val\) 对答案的贡献,最后对 \(P\) 差分,答案即为 \(\sum a_xP_{a_x}\)。
首先离散化权值,从小到大枚举权值 \(min\_val\) 并计算 \(P_{min\_val}\)。
以下是计算 \(P_{min\_val}\)的过程。
对于点 \(i\),若其权值 \(a_i>=min\_val\),则其为黑点, \(color_i=1\),若其权值 \(a_i<min\_val\),则其为白点,\(color_i=0\)。
设 \(f_x\) 表示从点 \(x\) 开始深搜的答案,则 \(P_{min\_val}=\sum f_x\),因此我们再设 \(g_x=\sum\limits_{y\in subtree(x)}f_y\) , \(P_{min\_val}=g_{root}\)。
\(g\) 的转移是容易的,若我们能计算出 \(f\),则 \(g_x=f_x+\sum\limits_{y \in son(x)}g_y\)。
\(f\) 如何转移?
考虑 \(dfs\) 的过程,不难发现若路径合法,则路径上的点全是黑点,转移分类讨论即可。
具体地说,我们记 \(tag_x=[以x为根的子树全为黑点],cnt_x=\sum\limits_{y\in son(x)}(1-tag_y)\),\(cnt_x\) 的定义是 \(x\) 的儿子中,有几个儿子,以它们为根的子树内的点非全黑。
每次随机一个子树走入,若我们在全黑子树内停止,相当于这个全黑子树被走入的时间要排在所有非全黑子树前,由于概率均等,所以概率为 \(\frac{cnt_x!}{(1+cnt_x)|}=\frac{1}{1+cnt_x}\)。
同理,若在非全黑子树内停止,相当于这个非全黑子树是最早被走入的非全黑子树,概率为 \(\frac{(cnt_x-1)!}{cnt_x!}=\frac{1}{cnt_x}\)。
那么 \(f\) 的转移就有了,对于\(f_x\),先判断其是不是黑点,若是白点,值为 \(0\) ,否则先加上 \(f(x,x)=1\) ,然后转移。
写成式子就是 \(f_x=color_x*(1+\sum\limits_{y\in son(x)}\frac{1}{cnt_x+tag_y}f_y)\)。
然后每次 \(min\_val\) 改变的时候,我们修改 \(color\) 后重新计算 \(tag,cnt,f,g\) 即可做到 \(O(n^2)\),可以获得 \(20\) 分。
代码
#include<bits/stdc++.h>
#define For(i,l,r) for(int i=(l);i<=(r);++i)
typedef long long ll;
const int mod=998244353;
const int N=400010;
using namespace std;
int n,root,tot;
int ver[N<<1],nxt[N<<1],head[N],b[N],cnt[N];
vector<int> id[N];
bool color[N],tag[N];
ll ans;
ll a[N],P[N],f[N],g[N],inv[N];
template<typename T1,typename T2>
void Add(T1 &a,T2 b){a+=b;if(a>=mod)a-=mod;return;}
template<typename T1,typename T2>
void Sub(T1 &a,T2 b){a-=b;if(a<0)a+=mod;return;}
void add(int x,int y)
{
ver[++tot]=y;
nxt[tot]=head[x];
head[x]=tot;
}
void calc_tag(int x,int fa)
{
tag[x]=color[x];
for(int i=head[x];i;i=nxt[i])
{
int y=ver[i];
if(y!=fa)
{
calc_tag(y,x);
tag[x]&=tag[y];
}
}
}
void calc_cnt(int x,int fa)
{
cnt[x]=0;
for(int i=head[x];i;i=nxt[i])
{
int y=ver[i];
if(y!=fa)
{
calc_cnt(y,x);
cnt[x]+=(1-tag[y]);
}
}
}
void calc(int x,int fa)
{
f[x]=1;
g[x]=0;
for(int i=head[x];i;i=nxt[i])
{
int y=ver[i];
if(y!=fa)
{
calc(y,x);
ll tmp=inv[cnt[x]+tag[y]];
(tmp*=f[y])%=mod;
Add(f[x],tmp);
Add(g[x],g[y]);
}
}
(f[x]*=color[x])%=mod;
Add(g[x],f[x]);
}
void solve()
{
tot=0;
scanf("%d%d",&n,&root);
{
For(i,1,n)
head[i]=0;
For(i,1,n)
id[i].clear();
};
For(i,1,n)
{
scanf("%lld",&a[i]);
b[i]=a[i];
}
{
sort(a+1,a+n+1);
For(i,1,n)
b[i]=(lower_bound(a+1,a+n+1,b[i])-a);
For(i,1,n)
id[b[i]].push_back(i);
};
For(i,1,(n-1))
{
int x,y;
scanf("%d%d",&x,&y);
add(x,y);
add(y,x);
}
{
For(i,1,n)
color[i]=1;
};
{
For(min_val,1,n)
{
if(id[min_val].empty())
continue;
calc_tag(root,-1);
calc_cnt(root,-1);
calc(root,-1);
P[min_val]=g[root];
for(auto j:id[min_val])
color[j]=0;
}
};
{
For(i,1,(n-1))
Sub(P[i],P[i+1]);
};
{
ans=0;
For(i,1,n)
{
ll E_i=P[i];
(E_i*=a[i])%=mod;
Add(ans,E_i);
}
printf("%lld\n",ans);
};
return;
}
int main()
{
{
inv[1]=1;
For(i,2,(N-1))
{
inv[i]=mod;
Sub(inv[i],(mod/i));
(inv[i]*=inv[mod%i])%=mod;
}
};
int T;
scanf("%d",&T);
while(T--)
solve();
return 0;
}
算法2
我们注意到每次修改一个点的 \(color\) ,这个量的总变化次数是点数级别的,暴力去做即可。
可能影响的信息还有这个点的 \(tag\) 值,若这个点有父亲,可能会影响其父亲的 \(cnt\) 值,不可能每次重新计算。
但是我们注意到\(tag_x=1\) 的时间是一段前缀,具体地,\(tag_x\) 在以 \(x\) 为根的子树内的最小值对应的点变为白色后从 \(1\) 变成 \(0\) ,因此在最开始预处理子树最小值后暴力修改 \(tag_x\) 及其父亲的 \(cnt\) 值即可。
但是可能会影响的 \(f,g\) 值,都是这个点到根的一条链。
使用动态 \(dp\) 的套路,轻重链剖分,对每个点维护轻儿子相关信息和,对于每条重链的链顶额外维护子树信息和即可,具体地说,下文的 \(f\_light,g\_light\) 表示轻儿子相关信息和,\(f,g\) 定义不变,就是子树信息和,\(sum\_f\) 是辅助转移的数组。
设 \(sum\_f[x][0]=\sum\limits_{y是x轻儿子} f[y]*[tag[y]==0],sum\_f[x][1]=\sum\limits_{y是x轻儿子} f[y]*[tag[y]==1]\)。
以下 \(y\) 表示 \(x\) 的重儿子(若其存在。)。
那么 \(f\) 的转移就可以写成 \(f_x=color[x]*(1+\frac{1}{cnt[x]}sum\_f[x][0]+\frac{1}{cnt[x]+1}sum\_f[x][1]+\frac{1}{cnt[x]+tag[y]}f[y])\)。
设 \(f\_light[x]=color[x]*(1+\frac{1}{cnt[x]}sum\_f[x][0]+\frac{1}{cnt[x]+1}sum\_f[x][1])\),\(g\_light[x]=f\_light[x]+\sum\limits_{z是x轻儿子}g[z]\)。
则 \(f,g\) 的转移可以写成 \(f[x]=f\_light[x]+\frac{1}{cnt[x]+tag[y]}f[y],g[x]=g\_light[x]+\frac{1}{cnt[x]+tag[y]}f[y]\) 。
转移写成矩阵乘法形式后用线段树维护即可。
每次需要将点由黑变白,以及修改一些 \(tag_x,cnt_{fa_x}\),以及若 \(x\) 不是 \(fa_x\) 的重儿子,记得修改 \(sum\_f[fa_x][0],sum\_f[fa_x][1]\)。
时间复杂度 \(O(n\ log^2\ n)\),至少可以获得 \(55\) 分。
代码
#include<bits/stdc++.h>
#define For(i,l,r) for(int i=(l);i<=(r);++i)
typedef long long ll;
const int mod=998244353;
const int N=400010;
using namespace std;
int n,root,tot;
int a[N],ver[N<<1],nxt[N<<1],head[N],b[N],cnt[N];
vector<int> id[N],tim_e[N];
bool color[N],tag[N];
ll ans;
ll P[N],f[N],g[N],f_light[N],g_light[N],inv[N],sum_f[N][2];
template<typename T1,typename T2>
void Add(T1 &a,T2 b){a+=b;if(a>=mod)a-=mod;return;}
template<typename T1,typename T2>
void Sub(T1 &a,T2 b){a-=b;if(a<0)a+=mod;return;}
void add(int x,int y)
{
ver[++tot]=y;
nxt[tot]=head[x];
head[x]=tot;
}
int subtree_min[N];
void calc_subtree_min(int x,int fa)
{
subtree_min[x]=b[x];
for(int i=head[x];i;i=nxt[i])
{
int y=ver[i];
if(y!=fa)
{
calc_subtree_min(y,x);
subtree_min[x]=min(subtree_min[x],subtree_min[y]);
}
}
}
int dep[N],fa[N],siz_e[N],L_size[N],son[N];
void dfs1(int x,int Fa)
{
fa[x]=Fa;
dep[x]=(dep[Fa]+1);
siz_e[x]=1;
son[x]=0;
for(int i=head[x];i;i=nxt[i])
{
int y=ver[i];
if(y!=Fa)
{
dfs1(y,x);
siz_e[x]+=siz_e[y];
if((son[x]==0) || (siz_e[y]>siz_e[son[x]]))
son[x]=y;
}
}
L_size[x]=siz_e[x];
if(son[x])
L_size[x]-=siz_e[son[x]];
}
int cnt_;
int dfn[N],rnk[N],top[N],en_d[N];
void dfs2(int x,int Top)
{
top[x]=Top;
en_d[Top]=x;
++cnt_;
dfn[x]=cnt_;
rnk[cnt_]=x;
if(son[x])
dfs2(son[x],Top);
for(int i=head[x];i;i=nxt[i])
{
int y=ver[i];
if((y!=fa[x]) && (y!=son[x]))
dfs2(y,y);
}
}
void get_f_light(int x)
{
f_light[x]=1;
ll tmp0=sum_f[x][0],tmp1=sum_f[x][1];
(tmp0*=inv[cnt[x]])%=mod;
(tmp1*=inv[cnt[x]+1])%=mod;
Add(f_light[x],tmp0);
Add(f_light[x],tmp1);
(f_light[x]*=color[x])%=mod;
return;
}
struct Matrix{ll a[3][3];}I;
Matrix operator * (Matrix mat1,Matrix mat2)
{
Matrix res;
For(i,0,2)
{
For(j,0,2)
res.a[i][j]=0;
}
For(i,0,2)
{
For(j,0,2)
{
For(k,0,2)
{
ll tmp=mat1.a[i][k];
(tmp*=mat2.a[k][j])%=mod;
Add(res.a[i][j],tmp);
}
}
}
return res;
}
struct node{Matrix prod;}tree[N<<2];
Matrix get_mat(int x)
{
Matrix res;
For(i,0,2)
{
For(j,0,2)
res.a[i][j]=0;
}
if(color[x])
{
res.a[0][0]=inv[cnt[x]+tag[son[x]]];
res.a[1][0]=inv[cnt[x]+tag[son[x]]];
}
get_f_light(x);
res.a[0][2]=f_light[x];
res.a[1][2]=g_light[x];
res.a[1][1]=1;
res.a[2][2]=1;
return res;
}
#define lson(x) (x<<1)
#define rson(x) (x<<1|1)
void pushup(int x){tree[x].prod=(tree[lson(x)].prod*tree[rson(x)].prod);return;}
void build(int x,int l,int r)
{
if(l==r)
{
tree[x].prod=get_mat(rnk[l]);
return;
}
int mid=((l+r)>>1);
build(lson(x),l,mid);
build(rson(x),(mid+1),r);
pushup(x);
}
Matrix query(int x,int l,int r,int L,int R)
{
if(L<=l && r<=R)
return tree[x].prod;
Matrix res=I;
int mid=((l+r)>>1);
if(L<=mid)
res=(res*query(lson(x),l,mid,L,R));
if((mid+1)<=R)
res=(res*query(rson(x),(mid+1),r,L,R));
return res;
}
void modify(int x,int l,int r,int pos)
{
if(l==r)
{
tree[x].prod=get_mat(rnk[pos]);
return;
}
int mid=((l+r)>>1);
if(pos<=mid)
modify(lson(x),l,mid,pos);
if((mid+1)<=pos)
modify(rson(x),(mid+1),r,pos);
pushup(x);
}
void calc_subtree(int x)
{
Matrix res=query(1,1,n,dfn[x],dfn[en_d[x]]);
f[x]=res.a[0][2];
g[x]=res.a[1][2];
return;
}
void modify(int x)
{
while(x)
{
int top_x=top[x],fa_top_x=fa[top_x];
if(fa_top_x)
{
Sub(g_light[fa_top_x],g[top_x]);
Sub(g_light[fa_top_x],f_light[fa_top_x]);
Sub(sum_f[fa_top_x][tag[top_x]],f[top_x]);
}
modify(1,1,n,dfn[x]);
calc_subtree(top_x);
if(fa_top_x)
{
Add(sum_f[fa_top_x][tag[top_x]],f[top_x]);
get_f_light(fa_top_x);
Add(g_light[fa_top_x],f_light[fa_top_x]);
Add(g_light[fa_top_x],g[top_x]);
}
x=fa[top[x]];
}
return;
}
void change_color(int x)
{
Sub(g_light[x],f_light[x]);
color[x]=0;
f_light[x]=0;
modify(x);
return;
}
void change(int x)
{
tag[x]=0;
int fa_x=fa[x];
if(fa_x==0)
return;
if(son[fa_x]!=x)
{
calc_subtree(x);
Sub(sum_f[fa_x][1],f[x]);
Add(sum_f[fa_x][0],f[x]);
}
Sub(g_light[fa_x],f_light[fa_x]);
++cnt[fa_x];
get_f_light(fa_x);
Add(g_light[fa_x],f_light[fa_x]);
modify(fa_x);
return;
}
void calc_init(int x)
{
f[x]=0;
g[x]=0;
for(int i=head[x];i;i=nxt[i])
{
int y=ver[i];
if(y!=fa[x])
{
calc_init(y);
ll tmp=inv[cnt[x]+tag[y]];
(tmp*=f[y])%=mod;
Add(f[x],tmp);
Add(g[x],g[y]);
if(y!=son[x])
{
ll tmp0=f[y],tmp1=f[y];
(tmp0*=(1-tag[y]))%=mod;
(tmp1*=tag[y])%=mod;
Add(sum_f[x][0],tmp0);
Add(sum_f[x][1],tmp1);
}
}
}
Add(f[x],1);
(f[x]*=color[x])%=mod;
Add(g[x],f[x]);
}
void calc_light()
{
For(x,1,n)
{
get_f_light(x);
g_light[x]=g[x];
if(son[x])
{
ll tmp=inv[cnt[x]+tag[son[x]]];
(tmp*=f[son[x]])%=mod;
Sub(g_light[x],tmp);
Sub(g_light[x],g[son[x]]);
}
}
return;
}
void calc()
{
calc_init(root);
calc_light();
return;
}
void solve()
{
scanf("%d%d",&n,&root);
{
tot=0;
cnt_=0;
For(i,1,n)
head[i]=0;
For(i,1,n)
id[i].clear();
For(i,1,n)
tim_e[i].clear();
For(i,1,n)
{
f[i]=0;
g[i]=0;
sum_f[i][0]=0;
sum_f[i][1]=0;
f_light[i]=0;
g_light[i]=0;
}
};
For(i,1,n)
{
scanf("%d",&a[i]);
b[i]=a[i];
}
{
sort(a+1,a+n+1);
For(i,1,n)
b[i]=(lower_bound(a+1,a+n+1,b[i])-a);
For(i,1,n)
id[b[i]].push_back(i);
};
For(i,1,(n-1))
{
int x,y;
scanf("%d%d",&x,&y);
add(x,y);
add(y,x);
}
{
dfs1(root,0);
dfs2(root,root);
calc_subtree_min(root,0);
For(i,1,n)
color[i]=1;
For(i,1,n)
tag[i]=1;
For(i,1,n)
cnt[i]=0;
For(i,1,n)
tim_e[subtree_min[i]].push_back(i);
};
{
calc();
build(1,1,n);
For(min_val,1,n)
{
if(id[min_val].empty())
continue;
P[min_val]=g[root];
for(auto j:id[min_val])
change_color(j);
for(auto j:tim_e[min_val])
change(j);
}
};
{
For(i,1,(n-1))
Sub(P[i],P[i+1]);
};
{
ans=0;
For(i,1,n)
{
ll E_i=P[i];
(E_i*=a[i])%=mod;
Add(ans,E_i);
}
printf("%lld\n",ans);
};
return;
}
int main()
{
For(i,0,2)
{
For(j,0,2)
I.a[i][j]=0;
}
For(i,0,2)
I.a[i][i]=1;
{
inv[1]=1;
For(i,2,(N-1))
{
inv[i]=mod;
Sub(inv[i],(mod/i));
(inv[i]*=inv[mod%i])%=mod;
}
};
int T;
scanf("%d",&T);
while(T--)
solve();
return 0;
}
算法3
将算法 \(2\) 的线段树改为全局平衡二叉树,即可做到 \(O(n\ log\ n)\),需要略微精细实现,比如矩阵乘法有个 \(27\) 的常数,但是我们注意到矩阵中只有 \(4\) 个位置的值会变,其他都是固定的,那么我们只维护这 \(4\) 个位置即可,将矩阵乘法手动计算后,这 \(4\) 个位置的值都可以快速求出,这是一个很大的优化。
可以获得 \(100\) 分。
代码
#include<bits/stdc++.h>
#define For(i,l,r) for(int i=(l);i<=(r);++i)
typedef long long ll;
const int mod=998244353;
const int N=400010;
using namespace std;
int n,root,tot;
int a[N],ver[N<<1],nxt[N<<1],head[N],b[N],cnt[N],rt[N];
vector<int> id[N],tim_e[N];
bool color[N],tag[N];
ll ans;
ll P[N],f[N],g[N],f_light[N],g_light[N],inv[N],sum_f[N][2];
template<typename T1,typename T2>
void Add(T1 &a,T2 b){a+=b;if(a>=mod)a-=mod;return;}
template<typename T1,typename T2>
void Sub(T1 &a,T2 b){a-=b;if(a<0)a+=mod;return;}
void add(int x,int y)
{
ver[++tot]=y;
nxt[tot]=head[x];
head[x]=tot;
}
int subtree_min[N];
void calc_subtree_min(int x,int fa)
{
subtree_min[x]=b[x];
for(int i=head[x];i;i=nxt[i])
{
int y=ver[i];
if(y!=fa)
{
calc_subtree_min(y,x);
subtree_min[x]=min(subtree_min[x],subtree_min[y]);
}
}
}
int dep[N],fa[N],siz_e[N],L_size[N],son[N];
void dfs1(int x,int Fa)
{
fa[x]=Fa;
dep[x]=(dep[Fa]+1);
siz_e[x]=1;
son[x]=0;
for(int i=head[x];i;i=nxt[i])
{
int y=ver[i];
if(y!=Fa)
{
dfs1(y,x);
siz_e[x]+=siz_e[y];
if((son[x]==0) || (siz_e[y]>siz_e[son[x]]))
son[x]=y;
}
}
L_size[x]=siz_e[x];
if(son[x])
L_size[x]-=siz_e[son[x]];
}
int cnt_;
int dfn[N],rnk[N],top[N],en_d[N];
void dfs2(int x,int Top)
{
top[x]=Top;
en_d[Top]=x;
++cnt_;
dfn[x]=cnt_;
rnk[cnt_]=x;
if(son[x])
dfs2(son[x],Top);
for(int i=head[x];i;i=nxt[i])
{
int y=ver[i];
if((y!=fa[x]) && (y!=son[x]))
dfs2(y,y);
}
}
void get_f_light(int x)
{
f_light[x]=1;
ll tmp0=sum_f[x][0],tmp1=sum_f[x][1];
(tmp0*=inv[cnt[x]])%=mod;
(tmp1*=inv[cnt[x]+1])%=mod;
Add(f_light[x],tmp0);
Add(f_light[x],tmp1);
(f_light[x]*=color[x])%=mod;
return;
}
struct Matrix{ll a[3][3];}I;
Matrix operator * (Matrix mat1,Matrix mat2)
{
Matrix res;
res.a[0][0]=mat1.a[0][0];
(res.a[0][0]*=mat2.a[0][0])%=mod;
res.a[0][2]=mat1.a[0][0];
(res.a[0][2]*=mat2.a[0][2])%=mod;
Add(res.a[0][2],mat1.a[0][2]);
res.a[1][0]=mat1.a[1][0];
(res.a[1][0]*=mat2.a[0][0])%=mod;
Add(res.a[1][0],mat2.a[1][0]);
res.a[1][2]=mat1.a[1][0];
(res.a[1][2]*=mat2.a[0][2])%=mod;
Add(res.a[1][2],mat1.a[1][2]);
Add(res.a[1][2],mat2.a[1][2]);
return res;
}
Matrix get_mat(int x)
{
Matrix res;
if(color[x])
{
res.a[0][0]=inv[cnt[x]+tag[son[x]]];
res.a[1][0]=inv[cnt[x]+tag[son[x]]];
}
else
{
res.a[0][0]=0;
res.a[1][0]=0;
}
get_f_light(x);
res.a[0][2]=f_light[x];
res.a[1][2]=g_light[x];
return res;
}
int seq[N],weight[N];
#define lson(x) (tree[x].lson)
#define rson(x) (tree[x].rson)
struct node{int lson,rson,anc;Matrix prod,mat;}tree[N];
void pushup(int x){tree[x].prod=(tree[lson(x)].prod*tree[x].mat*tree[rson(x)].prod);return;}
int build_heavy_chain(int L,int R)
{
if(L>R)
return 0;
ll sum=0,sum_now=0;
For(i,L,R)
sum+=(1ll*L_size[i]);
For(i,L,R)
{
sum_now+=(1ll*L_size[i]);
if((sum_now*1ll*2)>sum)
{
int root=seq[i];
tree[root].mat=get_mat(root);
tree[root].lson=build_heavy_chain(L,(i-1));
tree[lson(root)].anc=root;
tree[root].rson=build_heavy_chain((i+1),R);
tree[rson(root)].anc=root;
pushup(root);
return root;
}
}
return 0;
}
void build(int Top)
{
for(int x=Top;x;x=son[x])
{
for(int i=head[x];i;i=nxt[i])
{
int y=ver[i];
if((y!=fa[x]) && (y!=son[x]))
build(y);
}
}
int num=0;
for(int x=Top;x;x=son[x])
{
++num;
seq[num]=x;
weight[num]=L_size[x];
}
rt[Top]=build_heavy_chain(1,num);
tree[rt[Top]].anc=0;
}
void update(int x)
{
tree[x].mat=get_mat(x);
for(;x;x=tree[x].anc)
pushup(x);
return;
}
void modify(int x)
{
while(x)
{
int top_x=top[x],fa_top_x=fa[top_x];
if(fa_top_x)
{
Sub(g_light[fa_top_x],tree[rt[top_x]].prod.a[1][2]);
Sub(g_light[fa_top_x],f_light[fa_top_x]);
Sub(sum_f[fa_top_x][tag[top_x]],tree[rt[top_x]].prod.a[0][2]);
}
update(x);
if(fa_top_x)
{
Add(sum_f[fa_top_x][tag[top_x]],tree[rt[top_x]].prod.a[0][2]);
get_f_light(fa_top_x);
Add(g_light[fa_top_x],f_light[fa_top_x]);
Add(g_light[fa_top_x],tree[rt[top_x]].prod.a[1][2]);
}
x=fa[top[x]];
}
return;
}
void change_color(int x)
{
Sub(g_light[x],f_light[x]);
color[x]=0;
f_light[x]=0;
modify(x);
return;
}
void change(int x)
{
tag[x]=0;
int fa_x=fa[x];
if(fa_x==0)
return;
if(son[fa_x]!=x)
{
Sub(sum_f[fa_x][1],tree[rt[x]].prod.a[0][2]);
Add(sum_f[fa_x][0],tree[rt[x]].prod.a[0][2]);
}
Sub(g_light[fa_x],f_light[fa_x]);
++cnt[fa_x];
get_f_light(fa_x);
Add(g_light[fa_x],f_light[fa_x]);
modify(fa_x);
return;
}
void calc_init(int x)
{
f[x]=0;
g[x]=0;
for(int i=head[x];i;i=nxt[i])
{
int y=ver[i];
if(y!=fa[x])
{
calc_init(y);
ll tmp=inv[cnt[x]+tag[y]];
(tmp*=f[y])%=mod;
Add(f[x],tmp);
Add(g[x],g[y]);
if(y!=son[x])
{
ll tmp0=f[y],tmp1=f[y];
(tmp0*=(1-tag[y]))%=mod;
(tmp1*=tag[y])%=mod;
Add(sum_f[x][0],tmp0);
Add(sum_f[x][1],tmp1);
}
}
}
Add(f[x],1);
(f[x]*=color[x])%=mod;
Add(g[x],f[x]);
}
void calc_light()
{
For(x,1,n)
{
get_f_light(x);
g_light[x]=g[x];
if(son[x])
{
ll tmp=inv[cnt[x]+tag[son[x]]];
(tmp*=f[son[x]])%=mod;
Sub(g_light[x],tmp);
Sub(g_light[x],g[son[x]]);
}
}
return;
}
void calc(int x)
{
calc_init(x);
calc_light();
return;
}
int read()
{
int x=0,f=1;
char ch=getchar();
while(ch<'0' || ch>'9')
{
if(ch=='-')
f=-1;
ch=getchar();
}
while(ch>='0' && ch<='9')
{
x=((x*10)+ch-'0');
ch=getchar();
}
return f*x;
}
void solve()
{
int root;
scanf("%d%d",&n,&root);
{
tot=0;
cnt_=0;
For(i,1,n)
head[i]=0;
For(i,1,n)
id[i].clear();
For(i,1,n)
tim_e[i].clear();
};
For(i,1,n)
{
scanf("%d",&a[i]);
b[i]=a[i];
}
{
sort(a+1,a+n+1);
For(i,1,n)
b[i]=(lower_bound(a+1,a+n+1,b[i])-a);
For(i,1,n)
id[b[i]].push_back(i);
};
For(i,1,(n-1))
{
int x,y;
scanf("%d%d",&x,&y);
add(x,y);
add(y,x);
}
{
dfs1(root,0);
dfs2(root,root);
calc_subtree_min(root,0);
For(i,1,n)
color[i]=1;
For(i,1,n)
tag[i]=1;
For(i,1,n)
cnt[i]=0;
For(i,1,n)
tim_e[subtree_min[i]].push_back(i);
};
{
calc(root);
build(root);
For(min_val,1,n)
{
P[min_val]=tree[rt[root]].prod.a[1][2];
for(auto j:id[min_val])
change_color(j);
for(auto j:tim_e[min_val])
change(j);
}
};
{
For(i,1,(n-1))
Sub(P[i],P[i+1]);
};
{
ans=0;
For(i,1,n)
{
ll E_i=P[i];
(E_i*=a[i])%=mod;
Add(ans,E_i);
}
printf("%lld\n",ans);
};
return;
}
int main()
{
For(i,0,2)
{
For(j,0,2)
I.a[i][j]=0;
}
For(i,0,2)
I.a[i][i]=1;
tree[0].mat=I;
tree[0].prod=I;
{
inv[1]=1;
For(i,2,(N-1))
{
inv[i]=mod;
Sub(inv[i],(mod/i));
(inv[i]*=inv[mod%i])%=mod;
}
};
int T;
scanf("%d",&T);
while(T--)
solve();
return 0;
}
标签:cnt,int,题解,sum,son,fa,light,ZJOI2022
From: https://www.cnblogs.com/llzer/p/17483886.html