静态邻域数颜色 - 题目 - Zhengrui Online Judge
题目描述
静态树上邻域数颜色。
给一棵 \(n\) 个点的无根树,第 \(i\) 个点颜色为 \(a_i\) 。有 \(q\) 次询问,每次询问如下:给定 \(x,d\) ,考虑所有距离 \(x\) 不超过 \(d\) 的点,求有多少种不同的颜色。形式化地,给定 \(x,d\) ,求 \(|\{a_y:\text{dis}(x,y)\le d\}|\) ,其中 \(\text{dis}(x,y)\) 表示 \(x\) 到 \(y\) 经过的边数。
\(n\leq 10^5\),测三组数据,\(5\sim 6\) 秒,可以强制在线。
solution
本题做法很多,下面手写一种。其它的做法我考虑复制过来。不过好像有本质相同的做法?
1
没写完不放
2
考虑对于每种颜色单独计算,首先建出虚树,虚树上每个点 \(f_u\) 表示点 \(u\) 到这种颜色的最近的距离。对于一组询问 \(x,d\)。如果能有贡献当且仅当存在一个虚树上的点 \(u\) 满足 \(f_u+dis(u,x)≤d\) ,我们发现满足条件的点在虚树上正好构成一个连通块。所以把 \(1\) 的贡献拆分成点数减去边数。边数指的是到两个点都满足条件,因为两个点的领域交可以转化为一个点的领域,所以就变成的统计 \(f_u+dis(u,x)≤d\) 的点 \(u\) 的点权和的情况,可以通过点分树在线的回答,复杂度 \(O(n\log_2n)\)。
- 对于一条边两个点都满足条件,要求到两端的该色点距离 \(\max\) 小于一个数,所以考虑把贡献算到中点,因此每条边要在中间多建一个点,防止没有中点。
3
先来个预防针:标算不是 \(O(nq/w)\) 状物,是 polylog 的。
考虑每种颜色对答案的贡献,问题转化为求有多少种颜色,满足到 \(x\) 最短距离 \(\leq d\)。
依次考虑每种颜色,并建出虚树。对虚树上每个点处理出 \(f_i\) 表示到目标颜色的点的最近距离。把点分
成:虚树上的点,在虚边上的点,在虚点子树的点,在虚边子树的点。对虚树上的点的贡献是平凡的。
考虑一条虚边,记为 \((u, v)\),其中 \(u\) 是 \(v\) 的父亲,发现虚边上一个前缀的点最短路经过 \(u\),剩下的经过 \(v\),且最短路长度可以写成关于 \(dp_x\) 的一次函数,直接树上差分即可。
对于在虚点子树的点,最短距离也可以写成关于 \(dp_x\) 的一次函数,与上面同理。
哦,这里的维护可以直接树上差分,每个操作形如给子树或链插入一个一次函数,查询是直接在每个点区间查。
现在只剩最麻烦的情况,在虚边子树的点。发现如果最短距离经过 \(u\) 也是简单的,只要考虑下半部分。
考虑重链剖分,那么可以把操作拆成 \(O(\log n)\) 条重链以及 \(O(\log n)\) 棵子树(子树是重链连接处得到的),子树可以像上面那样直接做,重链也可以通过差分然后暴力回答所有轻子树的询问。
然后就做完了。复杂度 \(O(n\log ^2n)\)。
一些实现细节补充:
首先虚边上的点是不需要另外做的,这样可以直接写成 的子树减去 的某个儿子的子树,非常方便。
然后考虑这样一个基本操作:对于一个点 \(u\),考虑所有 \(u\) 子树中的询问 \((x, d)\),如果 \(0\leq dp_x+b\leq d\),就给这个询问贡献 \(y\)。操作实现可以直接 dfn 差分扫描线。这个操作其实是非常好用的,除了重剖那部分都可以写成这个操作。
接下来考虑重剖那部分。发现操作形如,对于一条重链的前缀所有点 \(u\),把 \(u\) 的所有轻子树插入一个 \(dp_x+b\)。这个可以直接对链从后往前扫,每次加入,然后直接暴力遍历所有轻子树的询问并回答即可。
另一些实现细节补充:
事实上并不需要对每个虚点或者虚边做,可以直接用当前子树减儿子子树。
然后也并不需要每次遍历重链,可以直接先 dfs 轻儿子,然后撤销,再 dfs 重儿子并保留信息。
对一条虚边 \((u, v)\) 求最短路分界点可以解关于 \(dp_x\) 的不等式。
4
感觉邻域就很不好做,考虑拆贡献。对于每个询问考虑每个颜色对其的贡献。对于每一个询问的点求出该点距离一个颜色的点的最近距离,如果这个距离小于等于 \(d\) 那么这个颜色就会对这个询问有 \(1\) 的贡献。
我们将所有询问离线下来,挂在询问的点上。那么假如我们到一个点时,能处理出一个可重集合表示所有颜色距离它的最近距离,那么我们就可以直接简单查了。
首先对于每个颜色分别考虑。对于当前颜色建包含点 \(1\) 的虚树,先预处理出虚树上每个点距离该颜色的最小距离 \(dis_i\)。那么相当于把所有点分成:在虚点上的,在虚边上的,不在虚树上的。在虚点上的贡献是好处理的。
对于某一条虚边上的点 \(u, v\),考虑这条虚边上的点的 \(dis_i\) 的形式。距离它们最近的该颜色的点是形如上面一部分是 \(u\),下面一部分是 \(v\)。也就是说,它们的距离是两段一次函数组合而成。而对于不在虚树上的点,它们的距离被最近的在虚树上的祖先计算到。
对于一条虚边 \(u, v\) 的上半部分,它们会把它们自己和它们不在虚树上的儿子的子树都覆盖了。这些点的 \(dis\) 都是距离 \(u\) 的距离加上 \(u\) 自己的 \(dis\)。这样的形式比较好处理,通过简单的树上差分把问题变成维护一个集合,支持全局加一减一,单点修改,区间查询。这个是好维护的。那么这一部分的总复杂度就是 \(O(n\log n)\)。
相对较难的部分在于虚边的下半部分。上半部分好维护是因为所有贡献都是形如这个点的祖先 \(dis\) 的加上它们之间的距离,但是下半部分就不一定是祖先了,无法直接套用上面的做法。
这里原本题解的做法是启发式合并+重链剖分。我的做法略有不同。
考虑轻重链剖分,把每个“下半部分”拆成 \(O(\log n)\) 条重链区间以及它们的轻儿子的子树,以及 \(O(\log n)\) 个子树。之所以还要拆出若干个子树是因为拆成重链的话会有一些点错位,需要微调一下。一条重链区间和他们轻儿子子树的 \(dis\) 都是距离该重链区间底端的距离加上底端那个点本身的 \(dis\),一个点 \(x\) 的子树中的点的 \(dis\) 是 \(dis_x\) 加上和 \(x\) 的路径长度。
对于那些子树就跟“上半部分”的做法是一样的。
考虑那些重链区间的贡献。首先考虑它们自己对自己的贡献,这个可以 dfs 的时候直接维护,即在 dfs的时候先走重儿子再走轻儿子,那么只需要一次插入一次删除,以及全局加减即可。还要维护轻儿子,这个其实直接暴力做就是对的。我们 dfs 到某一个点,然后把它所有轻儿子的询问全部查询一下。因为轻重链剖分本身的性质所以就是对的。
这一部分复杂度是 \(O(n\log ^2n)\) 的。原题解的做法也一样。
最后总复杂度就是 \(O(n\log ^2n)\) ,空间复杂度线性或者单老哥(要求 lca 啥的)。
据说还有点分治做法?有无大佬教教。
code
做法 134 应该是等同的思路,写的也是。
134
#include <bits/stdc++.h>
using namespace std;
#ifdef LOCAL
#define debug(...) fprintf(stderr, ##__VA_ARGS__)
#else
#define endl "\n"
#define debug(...) void(0)
#endif
using LL = long long;
namespace seg {
constexpr int N = 3.3e7;
int ch[N][2], s1[N], s2[N], col[N], tot, *s;
void maintain(int p) {
s1[p] = s1[ch[p][0]] + s1[ch[p][1]];
s2[p] = s2[ch[p][0]] + s2[ch[p][1]];
}
int newnode(int q, int c) {
if (col[q] == c) return q;
int p = ++tot;
ch[p][0] = ch[q][0], ch[p][1] = ch[q][1];
s1[p] = s1[q], s2[p] = s2[q], col[p] = c;
return p;
}
int _mdf(int x, int k, int c, int q, int l, int r) {
int p = newnode(q, c);
if (l == r) return s[p] += k, p;
int mid = (l + r) >> 1;
if (x <= mid) ch[p][0] = _mdf(x, k, c, ch[q][0], l, mid);
else ch[p][1] = _mdf(x, k, c, ch[q][1], mid + 1, r);
maintain(p);
return p;
}
int _qry(int L, int R, int p, int l, int r) {
if (!p) return 0;
if (L <= l && r <= R) return s[p];
int mid = (l + r) >> 1, res = 0;
if (L <= mid) res += _qry(L, R, ch[p][0], l, mid);
if (mid < R) res += _qry(L, R, ch[p][1], mid + 1, r);
return res;
}
int mdf1(int x, int k, int c, int q, int l, int r) { return s = s1, _mdf(x, k, c, q, l, r); }
int mdf2(int x, int k, int c, int q, int l, int r) { return s = s2, _mdf(x, k, c, q, l, r); }
int qry1(int L, int R, int p, int l, int r) { return debug("qry1(%d, %d)\n", L, R), s = s1, _qry(L, R, p, l, r); }
int qry2(int L, int R, int p, int l, int r) { return debug("qry2(%d, %d)\n", L, R), s = s2, _qry(L, R, p, l, r); }
};
constexpr int N = 1e5 + 10;
int op, q, n, vis[N], tim, key[N], f[N], rt[N];
basic_string<int> g[N], dts[N], t[N];
int st[17][N], cnt, dfn[N], dep[N];
int fa[N], siz[N], son[N], rnk[N], top[N];
void dfs(int u, int _fa) {// {{{
siz[u] = 1, dep[u] = dep[_fa] + 1, fa[u] = _fa, son[u] = 0;
for (int v : g[u]) if (v != _fa) dfs(v, u), siz[u] += siz[v], siz[v] > siz[son[u]] && (son[u] = v);
}// }}}
void cut(int u, int topf) {// {{{
dfn[u] = ++cnt, st[0][cnt] = fa[u], rnk[cnt] = u, top[u] = topf;
if (son[u]) cut(son[u], topf);
for (int v : g[u]) if (v != fa[u] && v != son[u]) cut(v, v);
}// }}}
bool cmp(int u, int v) {// {{{
return dfn[u] < dfn[v];
}// }}}
int lca(int u, int v) {// {{{
if (u == v) return u;
int l = min(dfn[u], dfn[v]) + 1, r = max(dfn[u], dfn[v]);
int k = 31 - __builtin_clz(r - l + 1);
return min(st[k][l], st[k][r - (1 << k) + 1], cmp);
}// }}}
int dist(int u, int v) {// {{{
return dep[u] + dep[v] - 2 * dep[lca(u, v)];
}// }}}
int jump(int u, int k) {// {{{
int len = dfn[u] - dfn[top[u]] + 1;
return len <= k ? jump(fa[top[u]], k - len) : rnk[dfn[u] - k];
}// }}}
void init() {// {{{
cnt = 0;
dfs(1, 0), cut(1, 1);
for (int j = 1; j <= 16; j++) {
for (int i = 1; i + (1 << j) - 1 <= n; i++) {
st[j][i] = min(st[j - 1][i], st[j - 1][i + (1 << (j - 1))], cmp);
}
}
}// }}}
void buildvt(basic_string<int> h) {
debug("buildvt: ");
++tim;
for (int x : h) key[x] = tim, debug("%d, ", x);
debug("\n");
auto clr = [&](int x) { if (vis[x] < tim) vis[x] = tim, t[x].clear(); };
auto link = [&](int u, int v) { clr(u), clr(v), t[u] += v, t[v] += u, debug("link(%d, %d)\n", u, v); };
h.push_back(1), clr(1);
sort(h.begin(), h.end(), cmp);
int m = (int)h.size();
h.resize(m * 2 - 1);
for (int i = 0; i + 1 < m; i++) h[i + m] = lca(h[i], h[i + 1]);
sort(h.begin(), h.end(), cmp);
h.erase(unique(h.begin(), h.end()), h.end());
for (int i = 1; i < (int)h.size(); i++) link(h[i], lca(h[i], h[i - 1]));
}
vector<pair<int, int>> qry1[N], qry2[N];
void dfs1(int u, int _fa) {
f[u] = key[u] == tim ? 0 : (int)1e9;
for (int v : t[u]) if (v != _fa) dfs1(v, u), f[u] = min(f[u], f[v] + dep[v] - dep[u]);
}
void upd1(int u, int val, int coe) {
if (val > n || !u) return ;
debug("upd1(subtree(%d), val=%d, coe=%d)\n", u, val, coe);
qry1[dfn[u]].emplace_back(val, coe);
qry1[dfn[u] + siz[u]].emplace_back(val, -coe);
}
void upd2(int l, int r, int val, int coe) {
if (val > n) return ;
debug("upd2(dfn[%d..%d], val=%d, coe=%d)\n", l, r, val, coe);
qry2[l].emplace_back(val, coe);
qry2[r + 1].emplace_back(val, -coe);
}
void dfs2(int u, int _fa) {
for (int v : t[u]) if (v != _fa) f[v] = min(f[v], f[u] + dep[v] - dep[u]), dfs2(v, u);
debug("f[%d] = %d\n", u, f[u]);
upd1(u, f[u] - dep[u], +1);
if (!_fa) return ;
int depq = (f[u] - f[_fa] + dep[u] + dep[_fa] + 1) >> 1;
int q = jump(u, dep[u] - max(dep[_fa] + 1, min(dep[u], depq)));
upd1(q, f[_fa] - dep[_fa], -1);
if (u == q) return ;
int p = fa[u];
upd1(son[p], f[u] + dep[u] - 2 * dep[p], +1);
upd1(u, f[u] + dep[u] - 2 * dep[p], -1);
while (top[q] != top[p]) {
upd2(dfn[top[p]], dfn[p], f[u] + dep[u], +1);
p = top[p];
upd1(p, f[u] + dep[u] - 2 * dep[fa[p]], -1);
p = fa[p];
upd1(son[p], f[u] + dep[u] - 2 * dep[p], +1);
}
upd2(dfn[q], dfn[p], f[u] + dep[u], +1);
}
int mian() {
cin >> n >> q;
for (int i = 1; i <= n; i++) g[i].clear(), dts[i].clear(), qry1[i].clear(), qry2[i].clear();
for (int i = 1, u, v; i < n; i++) cin >> u >> v, g[u] += v, g[v] += u;
for (int i = 1, c; i <= n; i++) cin >> c, dts[c] += i;
init();
for (int i = 1; i <= n; i++) debug("%d%c", rnk[i], " \n"[i == n]);
for (int c = 1; c <= n; c++) debug("col = %d\n", c), buildvt(dts[c]), dfs1(1, 0), dfs2(1, 0);
rt[0] = seg::tot = 0;
for (int i = 1; i <= n; i++) {
rt[i] = rt[i - 1];
for (auto op : qry1[i]) rt[i] = seg::mdf1(op.first, op.second, i, rt[i], -n, n);
for (auto op : qry2[i]) rt[i] = seg::mdf2(op.first, op.second, i, rt[i], -n, n);
}
int lst = 0;
while (q--) {
int x, d;
cin >> x >> d;
x ^= op * lst;
int ans = seg::qry1(-n, d - dep[x], rt[dfn[x]], -n, n);
for (int y = x; y; y = fa[top[y]]) {
ans += seg::qry2(-n, d - dep[x] + 2 * dep[y], rt[dfn[y]], -n, n);
}
cout << (lst = ans) << endl;
}
return 0;
}
int main() {
#ifndef LOCAL
#ifndef NF
freopen("count.in", "r", stdin);
freopen("count.out", "w", stdout);
#endif
cin.tie(nullptr)->sync_with_stdio(false);
#endif
int _t;
cin >> op >> _t;
while (_t--) mian();
return 0;
}
做法 2 std
#include <bits/stdc++.h>
#define fi first
#define se second
#define mp make_pair
#define sz(a) ((int)(a).size())
using namespace std;
const int N = 2e5 + 5, I = 1e9;
int t, n, a[N], op, u, v, k, h[N], tot, fa[N], ct, dfn[N], d[N], f[18][N], g[18][N], s[N], c[N], q, siz[N], ans, dep[N], pa[N], di[19][N];
bool vs[N];
vector<int> b[N], son[N];
struct edge {int to, nxt;}e[N << 1];
void add(int u, int v)
{
e[++tot] = {v, h[u]}; h[u] = tot;
e[++tot] = {u, h[v]}; h[v] = tot;
}
int cmp(int x, int y) {return d[x] < d[y] ? x : y;}
void dfs(int u)
{
d[u] = d[fa[u]] + 1; dfn[u] = ++ct; f[0][ct] = g[0][u] = fa[u];
for(int i = 1; (1 << i) <= d[u]; i++) g[i][u] = g[i - 1][g[i - 1][u]];
for(int i = h[u]; i; i = e[i].nxt)
{
int v = e[i].to;
if(v == fa[u]) continue;
fa[v] = u; dfs(v);
}
}
void work()
{
for(int i = 1; (1 << i) <= n; i++)
for(int j = 1; j + (1 << i) - 1 <= n; j++)
f[i][j] = cmp(f[i - 1][j], f[i - 1][j + (1 << (i - 1))]);
}
int lca(int x, int y)
{
if(x == y) return x;
if(dfn[x] > dfn[y]) swap(x, y);
int l = dfn[x] + 1, r = dfn[y], k = 31 ^ __builtin_clz(r - l + 1);
return cmp(f[k][l], f[k][r - (1 << k) + 1]);
}
int dis(int x, int y) {return d[x] + d[y] - 2 * d[lca(x, y)];}
int kth(int u, int k)
{
for(int i = 0; k; i++)
if((k >> i) & 1) u = g[i][u], k ^= (1 << i);
return u;
}
void solve(int u, int sum, int t)
{
int z = 0;
auto find = [&](auto self, int u, int fa) -> void
{
siz[u] = 1; int mx = 0;
for(int i = h[u]; i; i = e[i].nxt)
{
int v = e[i].to;
if(vs[v] || v == fa) continue;
self(self, v, u); siz[u] += siz[v];
mx = max(mx, siz[v]);
}
mx = max(mx, sum - siz[u]);
if(mx <= sum / 2) z = u;
};
find(find, u, 0);
auto dfs = [&](auto self, int u, int fa) -> void
{
siz[u] = 1;
for(int i = h[u]; i; i = e[i].nxt)
{
int v = e[i].to;
if(vs[v] || v == fa) continue;
di[dep[z]][v] = di[dep[z]][u] + 1;
self(self, v, u); siz[u] += siz[v];
}
};
pa[z] = t; dep[z] = dep[t] + 1; di[dep[z]][z] = 0; dfs(dfs, z, 0); vs[z] = 1;
for(int i = h[z]; i; i = e[i].nxt)
{
int v = e[i].to;
if(vs[v]) continue;
solve(v, siz[v], z);
}
}
vector<pair<int, int>> a1[N], a2[N];
int ask(vector<pair<int, int>> &a, int d)
{
int u = lower_bound(begin(a), end(a), mp(d + 1, -I)) - begin(a) - 1;
return u == -1 ? 0 : a[u].se;
}
void upd(int u, int d, int v)
{
a1[u].push_back({di[dep[u]][u] + d, v});
for(int z = u, i = dep[u] - 1; i >= 1; z = pa[z], i--)
{
a1[pa[z]].push_back({di[i][u] + d, v});
a2[z].push_back({di[i][u] + d, -v});
}
}
namespace IO
{
const int S = (1 << 20);
char in[S], out[S], *p1 = in, *p2 = in, *p3 = out;
inline char gc() {return p1 == p2 && (p2 = (p1 = in) + fread(in, 1, S, stdin), p1 == p2) ? EOF : *p1++;}
void flush() {fwrite(out, 1, p3 - out, stdout); p3 = out;}
void pc(char c)
{
if(p3 == out + S) flush();
*p3++ = c;
}
template <class T> void read(T &x)
{
x = 0; int f = 0; char c = gc();
for(; !isdigit(c); c = gc()) if(c == '-') f = 1;
for(; isdigit(c); c = gc()) x = x * 10 + (c ^ 48);
if(f) x = -x;
}
template <class T> void write(T x, char c = '\n')
{
if(x < 0) pc('-'), x = -x;
static int s[50], t = 0;
do s[++t] = x % 10, x /= 10; while(x);
while(t) pc(s[t--] ^ 48);
pc(c);
}
struct F {~F(){flush();};}f;
}
using IO :: read;
using IO :: write;
void solve()
{
read(n), read(q); tot = ct = ans = 0;
for(int i = 1; i < n + n; i++) h[i] = 0, vs[i] = 0;
for(int i = 1; i < n; i++)
{
read(u), read(v);
add(u, i + n), add(v, i + n);
}
for(int i = 1; i <= n; i++)
{
read(a[i]);
b[a[i]].push_back(i);
}
n = n + n - 1; dfs(1); work(); solve(1, n, 0);
for(int o = 1, t = 0; o <= n; o++)
{
if(!sz(b[o])) continue;
sort(begin(b[o]), end(b[o]), [](int x, int y){return dfn[x] < dfn[y];});
t = 0; s[++t] = 1;
auto add = [](int u, int v) {son[u].push_back(v);};
for(auto u: b[o])
{
if(u == 1) continue;
int v = lca(s[t], u);
while(dfn[s[t - 1]] >= dfn[v]) add(s[t - 1], s[t]), t--;
if(s[t] != v) add(v, s[t]), s[t] = v;
s[++t] = u;
}
while(t > 1) add(s[t - 1], s[t]), t--;
auto dfs = [&](auto self, int u) -> void
{
c[u] = (a[u] == o ? 0 : n);
for(int v: son[u])
{
fa[v] = u; self(self, v);
c[u] = min(c[u], c[v] + d[v] - d[u]);
}
};
auto dfs2 = [&](auto self, int u) -> void
{
upd(u, c[u], 1);
for(int v: son[u])
{
c[v] = min(c[v], c[u] + d[v] - d[u]);
int z = (c[u] + c[v] + d[v] - d[u]) >> 1;
if(z < c[v]) upd(v, c[v], -1);
else if(z > c[v] + d[v] - d[u]) upd(u, c[u], -1);
else upd(kth(v, z - c[v]), z, -1);
self(self, v);
}
vector<int>().swap(son[u]);
};
dfs(dfs, 1); dfs2(dfs2, 1);
vector<int>().swap(b[o]);
}
for(int u = 1; u <= n; u++)
{
sort(begin(a1[u]), end(a1[u])); sort(begin(a2[u]), end(a2[u]));
for(int i = 1; i < sz(a1[u]); i++) a1[u][i].se += a1[u][i - 1].se;
for(int i = 1; i < sz(a2[u]); i++) a2[u][i].se += a2[u][i - 1].se;
}
while(q--)
{
read(u), read(k), k <<= 1;
if(op) u ^= ans;
ans = ask(a1[u], k - di[dep[u]][u]);
for(int z = u, i = dep[u] - 1; i >= 1; z = pa[z], i--)
{
ans += ask(a1[pa[z]], k - di[i][u]);
ans += ask(a2[z], k - di[i][u]);
}
write(ans);
}
for(int i = 1; i <= n; i++)
{
vector<pair<int, int>>().swap(a1[i]);
vector<pair<int, int>>().swap(a2[i]);
}
}
int main()
{
freopen("count.in", "r", stdin);
freopen("count.out", "w", stdout);
read(op), read(t);
while(t--) solve();
return 0;
}