树套树
顾名思义,就是一个树套着一个树。
例如:线段树套平衡树,线段树中的每个节点的区间用平衡树维护。
常用:
- 外层:线段树,树状数组
- 内层:平衡树,线段树。(一般可以用
STL
)
例题:
-
AcWing 2488
没啥好说的,线段树套 set
#include <bits/stdc++.h> using namespace std; const int N = 50005, M = N << 2; const int INF = 0x3f3f3f3f; int n, m; struct Tree{ int l, r; multiset<int> s; } tr[M]; int w[N]; void build(int u,int l,int r){ tr[u] = {l, r}; tr[u].s.insert(-INF), tr[u].s.insert(INF); for (int i = l; i <= r;i++) tr[u].s.insert(w[i]); int mid = l + r >> 1; if(l==r) return; build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r); } void change(int u,int p,int x){ tr[u].s.erase(tr[u].s.find(w[p])); tr[u].s.insert(x); if(tr[u].l==tr[u].r) return; int mid = tr[u].l + tr[u].r >> 1; if(p<=mid) change(u << 1, p, x); else change(u << 1 | 1, p, x); } int query(int u,int a,int b,int x){ if(tr[u].l>=a&&tr[u].r<=b){ auto it = tr[u].s.lower_bound(x); --it; return *it; } int mid = tr[u].l + tr[u].r >> 1; int res = -INF; if(a<=mid) res = max(res, query(u << 1, a, b, x)); if(b>mid) res = max(res, query(u << 1 | 1, a, b, x)); return res; } int main(){ cin >> n >> m; for (int i = 1; i <= n;i++) cin >> w[i]; build(1, 1, n); while(m--){ int op, a, b, x; cin >> op; if(op==x){ cin >> a >> x; change(1, a, x); w[a] = x; } else{ cin >> a >> b >> x; cout << query(1, a, b, x) << endl; } } return 0; }
-
P3380 【模板】树套树
#include <bits/stdc++.h> using namespace std; const int N = 2000005, INF = 2147483647; int n, m; struct Node{ int s[2], p, v; int sz; void init(int _v,int _p){ v = _v, p = _p; sz = 1; } } tr[N]; int L[N], R[N], T[N], idx; int w[N]; void pushup(int x){ tr[x].sz = tr[tr[x].s[0]].sz + tr[tr[x].s[1]].sz + 1; } void rotate(int x){ int y = tr[x].p, z = tr[y].p; int k = tr[y].s[1] == x; tr[z].s[tr[z].s[1] == y] = x, tr[x].p = z; tr[y].s[k] = tr[x].s[k ^ 1], tr[tr[x].s[k ^ 1]].p = y; tr[x].s[k ^ 1] = y, tr[y].p = x; pushup(y), pushup(x); } void splay(int &root,int x,int k){ while(tr[x].p!=k){ int y = tr[x].p, z = tr[y].p; if(z!=k) if((tr[y].s[1]==x)^(tr[z].s[1]==y)) rotate(x); else rotate(y); rotate(x); } if(!k) root = x; } void insert(int &root,int v){ int u = root, p = 0; while(u) p = u, u = tr[u].s[v > tr[u].v]; u = ++idx; if(p) tr[p].s[v > tr[p].v] = u; tr[u].init(v, p); splay(root, u, 0); } int get_k(int root,int v){ int u = root, res = 0; while(u){ if(tr[u].v<v) res += tr[tr[u].s[0]].sz + 1, u = tr[u].s[1]; else u = tr[u].s[0]; } return res; } void update(int &root,int x,int y){ int u = root; while(u){ if(tr[u].v==x) break; if(tr[u].v<x) u = tr[u].s[1]; else u = tr[u].s[0]; } splay(root, u, 0); int l = tr[u].s[0], r = tr[u].s[1]; while(tr[l].s[1]) l = tr[l].s[1]; while(tr[r].s[0]) r = tr[r].s[0]; splay(root, l, 0), splay(root, r, l); tr[r].s[0] = 0; pushup(l), pushup(r); insert(root, y); } void build(int u,int l,int r){ L[u] = l, R[u] = r; insert(T[u], INF), insert(T[u], -INF); for (int i = l; i <= r;i++) insert(T[u], w[i]); if(l==r) return; int mid = l + r >> 1; build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r); } int query(int u,int a,int b,int x){ if(L[u]>=a&&R[u]<=b) return get_k(T[u], x) - 1; int mid = L[u] + R[u] >> 1; int res = 0; if(a<=mid) res += query(u<<1, a, b, x); if(b>mid) res += query(u << 1 | 1, a, b, x); return res; } void change(int u,int p,int x){ update(T[u], w[p], x); if(L[u]==R[u]) return; int mid = L[u] + R[u] >> 1; if(p<=mid) change(u << 1, p, x); else change(u << 1 | 1, p, x); } int get_pre(int root,int v){ int u = root, res = -INF; while(u){ if(tr[u].v<v) res = max(res, tr[u].v), u = tr[u].s[1]; else u = tr[u].s[0]; } return res; } int get_suc(int root,int v){ int u = root, res = INF; while(u){ if(tr[u].v>v) res = min(res, tr[u].v), u = tr[u].s[0]; else u = tr[u].s[1]; } return res; } int query_pre(int u,int a,int b,int x){ if(L[u]>=a&&R[u]<=b) return get_pre(T[u], x); int mid = L[u] + R[u] >> 1; int res = -INF; if(a<=mid) res = max(res, query_pre(u << 1, a, b, x)); if(b>mid) res = max(res, query_pre(u << 1 | 1, a, b, x)); return res; } int query_suc(int u,int a,int b,int x){ if(L[u]>=a&&R[u]<=b) return get_suc(T[u], x); int mid = L[u] + R[u] >> 1; int res = INF; if(a<=mid) res = min(res, query_suc(u << 1, a, b, x)); if(b>mid) res = min(res, query_suc(u << 1 | 1, a, b, x)); return res; } int main(){ cin >> n >> m; for (int i = 1; i <= n;i++) cin >> w[i]; build(1, 1, n); while(m--){ int op, a, b, x; cin >> op; if(op==1){ cin >> a >> b >> x; cout << query(1, a, b, x) +1 << endl; } else if(op==2){ cin >> a >> b >> x; int l = 0, r = 1e8; while(l<r){ int mid = l + r + 1 >> 1; if(query(1,a,b,mid)+1<=x) l = mid; else r = mid - 1; } cout << r << endl; } else if(op==3){ cin >> a >> x; change(1, a, x); w[a] = x; } else if(op==4){ cin >> a >> b >> x; cout << query_pre(1, a, b, x) << endl; } else{ cin >> a >> b >> x; cout << query_suc(1, a, b, x) << endl; } } return 0; }
-
P3332 [ZJOI2013] K大数查询
考虑值域线段树套线段树。
Tips:标记持久化,动态开店线段树。
#include <iostream> #include <cstring> #include <cstdio> #include <algorithm> #include <vector> using namespace std; typedef long long LL; const int N = 50010, P = N * 17 * 17, M = N * 4; int n, m; struct Tree { int l, r; LL sum, add; }tr[P]; int L[M], R[M], T[M], idx; struct Query { int op, a, b, c; }q[N]; vector<int> nums; int get(int x) { return lower_bound(nums.begin(), nums.end(), x) - nums.begin(); } void build(int u, int l, int r) { L[u] = l, R[u] = r, T[u] = ++ idx; if (l == r) return; int mid = l + r >> 1; build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r); } int intersection(int a, int b, int c, int d) { return min(b, d) - max(a, c) + 1; } void update(int u, int l, int r, int pl, int pr) { tr[u].sum += intersection(l, r, pl, pr); if (l >= pl && r <= pr) { tr[u].add ++ ; return; } int mid = l + r >> 1; if (pl <= mid) { if (!tr[u].l) tr[u].l = ++ idx; update(tr[u].l, l, mid, pl, pr); } if (pr > mid) { if (!tr[u].r) tr[u].r = ++ idx; update(tr[u].r, mid + 1, r, pl, pr); } } void change(int u, int a, int b, int c) { update(T[u], 1, n, a, b); if (L[u] == R[u]) return; int mid = L[u] + R[u] >> 1; if (c <= mid) change(u << 1, a, b, c); else change(u << 1 | 1, a, b, c); } LL get_sum(int u, int l, int r, int pl, int pr, int add) { if (l >= pl && r <= pr) return tr[u].sum + (r - l + 1LL) * add; int mid = l + r >> 1; LL res = 0; add += tr[u].add; if (pl <= mid) { if (tr[u].l) res += get_sum(tr[u].l, l, mid, pl, pr, add); else res += intersection(l, mid, pl, pr) * add; } if (pr > mid) { if (tr[u].r) res += get_sum(tr[u].r, mid + 1, r, pl, pr, add); else res += intersection(mid + 1, r, pl, pr) * add; } return res; } int query(int u, int a, int b, int c) { if (L[u] == R[u]) return R[u]; int mid = L[u] + R[u] >> 1; LL k = get_sum(T[u << 1 | 1], 1, n, a, b, 0); if (k >= c) return query(u << 1 | 1, a, b, c); return query(u << 1, a, b, c - k); } int main() { scanf("%d%d", &n, &m); for (int i = 0; i < m; i ++ ) { scanf("%d%d%d%d", &q[i].op, &q[i].a, &q[i].b, &q[i].c); if (q[i].op == 1) nums.push_back(q[i].c); } sort(nums.begin(), nums.end()); nums.erase(unique(nums.begin(), nums.end()), nums.end()); build(1, 0, nums.size() - 1); for (int i = 0; i < m; i ++ ) { int op = q[i].op, a = q[i].a, b = q[i].b, c = q[i].c; if (op == 1) change(1, a, b, get(c)); else printf("%d\n", nums[query(1, a, b, c)]); } return 0; }