题目描述
给定一棵 \(n\) 个节点的无根树,共有 \(m\) 个操作,操作分为两种:
- 将节点 \(a\) 到节点 \(b\) 的路径上的所有点(包括 \(a\) 和 \(b\))都染成颜色 \(c\)。
- 询问节点 \(a\) 到节点 \(b\) 的路径上的颜色段数量。
颜色段的定义是极长的连续相同颜色被认为是一段。例如 112221
由三段组成:11
、222
、1
。
数据规模与约定
对于 \(100\%\) 的数据,\(1 \leq n, m \leq 10^5\),\(1 \leq w_i, c \leq 10^9\),\(1 \leq a, b, u, v \leq n\),\(op\) 一定为 C
或 Q
,保证给出的图是一棵树。
除原数据外,还存在一组不计分的 hack 数据。
思路:
首先仔细观察一下题目,要求一条路径上的颜色段个数,这个东西看起来就很树链剖分。因为重链剖分有一个很优秀的性质:一条重链上的 \(dfs\) 序是连续的。
所以我们不妨将一条路径拆成若干个重链和一些不太完整的重链。然后就可以进行修改和询问操作了。
有一个细节需要注意,就是在代码实现的时候需要记录一下两个端点的颜色,在线段树 \(pushup,pushdown\) 的时候需要注意交点处颜色是否相同。
然后在查询的时候,一次跳重链算出所有的链内部的段数,然后对于重链相交的地方,在遍历一遍,如果相交的地方颜色相等,答案就要减一。
点击查看代码
#include<bits/stdc++.h>
#define int long long
#define mem(a) memset(a,0,sizeof(a))
#define set(a,b) memset(a,b,sizeof(a))
#define ls i<<1
#define rs i<<1|1
#define pb push_back
#define pt putchar
#define All(a) a.begin(),a.end()
#define T int t;cin>>t;while(t--)
#define rand RAND
using namespace std;
char buf[1<<20],*p1,*p2;
#define gc()(p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<20,stdin),p1==p2)?EOF:*p1++)
template<class Typ> Typ &re(Typ &x){char ch=gc(),sgn=0; x=0;for(;ch<'0'||ch>'9';ch=gc()) sgn|=ch=='-';for(;ch>='0'&&ch<='9';ch=gc()) x=x*10+(ch^48);return sgn&&(x=-x),x;}
template<class Typ> void wt(Typ x){if(x<0) putchar('-'),x=-x;if(x>9) wt(x/10);putchar(x%10^48);}
const int inf=0x3f3f3f3f3f;
const int maxn=1e5+5;
const int mod=1e9+7;
int seed = 19243;
unsigned rand(){return seed=(seed*48271ll)%2147483647;}
int n,m;
int a[maxn];
vector<int>G[maxn];
int fa[maxn],dep[maxn],sz[maxn],son[maxn];
void dfs(int u,int f){
son[u]=-1;
fa[u]=f;
dep[u]=dep[f]+1;
sz[u]=1;
for(int v:G[u]){
if(v==f)continue;
dfs(v,u);
sz[u]+=sz[v];
if(sz[v]>sz[son[u]]||son[u]==-1)son[u]=v;
}
}
int idx,dfn[maxn],top[maxn],rnk[maxn];
int col[maxn];
void dfs1(int u,int tp){
top[u]=tp;
dfn[u]=++idx;
rnk[idx]=u;
col[idx]=a[u];
if(son[u]==-1)return ;
dfs1(son[u],tp);
for(int v:G[u]){
if(v==son[u] or v==fa[u])continue;
dfs1(v,v);
}
}
struct Seg{
int sum[maxn<<2],tag[maxn<<2];
void push_up(int i,int mid){
sum[i]=sum[ls]+sum[rs];
if(col[mid]==col[mid+1])sum[i]--;
return ;
}
void build(int l,int r,int i){
if(l==r){
sum[i]=1;
return ;
}
int mid=(l+r)>>1;
build(l,mid,ls);
build(mid+1,r,rs);
push_up(i,mid);
}
void push_down(int i,int mid){
if(!tag[i])return ;
tag[ls]=tag[rs]=tag[i];
col[mid]=col[mid+1]=tag[i];
sum[ls]=sum[rs]=1;
tag[i]=0;
return ;
}
void update(int l,int r,int L,int R,int x,int i){
if(L<=l&&R>=r){
sum[i]=1;
col[l]=col[r]=tag[i]=x;
return ;
}
int mid=(l+r)>>1;
push_down(i,mid);
if(L<=mid)update(l,mid,L,R,x,ls);
if(R>mid)update(mid+1,r,L,R,x,rs);
push_up(i,mid);
}
int query(int l,int r,int L,int R,int i){
if(L<=l&&R>=r){
return sum[i];
}
int mid=(l+r)>>1;
push_down(i,mid);
int ans=0;
if(L<=mid)ans+=query(l,mid,L,R,ls);
if(R>mid)ans+=query(mid+1,r,L,R,rs);
if(L<=mid&&R>mid&&col[mid]==col[mid+1])ans--;
return ans;
}
}tree;
void change(int x,int y,int k){
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]])swap(x,y);
tree.update(1,n,dfn[top[x]],dfn[x],k,1);
x=fa[top[x]];
}
if(dfn[x]>dfn[y])swap(x,y);
tree.update(1,n,dfn[x],dfn[y],k,1);
return ;
}
int get(int x,int y){
int u=x,v=y,ans=0;
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]])swap(x,y);
ans+=tree.query(1,n,dfn[top[x]],dfn[x],1);
x=fa[top[x]];
}
if(dfn[x]>dfn[y])swap(x,y);
ans+=tree.query(1,n,dfn[x],dfn[y],1);
while(top[u]!=top[v]){
// cout<<u<<endl;
if(dep[top[u]]<dep[top[v]])swap(u,v);
if(col[dfn[top[u]]]==col[dfn[fa[top[u]]]])ans--;
u=fa[top[u]];
}
return ans;
}
signed main(){
cin>>n>>m;
for(int i=1;i<=n;i++)cin>>a[i];
for(int i=1;i<n;i++){
int u,v;cin>>u>>v;
G[u].pb(v);
G[v].pb(u);
}
dfs(1,0);
dfs1(1,1);
tree.build(1,n,1);
while(m--){
char op;
cin>>op;
int a,b,c;
if(op=='C'){
cin>>a>>b>>c;
change(a,b,c);
}
else{
cin>>a>>b;
cout<<get(a,b)<<endl;
}
}
return 0;
}