P5588 小猪佩奇爬树:
P5588 小猪佩奇爬树 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn)
v用来存储各个颜色的节点
一.v[i].size()=0时,不再赘述
二.v[i].size()=1时,
此时便是把 i 这个节点看成根节点,求他子节点所有两两之乘
显然把所有i进行dfs会超时O(n^2)
有一个巧妙的方法
void dfs(int x,int f) { size[x]=1; dep[x]=dep[f]+1; fa[x][0]=f; for(int i=1;i<=lg[dep[x]];i++) fa[x][i]=fa[fa[x][i-1]][i-1]; for(int i=lin[x];i;i=e[i].next) { int v=e[i].y; if(v==f) continue; dfs(v,x); ans1[x]+=size[x]*size[v]; 这步还是很妙的 size[x]+=size[v]; } ans1[x]+=size[x]*(n-size[x]); 最后这步别忘了 }ans1[x]便是v[x].size()=1时的情况了
三.v[i].size>=2时
此时,所有的点都应该在一条路径上,否则无解
这是便有两个问题:如何判断所有点在一条路径,又如何获得两个端点是哪两个
别急,此时有在一条路径上有两种情况: 1.一条链组成 2.两条链组成
1.所有的点都在一条链上
此时 3,4,5这三个相同颜色的点在一条链上
那么如何判断这些点在所有链上:
首先我们把这些点存到v[i]上,接着按照深度从深到浅排序,那么此时v[i]={5,4,3};
那么链上的时候两个端点分别是最深的那个点和最浅的那个点
我们把最深的点拿出来,如果LCA(其他点,最深的点)=其他这个点,那么就能保证所有点在一条链上(这步还是很妙的)
int l=v[i][0],r,flag1=1,flag2=1;//l便是最深的点,flag1表示是不是链的状态,flag2表示其他在一条路上的状态 for(int j=1;j<v[i].size();j++) if(LCA(l,v[i][j])!=v[i][j]) { r=v[i][j];flag1=0;break;} //判断出不是链的状态
此外这题还有很细节的地方,比如统计时r时(另一个端点,在链中是链里最上面的点),
此时不是(n-size[r]+1)而是(n-size[Son(r)])
如图符合的点是1,2,3,9.不是1,2,3
那门如何求Son(r)呢,我们暴力一点,直接从 l最深的点开始往上推,直到fa[][0]== break掉
if(flag1) { for(r=v[i][0];fa[r][0]!=v[i].back();r=fa[r][0]); //找最上面的点的儿子点 ans=size[l]*(n-size[r]); //统计答案 }
2.所有的点在一条路径上,但不在一条链上
这种情况,这条路其实是两条链,而两个端点是两条链的最深的点
我们先回来看判断链的代码
int l=v[i][0],r,flag1=1,flag2=1;//l便是最深的点,flag1表示是不是链的状态,flag2表示其他在一条路上的状态 for(int j=1;j<v[i].size();j++) if(LCA(l,v[i][j])!=v[i][j]) { r=v[i][j];flag1=0;break;} //判断出不是链的状态
如何不是一条链,我们用 r 记录第一个不在这个链的点
由于已经由深到浅排序,那么这个r点就是另一条链的最深的点
接下来就是判断是否是这种情况,由于其他点是在这两条链上的,我们还是用上面的LCA判断
else { for(int j=1;j<v[i].size();j++) if(LCA(l,v[i][j])!=v[i][j]&&LCA(r,v[i][j])!=v[i][j]) {flag2=0;break; } //这个点都不在这两条链上,那么就不是一条路径 if(flag2) ans=size[l]*size[r];} //这种情况的答案就很好统计了
}
那么这题就做完了
Code:
#include<bits/stdc++.h> using namespace std; #define ll long long #define mp make_pair #define pb push_back //vector函数 #define popb pop_back //vector函数 #define fi first #define se second const int N=1e6+10; //const int M=; //const int inf=0x3f3f3f3f; //一般为int赋最大值,不用于memset中 //const ll INF=0x3ffffffffffff; //一般为ll赋最大值,不用于memset中 int n,m,len=0,root,dep[N],lg[N],lin[N],size[N],num[N],fa[N][20]; int l[N],r1[N],r2[N]; ll ans1[N]; bool vis[N]; struct edge{ int next,y; }e[N<<1]; vector<int> v[N]; inline int read() { int x=0,f=1;char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();} while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();} return x*f; } void insert(int xx,int yy) { e[++len].next=lin[xx]; lin[xx]=len; e[len].y=yy; } void dfs(int x,int f) { size[x]=1; dep[x]=dep[f]+1; fa[x][0]=f; for(int i=1;i<=lg[dep[x]];i++) fa[x][i]=fa[fa[x][i-1]][i-1]; for(int i=lin[x];i;i=e[i].next) { int v=e[i].y; if(v==f) continue; dfs(v,x); ans1[x]+=size[x]*size[v]; size[x]+=size[v]; } ans1[x]+=size[x]*(n-size[x]); } int LCA(int x,int y) { if(dep[x]<dep[y]) swap(x,y); while(dep[x]>dep[y]) x=fa[x][lg[dep[x]-dep[y]]]; if(x==y) return x; for(int i=lg[dep[x]];i>=0;i--) if(fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i]; return fa[x][0]; } int main() { // freopen("","r",stdin); // freopen("","w",stdout); n=read(); for(int i=2;i<=n;i++) lg[i]=lg[i>>1]+1; for(int i=1;i<=n;i++) { int col=read(); v[col].pb(i); } for(int i=1;i<n;i++) { int x=read(),y=read(); insert(x,y);insert(y,x); } dfs(1,0); for(int i=1;i<=n;i++) { ll ans=0; if(!v[i].size()) ans=1LL*n*(n-1)/2; else if(v[i].size()==1) ans=ans1[v[i][0]]; else { reverse(v[i].begin(),v[i].end()); int l=v[i][0],r,flag1=1,flag2=1; for(int j=1;j<v[i].size();j++) if(LCA(l,v[i][j])!=v[i][j]) { r=v[i][j];flag1=0;break; } if(flag1) { for(r=v[i][0];fa[r][0]!=v[i].back();r=fa[r][0]); ans=size[l]*(n-size[r]); } else { for(int j=1;j<v[i].size();j++) if(LCA(l,v[i][j])!=v[i][j]&&LCA(r,v[i][j])!=v[i][j]) { flag2=0;break; } if(flag2) ans=size[l]*size[r]; } } printf("%lld\n",ans); } return 0; }
标签:dep,int,路径,define,问题,fa,链上,树上,size From: https://www.cnblogs.com/Willette/p/17036762.html