设两棵树的边集分别为 \(E_1,E_2\),那么两棵树不同当且仅当它们对应的边集不同。
转化一下可以发现,染色方案等于 \(y^{n-|E_1\cap E_2|}\),即由边集 \(E_1\cap E_2\) 构成的图的连通块数量。
\(k=1\)
\[ans=\sum_{E_2}y^{n-|E_1\cap E_2|} \]考虑枚举 \(S=E_1\cap E_2\),那么我们需要知道有多少种不同的 \(E_2\)。
先不考虑算重的问题,我们考虑在 \(S\) 的基础上添加若干边使得图构成一棵树的方案数。设边集 \(S\) 构成了若干个大小分别为 \(s_1,s_2,\cdots,s_k\) 的连通块(显然它们都是树),根据扩展 Cayley 定理可知方案数为 \(n^{k-2}\prod_{i=1}^k s_i\),而且发现 \(k=n-|S|\)。
这样显然会算重,需要容斥,大胆假设容斥系数为 \(f(|S|)\):
\[ans=\sum_{S\subseteq E_1} f(|S|)n^{n-|S|-2}\prod_{i=1}^{n-|S|}s_i \]考虑一种 \(E_2\) 在原式和容斥式中的贡献:(令 \(S=E_1\cap E_2\),\(y'=y^{-1}\))
\[\begin{aligned} y^{n-|S|}&=\sum_{T\subseteq S}f(|T|)\\ y'^{|S|}&=\sum_{|T|=0}^{|S|}\binom{|S|}{|T|}\dfrac{f(|T|)}{y^n} \end{aligned} \]令 \(\dfrac{f(|T|)}{y^n}=(y'-1)^{|T|}\) 即可,解得 \(f(|T|)=(1-y)^{|T|}y^{n-|T|}\)。
于是:
\[\begin{aligned} ans&=\sum_{S\subseteq E_1}(1-y)^{|S|}y^{n-|S|}n^{n-|S|-2}\prod_{i=1}^{n-|S|}s_i\\ &=y^nn^{n-2}\sum_{S\subseteq E_1}(1-y)^{|S|}y^{-|S|}n^{-|S|}\prod_{i=1}^{n-|S|}s_i\\ \end{aligned} \]这是树上连通块 DP 问题,可以设 \(f_{u,s}\) 表示考虑完 \(u\) 子树内的连通块情况,其中以 \(u\) 为根的连通块大小为 \(s\),所有方案的贡献和。直接转移是 \(O(n^2)\) 的。
神奇的是我们可以将 \(\prod\limits_{i=1}^{n-|S|}s_i\) 看成其组合意义:每个连通块内各选一个点的方案数。于是我们可以设 \(f_{u,0/1}\) 表示考虑完 \(u\) 子树内的连通块情况,其中以 \(u\) 为根的连通块是否已经选了点,所有方案的贡献和(此时的两种方案不同当且仅当连通块情况不同或连通块内选点情况不同)。那么转移就是 \(O(n)\) 的了。
\(k=2\)
\[ans=\sum_{E_1}\sum_{E_2}y^{n-|E_1\cap E_2|} \]同样地考虑枚举 \(S=E_1\cap E_2\),设 \(g(S)\) 为在 \(S\) 的基础上添加若干边使得图构成一棵树的方案数,\(k=1\) 时已经算过了为 \(g(S)=n^{n-|S|-2}\prod\limits_{i=1}^{n-|S|}s_i\),其中边集 \(S\) 构成了若干个大小分别为 \(s_1,s_2,\cdots,s_{n-|S|}\) 的连通块。
同样大胆假设容斥系数 \(f(|S|)\):
\[ans=\sum_{S}f(|S|)g(S)^2 \]考虑一种 \(E_1,E_2\) 在原式和容斥式中的贡献:(设 \(S=E_1\cap E_2\))
\[y^{n-|S|}=\sum_{T\subseteq S}f(|T|) \]同样可以得到 \(f(|T|)=(1-y)^{|T|}y^{n-|T|}\)。于是:
\[\begin{aligned} ans&=\sum_{S}(1-y)^{|S|}y^{n-|S|}g(S)^2\\ &=y^nn^{2n-4}\sum_{S}(1-y)^{|S|}y^{-|S|}n^{-2|S|}\prod\limits_{i=1}^{n-|S|}s_i^2 \end{aligned} \]\(S\) 不太好枚举,不如直接转为枚举 \(s_i\),即把 \(\{1,\cdots,n\}\) 分为若干个集合,每个大小为 \(s\) 的集合为一个连通块(且是树),有 \(s^{s-2}\) 种生成树方案,每种方案的贡献都是 \((1-y)^{s-1}y^{-(s-1)}n^{-2(s-1)}s^2\),于是这个连通块总贡献为 \((1-y)^{s-1}y^{-(s-1)}n^{-2(s-1)}s^s\)。
考虑其指数生成函数:
\[\begin{aligned} F(x)&=\sum\limits_{s\geq 1}\frac{1}{s!}(1-y)^{s-1}y^{-(s-1)}n^{-2(s-1)}s^sx^s\\ &=\sum\limits_{s\geq 0}\frac{1}{s!}(1-y)^{s}y^{-s}n^{-2s}(s+1)^sx^{s+1}\\ \end{aligned} \]于是 \(ans=y^nn^{2n-4}[x^n]n!e^{F(x)}\),多项式 exp 即可,时间复杂度 \(O(n\log n)\)。
#include<bits/stdc++.h>
#define N 100010
using namespace std;
namespace modular
{
const int mod=998244353;
inline int add(int x,int y){return x+y>=mod?x+y-mod:x+y;}
inline int dec(int x,int y){return x-y<0?x-y+mod:x-y;}
inline int mul(int x,int y){return 1ll*x*y%mod;}
inline void Add(int &x,int y){x=x+y>=mod?x+y-mod:x+y;}
inline void Dec(int &x,int y){x=x-y<0?x-y+mod:x-y;}
inline void Mul(int &x,int y){x=1ll*x*y%mod;}
}using namespace modular;
inline int poww(int a,int b)
{
int ans=1;
while(b)
{
if(b&1) ans=mul(ans,a);
a=mul(a,a);
b>>=1;
}
return ans;
}
inline int read()
{
int x=0,f=1;
char ch=getchar();
while(ch<'0'||ch>'9')
{
if(ch=='-') f=-1;
ch=getchar();
}
while(ch>='0'&&ch<='9')
{
x=(x<<1)+(x<<3)+(ch^'0');
ch=getchar();
}
return x*f;
}
int n,y;
namespace sub0
{
#define pii pair<int,int>
#define mk(a,b) make_pair(a,b)
set<pii>e;
void main()
{
for(int i=1;i<n;i++)
{
int u=read(),v=read();
if(u>v) swap(u,v);
e.insert(mk(u,v));
}
int same=0;
for(int i=1;i<n;i++)
{
int u=read(),v=read();
if(u>v) swap(u,v);
if(e.find(mk(u,v))!=e.end()) same++;
}
printf("%d\n",poww(y,n-same));
}
#undef pii
#undef mk
}
namespace sub1
{
int invy,invn,coef;
int cnt,head[N],nxt[N<<1],to[N<<1];
int f[N][2];
void adde(int u,int v)
{
to[++cnt]=v;
nxt[cnt]=head[u];
head[u]=cnt;
}
void dfs(int u,int fa)
{
static int g[2];
f[u][0]=f[u][1]=1;
for(int i=head[u];i;i=nxt[i])
{
int v=to[i];
if(v==fa) continue;
dfs(v,u);
g[0]=mul(f[u][0],f[v][1]);
Add(g[0],mul(f[u][0],mul(f[v][0],coef)));
g[1]=mul(f[u][1],f[v][1]);
Add(g[1],mul(f[u][0],mul(f[v][1],coef)));
Add(g[1],mul(f[u][1],mul(f[v][0],coef)));
f[u][0]=g[0];
f[u][1]=g[1];
}
}
void main()
{
invy=poww(y,mod-2),invn=poww(n,mod-2);
coef=mul(dec(1,y),mul(invy,invn));
for(int i=1;i<n;i++)
{
int u=read(),v=read();
adde(u,v),adde(v,u);
}
dfs(1,0);
printf("%d\n",mul(mul(poww(y,n),poww(n,n-2)),f[1][1]));
}
}
namespace sub2
{
#define LN 20
int fac[N<<3],ifac[N<<3];
vector<int>w[LN][2];
void init(int limit)
{
for(int bit=0,mid=1;mid<limit;bit++,mid<<=1)
{
int len=mid<<1;
int gn=poww(3,(mod-1)/len);
int ign=poww(gn,mod-2);
int g=1,ig=1;
for(int j=0;j<mid;j++,g=mul(g,gn),ig=mul(ig,ign))
w[bit][0].push_back(g),w[bit][1].push_back(ig);
}
}
void NTT(int *a,int limit,int opt)
{
static int rev[N<<3];
opt=(opt<0);
for(int i=0;i<limit;i++)
rev[i]=(rev[i>>1]>>1)|((i&1)*(limit>>1));
for(int i=0;i<limit;i++)
if(i<rev[i]) swap(a[i],a[rev[i]]);
for(int bit=0,mid=1;mid<limit;bit++,mid<<=1)
{
for(int i=0,len=mid<<1;i<limit;i+=len)
{
for(int j=0;j<mid;j++)
{
int x=a[i+j],y=mul(w[bit][opt][j],a[i+mid+j]);
a[i+j]=add(x,y),a[i+mid+j]=dec(x,y);
}
}
}
if(opt)
{
int tmp=poww(limit,mod-2);
for(int i=0;i<limit;i++)
a[i]=mul(a[i],tmp);
}
}
void getinv(int *f,int *g,int n)
{
static int ff[N<<3];
g[0]=poww(f[0],mod-2);
int now=2;
for(;now<(n<<1);now<<=1)
{
int limit=now<<1;
for(int i=0;i<now;i++) ff[i]=f[i];
NTT(ff,limit,1),NTT(g,limit,1);
for(int i=0;i<limit;i++)
g[i]=mul(dec(2,mul(ff[i],g[i])),g[i]);
NTT(g,limit,-1);
for(int i=now;i<limit;i++) g[i]=0;
}
for(int i=n;i<now;i++) g[i]=0;
for(int i=0;i<now;i++) ff[i]=0;
}
void getder(int *f,int *g,int n)
{
for(int i=1;i<n;i++) g[i-1]=mul(i,f[i]);
g[n-1]=0;
}
void getint(int *f,int *g,int n)
{
for(int i=n-1;i>=1;i--) g[i]=mul(mul(ifac[i],fac[i-1]),f[i-1]);
g[0]=0;
}
void getln(int *f,int *g,int n)
{
static int derf[N<<3],invf[N<<3];
getder(f,derf,n);
getinv(f,invf,n);
int limit=1;
while(limit<(n<<1)) limit<<=1;
NTT(derf,limit,1),NTT(invf,limit,1);
for(int i=0;i<limit;i++)
derf[i]=mul(derf[i],invf[i]);
NTT(derf,limit,-1);
getint(derf,g,n);
for(int i=0;i<limit;i++) derf[i]=invf[i]=0;
}
void getexp(int *f,int *g,int n)
{
static int ff[N<<3],lng[N<<3];
assert(!f[0]);
g[0]=1;
int now=2;
for(;now<(n<<1);now<<=1)
{
int limit=now<<1;
getln(g,lng,now);
for(int i=0;i<now;i++) ff[i]=f[i];
for(int i=0;i<now;i++) Dec(ff[i],lng[i]);
Add(ff[0],1);
NTT(g,limit,1),NTT(ff,limit,1);
for(int i=0;i<limit;i++)
g[i]=mul(g[i],ff[i]);
NTT(g,limit,-1);
for(int i=now;i<limit;i++) g[i]=0;
}
for(int i=n;i<now;i++) g[i]=0;
for(int i=0;i<now;i++) lng[i]=ff[i]=0;
}
void main()
{
int limit=1;
while(limit<=(n<<2)) limit<<=1;
init(limit);
fac[0]=1;
for(int i=1;i<=(n<<3);i++) fac[i]=mul(fac[i-1],i);
ifac[n<<3]=poww(fac[n<<3],mod-2);
for(int i=(n<<3);i>=1;i--) ifac[i-1]=mul(ifac[i],i);
static int f[N<<3],g[N<<3];
int coef=1,tmp=mul(mul(dec(1,y),poww(y,mod-2)),poww(mul(n,n),mod-2));
for(int i=0;i<n;i++)
{
f[i+1]=mul(mul(ifac[i],coef),poww(i+1,i));
Mul(coef,tmp);
}
getexp(f,g,n+1);
printf("%d\n",mul(mul(poww(y,n),poww(n,2*n-4)),mul(g[n],fac[n])));
}
#undef LN
}
int main()
{
n=read(),y=read();int opt=read();
if(!opt) sub0::main();
else if(opt==1) sub1::main();
else sub2::main();
return 0;
}
标签:连通,数树,int,sum,cap,exp,ans,aligned,WC2019
From: https://www.cnblogs.com/ez-lcw/p/16840589.html