点击查看代码
#include<bits/stdc++.h>
using namespace std;
const int N=1e5+10;
int w[N];//w
vector<int> g[N];
vector<int>qr[N];
int ans[N];
int son[N],sz[N];//重儿子
int L[N],R[N],dfn,id[N],dep[N];
void dfs1(int u,int fa){
son[u]=1;
L[u]=++dfn;
id[dfn]=u;
for(auto v:g[u]){
if(v==fa) continue;
dfs1(v,u);
son[u]+=son[v];
if(son[v]>son[sz[u]]) sz[u]=v;
}
R[u]=dfn;
}
int color[N],totcol;
void add(int v,int col){
color[col]++;
if(color[col]==1) totcol++;
}
void del(int col){
color[col]--;
if(color[col]==0) totcol--;
}
void dfs2(int u,int fa,bool op){
for(int v:g[u]){
if(v==fa||v==sz[u]) continue;
dfs2(v,u,false);
}
if(sz[u]){
dfs2(sz[u],u,true);
}
for(int v:g[u]){
if(v==fa||v==sz[u]) continue;
for(int i=L[v];i<=R[v];i++){
add(v,w[id[i]]);
}
}
add(u,w[u]);
for(auto v:qr[u]){
ans[v]=totcol;
}
if(!op){
for(int i=L[u];i<=R[u];i++){
del(w[id[i]]);
}
}
}
int main(){
ios::sync_with_stdio(0),cin.tie(0),cout.tie(0);
int n,q;
cin>>n;
for(int i=1;i<n;i++){
int x,y;
cin>>x>>y;
g[x].push_back(y);
g[y].push_back(x);
}
for(int i=1;i<=n;i++) cin>>w[i];
cin>>q;
for(int i=1;i<=q;i++){
int x;
cin>>x;
qr[x].push_back(i);
}
dfs1(1,0);
dfs2(1,0,false);
for(int i=1;i<=q;i++){
cout<<ans[i]<<'\n';
}
return 0;
}