- 注意可能出现dpx+1在模意义下为0的情况,此时需要额外维护0的个数而不能求逆元
- 记f[x]表示x子树内包含x的连通子图的个数,g[x]表示全树包含x的连通子图的个数,由于子树的限制,所有fx互斥 【子树互斥模型】
- 求出f[x]后换根DP求出g[x]。答案即为u-LCA(u,v)上f的和+g[LCA(u,v)]+v-LCA(u,v)上f的和,注意整个过程都和s[x]无关,显然一个节点并不能和一种划分一一对应
- 【关于计数类的题目的迷思】朴素算法很难实现,样例很弱,手造样例也很麻烦。这种情况下或许只能更多地把它当做一道数学题做
点击查看代码
#include <bits/stdc++.h>
using namespace std;
const int mod=998244353;
vector<int>a[100005];
int d[100005],h[100005],z[100005],s[100005],fa[100005],zero[100005];
long long f[100005],g[100005],val[100005],sum[100005],inv[100005];
int power(int n,int p)
{
if(p==0)
{
return 1;
}
long long tmp=power(n,p/2);
if(p%2==0)
{
return tmp*tmp%mod;
}
return tmp*tmp%mod*n%mod;
}
void dfs1(int n1)
{
s[n1]=1;
z[n1]=0;
f[n1]=val[n1]=1;
zero[n1]=0;
for(int i=0;i<a[n1].size();i++)
{
d[a[n1][i]]=d[n1]+1;
dfs1(a[n1][i]);
s[n1]+=s[a[n1][i]];
f[n1]=f[n1]*(f[a[n1][i]]+1)%mod;
if((f[a[n1][i]]+1)%mod)
{
val[n1]=val[n1]*(f[a[n1][i]]+1)%mod;
}
else
{
zero[n1]++;
}
if(s[a[n1][i]]>s[z[n1]])
{
z[n1]=a[n1][i];
}
}
inv[n1]=power(f[n1]+1,998244351);
}
void dfs2(int n1)
{
if(z[n1])
{
h[z[n1]]=h[n1];
dfs2(z[n1]);
}
sum[n1]=(sum[z[n1]]+f[n1])%mod;
for(int i=0;i<a[n1].size();i++)
{
if(a[n1][i]!=z[n1])
{
h[a[n1][i]]=a[n1][i];
dfs2(a[n1][i]);
}
}
}
void dp(int n1,int fa)
{
if(fa)
{
if((f[n1]+1)%mod)
{
g[n1]=f[n1]*(g[fa]*inv[n1]%mod+1)%mod;
if((g[fa]*inv[n1]%mod+1)%mod)
{
val[n1]=val[n1]*(g[fa]*inv[n1]%mod+1)%mod;
}
else
{
zero[n1]++;
}
}
else
{
if(zero[fa]>1)
{
g[n1]=0;
zero[n1]++;
}
else
{
g[n1]=f[n1]*(val[fa]+1)%mod;
if((val[fa]+1)%mod)
{
val[n1]=val[n1]*(val[fa]+1)%mod;
}
else
{
zero[n1]++;
}
}
}
}
for(int i=0;i<a[n1].size();i++)
{
dp(a[n1][i],n1);
}
}
int main()
{
//freopen("example.in","r",stdin);
ios::sync_with_stdio(false);
cin.tie(NULL);
int T;
cin >> T;
while(T--)
{
int n,q;
cin >> n >> q;
for(int i=1;i<=n;i++)
{
a[i].clear();
}
for(int i=2;i<=n;i++)
{
cin >> fa[i];
a[fa[i]].push_back(i);
}
d[1]=1;
dfs1(1);
g[1]=f[1];
dp(1,0);
h[1]=1;
dfs2(1);
for(int i=1;i<=q;i++)
{
long long ans=0;
int u,v,x;
cin >> u >> v;
while(h[u]!=h[v])
{
if(d[h[u]]<d[h[v]])
{
swap(u,v);
}
ans=(ans+sum[h[u]]-sum[z[u]])%mod;
u=fa[h[u]];
}
if(d[u]<d[v])
{
swap(u,v);
}
x=v;
ans=(ans+sum[z[v]]-sum[z[u]])%mod;
ans=(ans+g[x])%mod;
cout<<(ans+mod)%mod<<"\n";
}
}
return 0;
}