#include<bits/stdc++.h>
using namespace std;
#define LL long long
const int N = 1e5+5,mod = 1e9+7;
int n,a[N];
vector<int> g[N];
LL _2[N],tmp;
void dfs(int u,int m)
{
for(int v:g[u]) dfs(v,m);
int cnt=(a[u]>>m)&1,sz=1;
for(int v:g[u]) cnt+=((a[v]>>m)&1),sz++;
// printf("%d %d %d\n",u,cnt,sz);
if(cnt==0||sz==1) return;
if(u==1)
{
if(((a[u]>>m)&1)&&cnt>=2) tmp=(tmp+(_2[cnt-2]*_2[sz-cnt]-1)*_2[n-sz]%mod)%mod;
else if(!((a[u]>>m)&1)) tmp=(tmp+(_2[cnt-1]*_2[sz-1-cnt])*_2[n-sz]%mod)%mod;
}
else
{
tmp=(tmp+(_2[cnt-1]*_2[sz-cnt]-cnt)*_2[n-sz-1]%mod)%mod;
}
}
int main()
{
freopen("in.in","r",stdin);
freopen("out.out","w",stdout);
scanf("%d",&n);
_2[0]=1; for(int i=1;i<=max(n,30);i++) _2[i]=(_2[i-1]<<1)%mod;
for(int i=1;i<=n;i++) scanf("%d",&a[i]);
for(int i=2;i<=n;i++)
{
int x; scanf("%d",&x); g[x].push_back(i);
}
LL ans=0;
for(int i=0;i<=30;i++)
{
tmp=0;
// printf("%d:\n",i);
dfs(1,i);
ans=(ans+_2[i]*tmp%mod)%mod;
}
printf("%lld\n",ans);
return 0;
}
标签:tmp,sz,cnt,int,存码,mod
From: https://www.cnblogs.com/ppllxx-9G/p/18436385