\(~~\)因为我们注意到颜色c的取值范围只有[1,60],所以我们考虑状态压缩,将颜色映射到二进制位上,每次维护颜色的时候直接按位或即可维护该区间内有无这种颜色题目链接
题目大意
\(~~~~\) 1. 1\(~\)u\(~\)c:将以\(~\)u\(~\)为根的子树上的所有节点的颜色改为\(~\)c。
\(~~~~\) 2. 2\(~\)u\(~\):询问以\(~\)u\(~\)为根的子树上的所有节点的颜色数量。
题目思路
\(~~\)另一个问题是因为是树形结构,如果想用线段树来维护,就需要我们转化为线性结构,所以我们选择用dfs序来转换,保存每个节点的l[u],r[u]表示以\(~\)u\(~\)为根的
节点所能到达的最远的点,也就是他整个子树区间[\(~\)l[u]\(~\),\(~\)r[u]\(~\)]
\(~~\)~~~然后这个问题就基本解决了~~~
\(~~\)不过需要注意的是,在下方lazy标记的时候,因为这里的0,1分别表示两种状态,所以在初始化没有标记的状态时要避免这个影响,将其初始化为-1或者其他非0,1的数
# include<bits/stdc++.h>
using namespace std;
#define endl "\n"
# define int long long
# define ls u<<1
# define rs u<<1|1
const int N = 4e5 + 10;
int a[N], p, n, m;
vector<int> g[N];
int pos[N];
struct segtree {
int sum[4 * N], lazy[4 * N], ans;
segtree() {
ans = 0;
memset(lazy, 0, sizeof lazy);
memset(sum, 0, sizeof lazy);
}
void pushup(int u) //维护区间颜色
{
sum[u] = sum[ls] | sum[rs];
}
void build(int u, int l, int r) {
if (l == r) {
sum[u] = 1ll << a[pos[l]];
return;
}
int mid = l + r >> 1;
build(ls, l, mid);
build(rs, mid + 1, r);
pushup(u);
}
void pushdown(int u) {
if (lazy[u]) {
lazy[ls] = lazy[rs] = lazy[u];
sum[rs] = sum[ls] = 1ll << lazy[u];
lazy[u] = 0;
}
}
void modify(int u, int l, int r, int L, int R, int c) {
if (L <= l && r <= R) {
sum[u] = 1ll << c;
lazy[u] = c;
return;
}
int mid = l + r >> 1;
pushdown(u);
if (L <= mid) modify(ls, l, mid, L, R, c);
if (mid + 1 <= R) modify(rs, mid + 1, r, L, R, c);
pushup(u);
}
int query(int u, int l, int r, int L, int R) {
if (l >= L && r <= R) {
return sum[u];
}
pushdown(u);
int mid = l + r >> 1;
int val = 0;
if (L <= mid) val |= query(ls, l, mid, L, R);
if (R > mid) val |= query(rs, mid + 1, r, L, R);
pushup(u);
return val;
}
} tr;
int l[N], r[N], tot;
void dfs(int u, int fa) //dfs序
{
l[u] = ++tot;
pos[tot] = u;
for (auto v : g[u]) {
if (v == fa) continue;
dfs(v, u);
}
r[u] = tot;
}
int cnt(int val) //计算区间有多少种颜色即查询二进制位有多少个1
{
int ans = 0;
while (val) {
ans += val & 1;
val >>= 1;
}
return ans;
}
signed main() {
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
int n, m;
cin >> n >> m;
for (int i = 1; i <= n; ++i) cin >> a[i];
for (int i = 1; i < n; ++i) {
int a, b;
cin >> a >> b;
g[a].push_back(b);
g[b].push_back(a);
}
dfs(1, 0);
tr.build(1, 1, n);
while (m--) {
int op, x;
cin >> op >> x;
if (op == 1) {
int c;
cin >> c;
tr.modify(1, 1, n, l[x], r[x], c);
} else {
cout << cnt(tr.query(1, 1, n, l[x], r[x])) << endl;
}
}
return 0;
}