题目大意
给你一棵 \(n\)(\(n\le50\))个点的树,可以进行不超过 \(k\) 次操作,每次断掉一条边,再连上一条边,要求树一直是树,求一共有多少种树的形态。
思路
把题意转换为对于一个 \(n\) 个点的完全图,是树边的话权值是 \(1\),否则是 \(x\)。
跑一遍矩阵树定理,矩阵树定理求的是一个图所有生成树的权值和,在这里一棵生成树的权值被定义为边权的乘积。跑出权值和以后,\(x^i\) 项的系数就是选了 \(k\) 条非树边的生成树个数。
由于操作数不超过 \(k\) 个,所以非树边的个数小于等于 \(k\),只需要把指数小于等于 \(k\) 的项的系数加起来就是答案。
但是,这样子看起来相当不好求,因此考虑拉格朗日插值。
由于最后矩阵树出来的多项式最多有 \(n-1\) 项,所以只需要把 \(x\) 从 \(1\) 到 \(n\) 代入跑 \(n\) 次矩阵树,然后拉格朗日插值还原出矩阵树的多项式。
相关知识点
矩阵树定理
(P6178 【模板】Matrix-Tree 定理 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn))
定义两个矩阵:度数矩阵 \(D_{i,i}=点 i 的度数\),邻接矩阵 \(A_{i,j}=(i,j)的权值\)。
求出 \(\det (D-A)\) 就是答案了。
矩阵树定理及其无向图形式证明 - 洛谷专栏 (luogu.com.cn)
求 \(\det (D-A)\) 的话,只需要选出任意一个 \(k\),删掉第 \(k\) 行和第 \(k\) 列,然后利用交换两行行列式取反,一行减去另一行的倍数行列式不变,变成上三角矩阵,就是只有 \(i\le j\) 时 \(a_{i,j}\) 有值,利用上三角矩阵的行列式是对角线的乘积求解即可。
拉格朗日插值还原多项式
拉格朗日插值公式:
\[f(x)=\sum ^n_{i=1} y_i \prod _{i\neq j} \frac{x-x_j}{x_i-x_j} \]最暴力模拟的话可以 \(O(n^3)\),利用多项式快速插值可以做到 \(O(n\log^2 n)\)。
代码
点击开 D
const int N=59,moder=998244353;
int n,K,fa[N]={},D[N][N]={},f[N]={},h[N]={};
bool is[N][N]={};
int add(int x,int y) { return x+y>=moder?x+y-moder:x+y; }
int sub(int x,int y) { return x<y?x-y+moder:x-y; }
int mul(int x,int y) { return (ll)x*y%moder; }
int Add(int &x,int y) { return x=x+y>=moder?x+y-moder:x+y; }
int Sub(int &x,int y) { return x=x<y?x-y+moder:x-y; }
int Mul(int &x,int y) { return x=(ll)x*y%moder; }
int kuai(int a,int b) { ll rey=1,temp=a; for(;b;b>>=1) { if(b&1) rey=rey*temp%moder; temp=temp*temp%moder; } return rey; }
int get(int x) {
int i,j,k,rey=1,op=0,inv,muler;
memset(D,0,sizeof(D));
for(i=0;i<n;++i)
for(j=0;j<n;++j)
if(i!=j)
D[i][j]=is[i][j]?1:x,
Add(D[i][i],D[i][j]),
D[i][j]=sub(0,D[i][j]);
for(i=1;i<n;++i) {
if(!D[i][i]) {
for(j=i+1;j<n;++j)
if(D[j][i]) {
for(k=1;k<n;++k)
swap(D[i][k],D[j][k]);
op^=1; break;
}
}
Mul(rey,D[i][i]);
inv=kuai(D[i][i],moder-2);
for(j=i+1;j<n;++j) {
muler=mul(D[j][i],inv);
for(k=1;k<n;++k)
Sub(D[j][k],mul(D[i][k],muler));
}
}
if(op) rey=sub(0,rey);
return rey;
}
void lag() {
int i,j,k,tot,muler,g[N]={},oldg[N]={};
for(i=1;i<=n;++i) {
memset(g,0,sizeof(g)),g[0]=1;
for(j=1,tot=0;j<=n;++j)
if(i!=j) {
memcpy(oldg,g,sizeof(oldg));
++tot;
muler=sub(i,j),muler=kuai(muler,moder-2);
for(k=0;k<=tot;++k)
g[k]=mul(sub(k?oldg[k-1]:0,mul(oldg[k],j)),muler);
}
for(j=0;j<=tot;++j)
Add(h[j],mul(g[j],f[i]));
}
return ;
}
int main()
{
usefile("b");
int i,ans=0;
read(n,K);
for(i=1;i<n;++i)
read(fa[i]),
is[i][fa[i]]=is[fa[i]][i]=true;
for(i=1;i<=n;++i)
f[i]=get(i);
lag();
for(i=0;i<=K;++i)
Add(ans,h[i]);
printf("%d\n",ans);
return 0;
}