转化一下题意,变成求 \(x\) 在只经过编号 \(\in [l, r]\) 的点,能走到多少种颜色。
考虑建出点分树。一个结论是原树上的一个连通块,一定存在一个点,使得它在点分树上的子树完全包含这个连通块的所有点。证明考虑点分治的过程,一个连通块如果没被其中一个点剖开就一定在同一个子树。
所以对于一个询问 \((l, r, x)\),我们可以把 \(x\) 跳到点分树上最浅的与 \(x\) 在同一连通块的点 \(y\)。判断两个点是否在同一连通块只需求路径最小值和最大值,可以树剖 + ST 表解决。
这样问题变为子树内的问题了,即一次询问要求点分树上 \(y\) 的子树内,到 \(y\) 的路径最大值 \(L \ge l\),最小值 \(R \le r\) 的这些点中不同的颜色 \(c\) 数量。
因为点分树上 \(\sum sz_u = O(n \log n)\) 所以可以暴力遍历子树求出所有的 \((L, R, c)\)。套路地离线扫描线,扫右端点 \(r\),同时加入所有 \(R \le r\) 的点。设 \(b_x\) 为颜色 \(x\) 的左端点最大值。那么一种颜色 \(c\) 对于一个 \(l\) 当 \(b_c \ge l\) 时有 \(1\) 的贡献。树状数组维护即可。
时间复杂度 \(O(n \log^2 n + q \log n)\)。
code
// Problem: P5311 [Ynoi2011] 成都七中
// Contest: Luogu
// URL: https://www.luogu.com.cn/problem/P5311
// Memory Limit: 250 MB
// Time Limit: 1000 ms
//
// Powered by CP Editor (https://cpeditor.org)
#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mkp make_pair
#define mems(a, x) memset((a), (x), sizeof(a))
using namespace std;
typedef long long ll;
typedef double db;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<int, int> pii;
const int maxn = 100100;
const int logn = 22;
int n, m, a[maxn], b[maxn], c[maxn], ans[maxn];
int A[logn][maxn], B[logn][maxn], fa[maxn], dep[maxn], son[maxn], top[maxn], dfn[maxn], tim;
vector<int> G[maxn];
int f[maxn], sz[maxn], rt;
bool vis[maxn];
struct node {
int l, r, i;
node(int a = 0, int b = 0, int c = 0) : l(a), r(b), i(c) {}
};
vector<node> qq[maxn], vc[maxn];
void dfs2(int u, int fa, int t) {
sz[u] = 1;
f[u] = 0;
for (int v : G[u]) {
if (v == fa || vis[v]) {
continue;
}
dfs2(v, u, t);
sz[u] += sz[v];
f[u] = max(f[u], sz[v]);
}
f[u] = max(f[u], t - sz[u]);
if (!rt || f[u] < f[rt]) {
rt = u;
}
}
void dfs3(int u, int fa, int l, int r, int rt) {
vc[rt].pb(l, r, a[u]);
for (int v : G[u]) {
if (v == fa || vis[v]) {
continue;
}
dfs3(v, u, min(l, v), max(r, v), rt);
}
}
void dfs(int u) {
vis[u] = 1;
dfs3(u, -1, u, u, u);
for (int v : G[u]) {
if (vis[v]) {
continue;
}
rt = 0;
dfs2(v, u, sz[v]);
dfs2(rt, -1, sz[v]);
// printf("%d -> %d\n", rt, u);
b[rt] = u;
dfs(rt);
}
}
int dfs4(int u, int f, int d) {
fa[u] = f;
sz[u] = 1;
dep[u] = d;
int mx = -1;
for (int v : G[u]) {
if (v == f) {
continue;
}
sz[u] += dfs4(v, u, d + 1);
if (sz[v] > mx) {
son[u] = v;
mx = sz[v];
}
}
return sz[u];
}
void dfs5(int u, int tp) {
top[u] = tp;
dfn[u] = ++tim;
A[0][tim] = B[0][tim] = u;
if (!son[u]) {
return;
}
dfs5(son[u], tp);
for (int v : G[u]) {
if (!dfn[v]) {
dfs5(v, v);
}
}
}
inline int qmin(int l, int r) {
int k = __lg(r - l + 1);
return min(A[k][l], A[k][r - (1 << k) + 1]);
}
inline int qmax(int l, int r) {
int k = __lg(r - l + 1);
return max(B[k][l], B[k][r - (1 << k) + 1]);
}
inline int querymin(int x, int y) {
int res = x;
while (top[x] != top[y]) {
if (dep[top[x]] < dep[top[y]]) {
swap(x, y);
}
res = min(res, qmin(dfn[top[x]], dfn[x]));
x = fa[top[x]];
}
if (dep[x] > dep[y]) {
swap(x, y);
}
res = min(res, qmin(dfn[x], dfn[y]));
return res;
}
inline int querymax(int x, int y) {
int res = x;
while (top[x] != top[y]) {
if (dep[top[x]] < dep[top[y]]) {
swap(x, y);
}
res = max(res, qmax(dfn[top[x]], dfn[x]));
x = fa[top[x]];
}
if (dep[x] > dep[y]) {
swap(x, y);
}
res = max(res, qmax(dfn[x], dfn[y]));
return res;
}
namespace BIT {
int c[maxn];
inline void update(int x, int d) {
for (int i = x; i; i -= (i & (-i))) {
c[i] += d;
}
}
inline int query(int x) {
int res = 0;
for (int i = x; i <= n; i += (i & (-i))) {
res += c[i];
}
return res;
}
inline void clear(int x) {
for (int i = x; i; i -= (i & (-i))) {
c[i] = 0;
}
}
}
void solve() {
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; ++i) {
scanf("%d", &a[i]);
}
for (int i = 1, u, v; i < n; ++i) {
scanf("%d%d", &u, &v);
G[u].pb(v);
G[v].pb(u);
}
dfs2(1, -1, n);
dfs2(rt, -1, n);
dfs(rt);
dfs4(1, -1, 1);
dfs5(1, 1);
for (int j = 1; (1 << j) <= n; ++j) {
for (int i = 1; i + (1 << j) - 1 <= n; ++i) {
A[j][i] = min(A[j - 1][i], A[j - 1][i + (1 << (j - 1))]);
B[j][i] = max(B[j - 1][i], B[j - 1][i + (1 << (j - 1))]);
}
}
for (int i = 1, l, r, x; i <= m; ++i) {
scanf("%d%d%d", &l, &r, &x);
int y = x;
while (b[y]) {
y = b[y];
if (l <= querymin(x, y) && querymax(x, y) <= r) {
x = y;
}
}
// printf("%d %d %d %d\n", x, l, r, i);
qq[x].pb(l, r, i);
}
for (int i = 1; i <= n; ++i) {
if (qq[i].empty()) {
continue;
}
sort(vc[i].begin(), vc[i].end(), [&](const node &a, const node &b) {
return a.r < b.r;
});
sort(qq[i].begin(), qq[i].end(), [&](const node &a, const node &b) {
return a.r < b.r;
});
int j = 0;
for (node u : qq[i]) {
while (j < (int)vc[i].size() && vc[i][j].r <= u.r) {
if (c[vc[i][j].i]) {
BIT::update(c[vc[i][j].i], -1);
}
c[vc[i][j].i] = max(c[vc[i][j].i], vc[i][j].l);
BIT::update(c[vc[i][j].i], 1);
++j;
}
ans[u.i] = BIT::query(u.l);
}
for (node u : vc[i]) {
BIT::clear(u.l);
c[u.i] = 0;
}
}
for (int i = 1; i <= m; ++i) {
printf("%d\n", ans[i]);
}
}
int main() {
int T = 1;
// scanf("%d", &T);
while (T--) {
solve();
}
return 0;
}