题意
给定一棵 \(n\) 个节点的有根树和 \(m\) 条祖先到后代的链。问有多少种把边权设置为 \(0\) 或 \(1\) 的方案使得每条链上至少有一条边是 \(1\)。
答案对 \(998244353\) 取模。
\(1 \leq n,m \leq 5 \times 10^5\)
题解
我们将链的下端称为限制的起点。容易发现,对于同一个起点,终点越深,限制越强,于是不妨只考虑起点。
设 \(f(u,x)\) 表示限制的起点在 \(u\) 的子树之内,不满足条件的限制最深深度为 \(x\) 的方案数。最终显然有 \(f(1,0)\) 为答案。
考虑树形 DP,不断合并子树来求解 \(f\)。首先遍历 \(u\) 为起点的限制,求出最深的深度 \(d_{max}\),则初始子树中只有 \(u\) 一个点,\(f(u,d_{max})=1\)。
考虑合并子树 \(v\),根据 \((u,v)\) 这条边是 \(0\) 还是 \(1\) 转移:
\[f(u,d) \leftarrow \sum \limits_{i=0}^{d} f(v,i) f(u,d) +\sum \limits_{i=0}^{d-1}f(u,i)f(v,d) + \sum \limits_{i=0}^{dep_u} f(v,i)f(u,d) \]考虑第二个和式的上界是 \(d-1\),因为 \(f(u,d)f(v,d)\) 会出现两次。
记前缀和 \(g(u,d) = \sum \limits_{i=0}^d f(u,d)\),
\[f(u,d) \leftarrow g(v,d)f(u,d)+g(u,d-1)f(v,d)+g(v,dep_u)f(u,d) \\ f(u,d) \leftarrow f(u,d)(g(v,d)+g(v,dep_u))+g(u,d-1)f(v,d) \]考虑线段树合并转移。这里讲一下转移过程:\(g(v,dep_u)\) 在整个合并过程中都是常量,可以提前查询得到。合并到 \([l,r]\) 时维护 \(g(u,l-1)\) 和 \(g(v,l-1)\),因为我们是从左到右合并,于是这是容易维护的。
如果 \(u,v\) 的树上都有 \([l,r]\) 的节点,直接递归合并。到叶子了就按照上面的式子直接转移。重点讲一下 \(u,v\) 有一个为空的情况:
- \(u\) 为空,那么 \(f(u,l\cdots r)=0\),于是上面式子里面 \(f(u,d)\) 的项全是 \(0\),且区间内的 \(g(u,d-1) = g(u,l-1)\)。直接给 \(f(v,l\cdots r)\) 乘上 \(g(u,l-1)\) 即可。
- \(v\) 为空,那么 \(f(v,l\cdots r) =0\),于是上面式子里面 \(f(v,d)\) 的项全是 \(0\),且区间内的 \(g(v,d) = g(v,l-1)\)。直接给 \(f(u,l\cdots r)\) 乘上 \(g(v,l-1)+g(v,dep_u)\) 即可。
# include <bits/stdc++.h>
const int N=500010,mod=998244353;
int n,m;
int dep[N],rt[N];
std::vector <int> G[N];
std::vector <int> lim[N];
struct Node{
int sum,lc,rc,tag;
Node(){
tag=1;
return;
}
}tr[N*30];
int cnt;
inline int read(void){
int res,f=1;
char c;
while((c=getchar())<'0'||c>'9')
if(c=='-') f=-1;
res=c-48;
while((c=getchar())>='0'&&c<='9')
res=res*10+c-48;
return res*f;
}
inline void add(int &x,int v){
x+=v;
if(x>=mod) x-=mod;
return;
}
inline int adc(int a,int b){
return (a+b<mod)?(a+b):(a+b-mod);
}
inline int mul(int a,int b){
return 1ll*a*b%mod;
}
inline int& lc(int x){
return tr[x].lc;
}
inline int& rc(int x){
return tr[x].rc;
}
inline void pushup(int x){
tr[x].sum=adc(tr[lc(x)].sum,tr[rc(x)].sum);
return;
}
inline void mule(int x,int v){
if(!x) return;
tr[x].sum=mul(tr[x].sum,v);
tr[x].tag=mul(tr[x].tag,v);
return;
}
inline void pushdown(int x){
if(tr[x].tag!=1)
mule(lc(x),tr[x].tag),mule(rc(x),tr[x].tag),tr[x].tag=1;
return;
}
void change(int &k,int l,int r,int x,int v){
if(!k) k=++cnt;
// printf("qwq = %d\n",tr[x].tag);
if(l==r) return tr[k].sum=v,void();
int mid=(l+r)>>1;
if(x<=mid) change(lc(k),l,mid,x,v);
else change(rc(k),mid+1,r,x,v);
pushup(k);
return;
}
void merge(int &k,int x,int y,int l,int r,int &gu,int &gv){ // gu[i-1] gv[i-1]
// printf("exe\n");
if(!x&&!y) return k=0,void();
if(!x){
add(gv,tr[y].sum),mule(y,gu),k=y;
return;
}
if(!y){
add(gu,tr[x].sum),mule(x,gv),k=x;
return;
}
if(l==r){
int fu=tr[x].sum,fv=tr[y].sum;
add(gv,fv),
tr[x].sum=adc(1ll*tr[x].sum*gv%mod,1ll*gu*fv%mod),add(gu,fu);
k=x;
return;
}
int mid=(l+r)>>1;
pushdown(x),pushdown(y);
merge(lc(k),lc(x),lc(y),l,mid,gu,gv);
merge(rc(k),rc(x),rc(y),mid+1,r,gu,gv);
pushup(k);
return;
}
int query(int k,int l,int r,int L,int R){
if(!k) return 0;
if(L<=l&&r<=R) return tr[k].sum;
pushdown(k);
int mid=(l+r)>>1,res=0;
if(L<=mid) add(res,query(lc(k),l,mid,L,R));
if(mid<R) add(res,query(rc(k),mid+1,r,L,R));
return res;
}
void dfs(int i,int fa){
int md=0,su,sv;
dep[i]=dep[fa]+1;
for(auto v:lim[i]) md=std::max(md,dep[v]);
change(rt[i],0,n,md,1);
for(auto v:G[i]){
if(v==fa) continue;
dfs(v,i),su=0,sv=query(rt[v],0,n,0,dep[i]);
merge(rt[i],rt[i],rt[v],0,n,su,sv);
}
return;
}
int main(void){
n=read();
for(int i=1;i<n;++i){
int u=read(),v=read();
G[u].push_back(v),G[v].push_back(u);
}
m=read();
while(m--){
int u=read(),v=read();
lim[v].push_back(u);
}
dfs(1,0);
printf("%d",query(rt[1],0,n,0,0));
return 0;
}
标签:return,命运,NOI2020,int,题解,sum,dep,cdots,rc
From: https://www.cnblogs.com/liuzongxin/p/17596116.html