和考试的时候思路差不多。
首先考虑钦定一部分关键点是合法的 根 ,带上容斥系数。
对于一条非树边,要求其在任何一个钦定点作为根的时候都不是横叉边。
具体而言,对于一个钦定点集合,我们建出钦定点集合的虚树,那么符合条件的非树边有如下几类:
不妨先考虑特殊性质 \(B\) ,没有横叉边的情况:
-
完全在某条虚树边上
-
从虚树上某个点向非虚树上的子树方向的边
-
虚树根向上方的边
对于每个点,预处理第二类和第三类的边的个数,记为 \(down_x,up_x\) 。
考试的时候写的是,\(dp_{x,i}\) 表示考虑 \(x\) 子树内的点和边,选出的、覆盖当前点的上端点最浅的非树边深度是 \(i\) 的方案数,转移和 [NOI2020]命运 一模一样。
但是这样需要线段树合并,而且不太好拓展。
我们考虑设 \(dp_x\) 表示 \(x\) 为虚树上的点的方案数,对于每个子树 \(y\),如果这子树里没选点,方案数是 \(2^{down_y}\) ,否则,方案数是枚举最浅的点 \(z\),用 \(dp_y\) 乘上 \(x\to z\) 的方案数,这个可以边转移边维护。
以 \(DFS\) 序为下标开一颗线段树,\(dp_x\) 存在 \(dfn_x\) 位置。
更新就是一个区间乘法,区间求和。
现在考虑一般情况。
注意到横叉边只有跨过虚树根的才合法,且必须被虚树包含且最多涉及到两棵子树。
对于一个二维平面上的点 \((x,y)\) ,我们表示选择的两个子树最浅的点的 \(dfs\) 序分别是 \(x,y\) 的方案数,初始时是 \(dp_x\times dp_y\)。
对于每个横叉边,相当于一个矩形内的方案数乘二,最终求所有位置的和,可以离散化后做一遍扫描线解决。
复杂度 \(O(n\log n)\)
#include<bits/stdc++.h>
using namespace std;
const int N = 5e5+7;
struct edge
{
int y,next;
}e[2*N];
int flink[N],t=0;
void add(int x,int y)
{
e[++t].y=y;
e[t].next=flink[x];
flink[x]=t;
}
int n,m,K;
bool key[N];
int dep[N],st[N],ed[N],tim=0;
int jump[N][21];
int fup[N],fdown[N],tree[N],pw[N];
const int mod = 1e9+7;
vector<int> Up[N],Down[N],G[N],A[N];
int Jump(int x,int y)
{
for(int k=19;k>=0;k--)
if(jump[x][k]&&dep[jump[x][k]]>dep[y])x=jump[x][k];
return x;
}
int lca(int x,int y)
{
if(dep[x]<dep[y])swap(x,y);
for(int k=19;k>=0;k--)if(dep[jump[x][k]]>=dep[y])x=jump[x][k];
if(x==y)return x;
for(int k=19;k>=0;k--)if(jump[x][k]!=jump[y][k])
{
x=jump[x][k];
y=jump[y][k];
}
return jump[x][0];
}
void dfs(int x,int pre)
{
jump[x][0]=pre;
for(int k=1;jump[x][k-1];k++)jump[x][k]=jump[jump[x][k-1]][k-1];
dep[x]=dep[pre]+1;
st[x]=++tim;
for(int i=flink[x];i;i=e[i].next)
{
int y=e[i].y;
if(y==pre)continue;
dfs(y,x);
}
ed[x]=tim;
}
inline bool Anc(int x,int y){return st[x]<=st[y]&&st[y]<=ed[x];}
int U[N],V[N];
void dfs2(int x,int pre)
{
tree[x]=Down[x].size();
for(int i=flink[x];i;i=e[i].next)
{
int y=e[i].y;
if(y==pre)continue;
dfs2(y,x);
tree[x]+=tree[y];
fdown[x]+=fdown[y];
}
}
int dis[N];
inline int plu(int a,int b){return a+b>=mod?a+b-mod:a+b;}
inline int dec(int a,int b){return a-b<0?a-b+mod:a-b;}
struct segment
{
int sum[N*4],tag[N*4];
void build(int k,int l,int r)
{
sum[k]=0;tag[k]=1;
if(l==r)
{
sum[k]=dis[l];
return;
}
int mid=(l+r)>>1;
build(k<<1,l,mid);
build(k<<1|1,mid+1,r);
sum[k]=plu(sum[k<<1],sum[k<<1|1]);
}
void pushup(int k)
{
sum[k]=plu(sum[k<<1],sum[k<<1|1]);
}
void pushtag(int k,int v)
{
sum[k]=1ll*sum[k]*v%mod;
tag[k]=1ll*tag[k]*v%mod;
}
void pushdown(int k)
{
if(tag[k]^1)
{
pushtag(k<<1,tag[k]);
pushtag(k<<1|1,tag[k]);
tag[k]=1;
}
}
void mul(int k,int l,int r,int L,int R,int v)
{
if(L<=l&&r<=R)
{
pushtag(k,v);
return;
}
pushdown(k);
int mid=(l+r)>>1;
if(L<=mid)mul(k<<1,l,mid,L,R,v);
if(R>mid) mul(k<<1|1,mid+1,r,L,R,v);
pushup(k);
}
int query(int k,int l,int r,int L,int R)
{
if(L>R)return 0;
if(L<=l&&r<=R) return sum[k];
pushdown(k);
int mid=(l+r)>>1;
int res=0;
if(L<=mid)res=plu(res,query(k<<1,l,mid,L,R));
if(R>mid) res=plu(res,query(k<<1|1,mid+1,r,L,R));
return res;
}
void upd(int k,int l,int r,int x,int v)
{
if(l==r)
{
sum[k]=v;
return;
}
pushdown(k);
int mid=(l+r)>>1;
if(x<=mid)upd(k<<1,l,mid,x,v);
else upd(k<<1|1,mid+1,r,x,v);
pushup(k);
}
}T,S;
int dp[N][4],f[4],ans=0;
int ipw[N];
const int inv2=(mod+1)/2;
int dct[N*5],tot=0;
struct Info
{
int x,l,r,v;
}seq[N*5];
bool cmp(Info A,Info B)
{
return A.x<B.x;
}
int cnt=0;
int get(int x)
{
return lower_bound(dct+1,dct+tot+1,x)-dct;
}
int siz[N];
int calc(int x)
{
if(A[x].empty())return dp[x][2];
tot=0;
dct[++tot]=st[x];
dct[++tot]=ed[x]+1;
for(int i:A[x])
{
int u=U[i],v=V[i];
dct[++tot]=st[u];dct[++tot]=ed[u]+1;
dct[++tot]=st[v];dct[++tot]=ed[v]+1;
}
sort(dct+1,dct+tot+1);
tot=unique(dct+1,dct+tot+1)-dct-1;
for(int i=1;i<=tot;i++)dis[i]=T.query(1,1,n,dct[i],dct[i+1]-1);
S.build(1,1,tot);
cnt=0;
seq[++cnt]=(Info){st[x],st[x],ed[x]+1,1};
seq[++cnt]=(Info){ed[x]+1,st[x],ed[x]+1,1};
for(int i:A[x])
{
int u=U[i],v=V[i];
seq[++cnt]=(Info){st[u],st[v],ed[v]+1,2};
seq[++cnt]=(Info){ed[u]+1,st[v],ed[v]+1,inv2};
seq[++cnt]=(Info){st[v],st[u],ed[u]+1,2};
seq[++cnt]=(Info){ed[v]+1,st[u],ed[u]+1,inv2};
}
sort(seq+1,seq+cnt+1,cmp);
int res=0;
for(int i=1;i<=cnt;i++)
{
if(i>1)res=(res+1ll*T.query(1,1,n,seq[i-1].x,seq[i].x-1)*S.sum[1]%mod)%mod;
S.mul(1,1,tot,get(seq[i].l),get(seq[i].r)-1,seq[i].v);
}
for(int i=flink[x];i;i=e[i].next)
{
int y=e[i].y;
if(y==jump[x][0])continue;
int V=T.query(1,1,n,st[y],ed[y]);
res=(res-1ll*V*V%mod+mod)%mod;
}
res=1ll*res*inv2%mod*ipw[tree[x]]%mod;
return res;
}
void dfs3(int x,int pre)
{
fup[x]+=siz[x]-Down[x].size();
for(int i=flink[x];i;i=e[i].next)
{
int y=e[i].y;
if(y==pre)continue;
fup[y]=fup[x]+tree[x]-fdown[y];
dfs3(y,x);
}
for(auto u:Down[x])T.mul(1,1,n,st[u],ed[u],2);
dp[x][0]=1;
for(int i=flink[x];i;i=e[i].next)
{
int y=e[i].y;
if(y==pre)continue;
for(int c=0;c<=3;c++)f[c]=dp[x][c],dp[x][c]=0;
for(int c=0;c<=3;c++)
{
dp[x][c]=plu(dp[x][c],1ll*f[c]*pw[fdown[y]]%mod);
dp[x][min(3,c+1)]=plu(dp[x][min(3,c+1)],1ll*f[c]*T.query(1,1,n,st[y],ed[y])%mod);
}
}
for(int i=flink[x];i;i=e[i].next)
{
int y=e[i].y;
if(y==pre)continue;
T.mul(1,1,n,st[y],ed[y],pw[tree[x]-fdown[y]]);
}
int F=0;
if(key[x])
{
for(int c=0;c<=3;c++)
F=(F+mod-dp[x][c])%mod;
}
F=(F+dp[x][3])%mod;
ans=(ans+1ll*(F+calc(x))%mod*pw[fup[x]]%mod)%mod;
F=(F+dp[x][2])%mod;
T.upd(1,1,n,st[x],F);
}
int main()
{
int tp;
cin>>tp;
cin>>n>>m>>K;
for(int i=1;i<n;i++)
{
int x,y;
scanf("%d %d",&x,&y);
add(x,y);
add(y,x);
}
dfs(1,0);
for(int i=1;i<=m;i++)
{
int x,y;
scanf("%d %d",&x,&y);
if(st[x]>st[y])swap(x,y);
if(Anc(x,y))
{
Up[y].push_back(x);
Down[x].push_back(y);
fdown[Jump(y,x)]++;
}
else A[lca(x,y)].push_back(i);
siz[x]++;siz[y]++;
U[i]=x;V[i]=y;
}
for(int i=1;i<=K;i++)
{
int x;
scanf("%d",&x);
key[x]=1;
}
pw[0]=1;ipw[0]=1;
for(int i=1;i<=m;i++)
{
pw[i]=1ll*pw[i-1]*2%mod;
ipw[i]=1ll*ipw[i-1]*inv2%mod;
}
dfs2(1,0);
T.build(1,1,n);
dfs3(1,0);
cout<<(mod-ans)%mod;
return 0;
}
标签:return,int,res,jump,dep,NOI2023,mod
From: https://www.cnblogs.com/jesoyizexry/p/17594416.html