题意简述
给定一个\(N\)个节点的树,接下来有\(q\)次询问。每次询问给定\(a,b,c\),请问存在多少个节点\(i\),使得这棵树在以\(i\)为根的情况下,\(a\)和\(b\)的LCA是\(c\)。
解题思路
首先通过分析样例,我们发现:\(a,b\)的LCA一定在它们之间的简单路径上,所以如果\(c\)不在\(a,b\)之间的简单路径上,则输出\(0\)。
进一步分析\(c\)在\(a,b\)间简单路径上的情况,我们可以归纳出三种情况:
- 如果\(c=lca(a,b)\),那么答案是\(n-siz[jump(a,c)]-siz[jump(b,c)]\),其中\(siz[u]\)表示以\(u\)为根的子数大小,\(jump(u,v)\)表示\(u\)点向上跳到\(v\)的下一层所到达的节点。
- 如果\(c\)在\(a\)到\(lca(a,b)\)的路径上,那么答案是\(siz[c]-siz[jump(a,c)]\)。
- 如果\(c\)在\(b\)到\(lca(a,b)\)的路径上,那么答案是\(siz[c]-siz[jump(b,c)]\)。
接下来我们思考代码实现。怎样知道\(c\)是哪一种情况呢?
- 第一种情况:\(c=lca(a,b)\)。
- 第二种情况:\(lca(a,c)=c\)且\(lca(a,b)=lca(b,c)\)。
- 第三种情况:\(lca(b,c)=c\)且\(lca(a,b)=lca(a,c)\)。
代码实现中,\(jump()\)函数我们可以通过倍增的思想在\(O(\log N)\)的时间复杂度内完成。而求\(lca()\)函数同样可以用倍增达到\(O( \log N)\)的时间复杂度。
点击查看代码
#include<bits/stdc++.h>
#define N 500010
using namespace std;
int n,q,dep[N],fa[N][20],siz[N];
vector<int> G[N];
void dfs(int u,int father){
siz[u]=1;
dep[u]=dep[father]+1;
fa[u][0]=father;
for(int i=1;i<20;i++)
fa[u][i]=fa[fa[u][i-1]][i-1];
for(int i:G[u])
if(i!=father) dfs(i,u),siz[u]+=siz[i];
}
int lca(int u,int v){
if(dep[u]<dep[v]) swap(u,v);
for(int i=19;i>=0;i--)
if(dep[fa[u][i]]>=dep[v])
u=fa[u][i];
if(u==v) return v;
for(int i=19;i>=0;i--)
if(fa[u][i]!=fa[v][i])
u=fa[u][i],v=fa[v][i];
return fa[u][0];
}
int jump(int a,int b){
if(dep[a]==dep[b]) return 0;
//计算a跳到b下一层的位置
for(int i=19;i>=0;i--)
if(dep[fa[a][i]]>dep[b])
a=fa[a][i];
return a;
}
int main(){
cin>>n>>q;
for(int i=1;i<n;i++){
int u,v;
cin>>u>>v;
G[u].emplace_back(v);
G[v].emplace_back(u);
}
dfs(1,0);
while(q--){
int a,b,c;
cin>>a>>b>>c;
int ab=lca(a,b),ac=lca(a,c),bc=lca(b,c);
if(ab==c){
cout<<n-siz[jump(a,c)]-siz[jump(b,c)]<<"\n";
}else if(ac==c&&ab==bc){
cout<<siz[c]-siz[jump(a,c)]<<"\n";
}else if(bc==c&&ab==ac){
cout<<siz[c]-siz[jump(b,c)]<<"\n";
}else{
cout<<"0\n";
}
}
return 0;
}