非常有意义的一道题,虽然不算太难。非常好题目,英雄联盟,爱来自瓷器。
题意
给一定一个 \(n\) 个点的树,每个点有一个颜色,点 \(u\) 的颜色为 \(w_u\)。给定一个 \(P_1,P_2,\cdots,P_c\),表示要按这样的顺序收集颜色分别为 \(P_1,P_2,\cdots,P_c\) 的点。给定 \(q\) 次询问,每次给定 \(s,t\),求从 \(s\) 到 \(t\) 的路径上最多能收集前多少个点。注意:如果没收集到颜色为 \(P_i\) 的点就不能先收集颜色为 \(P_{i+1}\) 的点。\(P_1,P_2,\cdots,P_c\) 互不相等。
数据范围:\(1\leq n,q\leq 2\times 10^5,1\leq c\leq m\leq 5\times 10^4\)。
题解
方法一
摘要:不需要什么高级算法的方法,很容易写的方法。
首先,从 25 pts 的部分分开始思考。这部分很无脑,只要暴力求 LCA 然后暴力往上跳计算即可。时间复杂度 \(\Theta(qn\log n)\)。
然后思考 50 pts 的做法。发现 \(m\leq 300\),则可以设 \(f(u,x)\) 表示点 \(u\) 往上跳最近的颜色为 \(x\) 的点。这个东西可以在 \(\Theta(nm)\) 的时间复杂度内预处理出来,查询暴力跳就行,时间复杂度 \(\Theta(nm+q(c+\log n))\)。
下一步思考这个如何优化。容易发现这个做法的瓶颈在于无法预处理出 \(f(u,x)\),那么有没有什么办法可以省略 \(x\) 呢?观察查询时往上跳的过程,发现是 \(u\to x_0\to x_1\to\cdots\to x_k\)。除了 \(u\to x_0\) 的那一步,其他都是 \(x_i\to x_{i+1}\) 的!那么首先预处理出 \(U_x\) 表示 \(x\) 往上跳到 \(P_1\) 的点的位置。在剩下只要维护从一个点跳到下一个点即可!而因为 \(P_1,P_2,\cdots,P_c\) 互不相等可以直接知道每一个点的颜色对应了 \(P\) 中的第几个数,这样只要维护每一个点跳到下一个在 \(P\) 中的点的位置即可,这个可以用倍增优化。
具体来说,把 \(u\to\text{lca}(u,v)\) 和 \(\text{lca}(u,v)\to v\) 分开计算,分别维护 \(f(u,x)\) 表示从 \(u\) 开始走 \(2^x\) 个 \(P\) 里的颜色能到的点和 \(g(u,x)\) 表示从 \(u\) 开始走 \(2^x\) 个 \(P\) 里的颜色能到的点。把询问离线下来挂在 \(v\) 点上 DFS 一遍计算答案。但是还有一个问题,如果不知道最终的答案,如何处理 \(v\) 向上跳呢?解决这个问题只要二分答案即可。时间复杂度 \(\Theta(n\log n+q\log n\log m)\)。
代码
#include<bits/stdc++.h>
#define int long long
#define rep(i,n) for(int i=0;i<n;i++)
#define rept(i,n) for(int i=1;i<=n;i++)
#define repe(i,l,r) for(int i=l;i<=r;i++)
#define FOR(i,r,l) for(int i=r;i>=l;i--)
#define pii pair<int,int>
#define mpr make_pair
#define pb push_back
#define sz(v) (int)(v.size())
using namespace std;
int fast(int a,int b,int P){int res=1;if(P<=0){while(b){if(b&1)res=res*a;a=a*a;b>>=1;}}else{while(b){if(b&1)res=res*a%P;a=a*a%P;b>>=1;}}return res;}
template<typename T>void chmax(T& a,T b){if(a<b)a=b;return;}
template<typename T>void chmin(T& a,T b){if(a>b)a=b;return;}
const int N=2e5+10;
int n,m,c,P[N],w[N],rnk[N],s[N],t[N],ans[N];
vector<int> G[N],Q[N];
int fa[N][18],cur[N],up[N],up1[N][18],up2[N][18],dep[N];
void dfs(int u,int p){
int tmp=cur[w[u]];
cur[w[u]]=u;
fa[u][0]=p;
dep[u]=dep[p]+1;
up[u]=cur[P[1]];
up1[u][0]=cur[P[rnk[w[u]]+1]];
up2[u][0]=cur[P[rnk[w[u]]-1]];
for(auto v:G[u]){
if(v==p)continue;
dfs(v,u);
}
cur[w[u]]=tmp;
}
int lca(int u,int v){
if(dep[u]<dep[v])swap(u,v);
FOR(i,17,0)if(dep[fa[u][i]]>=dep[v])u=fa[u][i];
if(u==v)return u;
FOR(i,17,0)if(fa[u][i]!=fa[v][i])u=fa[u][i],v=fa[v][i];
return fa[u][0];
}
int calc(int x){
int u=s[x],v=t[x];
int c=lca(u,v);
int l=0,r=m;
if(dep[up[u]]>=dep[c]){
u=up[u];
FOR(i,17,0)if(dep[up1[u][i]]>=dep[c])u=up1[u][i];
l=rnk[w[u]];
}
int tmp=l;
while(l<r){
int mid=(l+r+1)>>1LL;
if(mid==tmp){
l=mid;
continue;
}
int node=v;
if(dep[cur[P[mid]]]<dep[c]){
r=mid-1;
continue;
}
node=cur[P[mid]];
FOR(i,17,0)if(dep[up2[node][i]]>=dep[c])node=up2[node][i];
if(rnk[w[node]]<=tmp+1)l=mid;
else r=mid-1;
}
return l;
}
void solve(int u,int p){
int tmp=cur[w[u]];
cur[w[u]]=u;
for(auto v:Q[u])ans[v]=calc(v);
for(auto v:G[u]){
if(v==p)continue;
solve(v,u);
}
cur[w[u]]=tmp;
}
signed main(){
cin>>n>>m>>c;
rept(i,c){
cin>>P[i];
rnk[P[i]]=i;
}
rept(i,n)cin>>w[i];
rept(i,n-1){
int u,v;cin>>u>>v;
G[u].pb(v);
G[v].pb(u);
}
dfs(1,0);
rept(j,17){
rept(i,n){
fa[i][j]=fa[fa[i][j-1]][j-1];
up1[i][j]=up1[up1[i][j-1]][j-1];
up2[i][j]=up2[up2[i][j-1]][j-1];
}
}
int q;cin>>q;
rept(i,q){
cin>>s[i]>>t[i];
Q[t[i]].pb(i);
}
memset(cur,0,sizeof(cur));
solve(1,0);
rept(i,q)cout<<ans[i]<<"\n";
return 0;
}