【学习笔记】线段树合并 & 分裂
前置知识:动态开点线段树
用来解决一些对区间拆分合并的问题。
线段树合并大概可以替代 DSU,但是常数略大。
对于线段树分裂合并的空间复杂度问题,一般内存要开 \(maxq\times t\times \lceil \log_2maxn \rceil\),其中 \(maxq\) 为询问次数,\(t\) 为每次询问的更新操作次数,\(maxn\) 为线段树值域右端点,有分裂的话还要再 \(\times 2\)。如果 \(1 \le maxq, maxn \le 10^5\) 的话,内存大概要开 \(maxn \times 40\)。
动态开点
int newnode(){
if(delcnt) return cache[delcnt--];
return ++cnt;
}
void del(int &x){
cache[++delcnt] = x;
ls = rs = sum[x] = 0;
}
修改操作时需要传取地址符并在不存在的 \(x\) 处新建节点,这样子递归回去的时候左右儿子就建好了。
不需要的时候删除节点可以将其存在一个数组里,要新建节点的时候从这里面取可以省空间。
例如一个最简单的单点修改:
void update(int &x, int l, int r, int p, int v){
if(!x) x = newnode();
sum[x] += v;
if(l == r)
return;
if(p<=mid) update(lson, p, v);
else update(rson, p, v);
}
线段树合并
线段树合并的目的是把两棵权值线段树合并。
相当于就是把对应的节点的权值相加。即将线段树 B 的节点权值加到线段树 A 上,这样子线段树 B 就可以不用了。
对应代码:
int merge(int &x, int &y){
if(!x || !y) return x+y;
son[x][0] = merge(son[x][0], son[y][0]);
son[x][1] = merge(son[x][1], son[y][1]);
sum[x] += sum[y];
del(y);
return x;
}
线段树分裂
线段树分裂的目的是把一棵权值线段树分裂成两颗权值线段树。
分裂有按权值(第 \(k\) 小)分裂和按区间分裂。
按权值(第 \(k\) 小)分裂
设 \(v\) 为 \(x\) 的左子树权值的大小。
- 当 \(v=k\) 时,
swap
后直接返回。 - 当 \(v>k\) 时,右边可以全归 \(y\),此时
swap
后还需要递归分裂左子树,因为中间还有一段没有归到右边。 - 当 \(v<k\) 时,递归分裂右子树。
记得更新权值。
对应代码:
void split(int &x, int l, int r, int &y, int k){ // 分前 k 小,即 1~k-1 和 k~n
if(sum[ls]==k){
swap(son[x][1], son[y][1]);
return;
}
if(!y) y = newnode();
if(sum[ls]>=k){
swap(son[x][1], son[y][1]);
split(lson, son[y][0], k);
}
else
split(rson, son[y][1], k-sum[ls]);
sum[y] = sum[son[y][0]]+sum[son[y][1]];
sum[x] -= sum[y];
}
按区间分裂
像普通线段树一样递归,如果当前区间被查询区间包含,则 swap
后返回。
对应代码:
void split(int &x, int l, int r, int &y, int ql, int qr){
if(ql<=l && r<=qr){
swap(x, y);
// 将 x 存到 y,相当于 y = x, del(x)
return;
}
if(!y) y = newnode();
if(ql<=mid) split(lson, son[y][0], ql, qr);
if(qr>mid) split(rson, son[y][1], ql, qr);
sum[y] = sum[son[y][0]]+sum[son[y][1]];
sum[x] -= sum[y];
// sum[x] 也可以写成左右子树相加形式的
}
注意新建 \(y\) 节点都要在递归边界之后开,这样递归边界的交换才能使 \(x\) 的这部分为空。
P5494 【模板】线段树分裂
模板题。
#include<bits/stdc++.h>
using namespace std;
#define DEBUG(a) cout<<"Dline[ "<<__LINE__<<" ]: "<<(a)<<"\n";
#define ll long long
constexpr int N = 2e5+5;
struct segtree{
int son[N*35][2], rt[N], cnt, rtcnt=1;
ll sum[N*35];
int cache[N*35], delcnt;
#define ls (son[x][0])
#define rs (son[x][1])
#define mid (((l)+(r))>>1)
#define lson ls, l, mid
#define rson rs, mid+1, r
int newnode(){
if(delcnt) return cache[delcnt--];
return ++cnt;
}
void del(int &x){
cache[++delcnt] = x;
ls = rs = sum[x] = 0;
}
void update(int &x, int l, int r, int p, int v){
if(!x) x = newnode();
sum[x] += v;
if(l == r)
return;
if(p<=mid) update(lson, p, v);
else update(rson, p, v);
}
ll query(int x, int l, int r, int ql, int qr){
if(ql<=l && r<=qr)
return sum[x];
ll res = 0;
if(ql<=mid) res += query(lson, ql, qr);
if(qr>mid) res += query(rson, ql, qr);
return res;
}
int querykth(int x, int l, int r, int k){
if(k > sum[x]) return -1;
if(l == r)
return l;
if(sum[ls]>=k) return querykth(lson, k);
else return querykth(rson, k-sum[ls]);
}
int merge(int &x, int &y){
if(!x || !y) return x+y;
son[x][0] = merge(son[x][0], son[y][0]);
son[x][1] = merge(son[x][1], son[y][1]);
sum[x] += sum[y];
del(y);
return x;
}
void split(int &x, int l, int r, int &y, int ql, int qr){
if(ql<=l && r<=qr){
swap(x, y);
// 将 x 存到 y,相当于 y = x, del(x)
return;
}
if(!y) y = newnode();
if(ql<=mid) split(lson, son[y][0], ql, qr);
if(qr>mid) split(rson, son[y][1], ql, qr);
sum[y] = sum[son[y][0]]+sum[son[y][1]];
sum[x] -= sum[y];
// sum[x] 也可以写成左右子树相加形式的
}
void split(int &x, int l, int r, int &y, int k){ // 分前 k 小,即 1~k-1 和 k~n
if(sum[ls]==k){
swap(son[x][1], son[y][1]);
return;
}
if(!y) y = newnode();
if(sum[ls]>=k){
swap(son[x][1], son[y][1]);
split(lson, son[y][0], k);
}
else
split(rson, son[y][1], k-sum[ls]);
sum[y] = sum[son[y][0]]+sum[son[y][1]];
sum[x] -= sum[y];
}
}T;
int main(){
ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
int n, m; cin>>n>>m;
for(int i=1; i<=n; i++){
int x; cin>>x;
T.update(T.rt[1], 1, n, i, x);
}
for(int i=1; i<=m; i++){
int op; cin>>op;
if(op == 0){
int p, x, y; cin>>p>>x>>y;
T.split(T.rt[p], 1, n, T.rt[++T.rtcnt], x, y);
} else if(op == 1){
int p, t; cin>>p>>t;
T.merge(T.rt[p], T.rt[t]);
} else if(op == 2){
int p, x, q; cin>>p>>x>>q;
T.update(T.rt[p], 1, n, q, x);
} else if(op == 3){
int p, x, y; cin>>p>>x>>y;
cout<<T.query(T.rt[p], 1, n, x, y)<<"\n";
} else if(op == 4){
int p, k; cin>>p>>k;
cout<<T.querykth(T.rt[p], 1, n, k)<<"\n";
}
}
return 0;
}
P4556 [Vani有约会] 雨天的尾巴 /【模板】线段树合并
原图就是一棵树,对于这棵树上的每个节点有一个桶记录对应种类的添加次数,用权值线段树维护这个桶。
每次询问对于 \(u \rightarrow v\) 上的节点的桶的位置 \(z\) 需要 \(+1\),暴力的话需要在每个节点都更新一次线段树。考虑树上差分,即在 \(u, v\) 处 \(+1\),在 \(lca(u, v), fa_{lca(u, v)}\) 处 \(-1\),最后向上更新(线段树合并)统计答案。
每个节点的线段树需要记录最大值和最大值所在位置。
#include<bits/stdc++.h>
using namespace std;
#define DEBUG(a) cout<<"Dline[ "<<__LINE__<<" ]: "<<(a)<<"\n";
#define ll long long
const int N = 1e5+5;
int fa[N][19], dep[N];
vector<int> g[N];
int ans[N], R;
struct segtree{
int ls[N*70], rs[N*70], rt[N], cache[N*70], cnt, delcnt;
int mx[N*70], id[N*70];
#define mid ((l)+(r)>>1)
#define lson ls[x], l, mid
#define rson rs[x], mid+1, r
int newnode(){
if(delcnt) return cache[delcnt--];
return ++cnt;
}
void del(int &x){
cache[++delcnt] = x;
ls[x] = rs[x] = mx[x] = id[x] = 0;
}
void pushup(int x){
if(mx[ls[x]]>=mx[rs[x]]) mx[x] = mx[ls[x]], id[x] = id[ls[x]];
else mx[x] = mx[rs[x]], id[x] = id[rs[x]];
}
void update(int &x, int l, int r, int p, int v){
if(!x) x = newnode();
if(l == r){
mx[x] += v; id[x] = l;
return;
}
if(p<=mid) update(lson, p, v);
else update(rson, p, v);
pushup(x);
}
int merge(int &x, int l, int r, int &y){
if(!x || !y) return x+y;
if(l == r){
mx[x] += mx[y];
del(y);
return x;
}
if(!y) y = newnode();
ls[x] = merge(lson, ls[y]);
rs[x] = merge(rson, rs[y]);
pushup(x);
del(y);
return x;
}
}T;
void dfs(int u, int f){
dep[u] = dep[f]+1; fa[u][0] = f;
for(int i=1; i<19; i++) fa[u][i] = fa[fa[u][i-1]][i-1];
for(int v : g[u]) if(v != f) dfs(v, u);
}
int LCA(int u, int v){
if(dep[u] > dep[v]) swap(u, v);
for(int d=dep[v]-dep[u], i=0; d; i++, d>>=1) if(d&1) v = fa[v][i];
if(u == v) return u;
for(int i=18; i>=0; i--) if(fa[u][i]!=fa[v][i]) u = fa[u][i], v = fa[v][i];
return fa[u][0];
}
void redfs(int u, int f){
if(!T.rt[u]) T.rt[u] = T.newnode(); // 不存在要开,不然 merge 会直接 return
for(int v : g[u]){
if(v != f){
redfs(v, u);
T.merge(T.rt[u], 1, R, T.rt[v]);
}
}
ans[u] = T.id[T.rt[u]];
if(T.mx[T.rt[u]]==0) ans[u] = 0;
//cout<<u<<" "<<T.rt[u]<<" "<<T.mx[T.rt[u]]<<" "<<T.id[T.rt[u]]<<"\n";
}
struct node{
int x, y, z, lca;
}op[N];
int main(){
ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
int n, m; cin>>n>>m;
for(int i=1; i<n; i++){
int u, v; cin>>u>>v;
g[u].push_back(v);
g[v].push_back(u);
}
dfs(1, 0);
for(int i=1; i<=m; i++){
cin>>op[i].x>>op[i].y>>op[i].z;
op[i].lca = LCA(op[i].x, op[i].y);
R = max(R, op[i].z);
}
for(int i=1; i<=m; i++){
T.update(T.rt[op[i].x], 1, R, op[i].z, 1);
T.update(T.rt[op[i].y], 1, R, op[i].z, 1);
T.update(T.rt[op[i].lca], 1, R, op[i].z, -1);
T.update(T.rt[fa[op[i].lca][0]], 1, R, op[i].z, -1);
}
redfs(1, 0);
for(int i=1; i<=n; i++)
cout<<ans[i]<<"\n";
return 0;
}
P1600 [NOIP2016 提高组] 天天爱跑步
因为对一个点从 \(u\) 到 \(v\) 的模拟每一时刻的贡献复杂度肯定不优,我们倒过来考虑对于每个观察员 \(i\),哪些运动点会在 \(w[i]\) 时刻对其作贡献。
记 \(lca\) 为 \(u\) 和 \(v\) 的最近公共祖先,将一个点从 \(u\) 到 \(v\) 分为 \(u\) 到 \(lca\) 的上行段和 \(lca\) 到 \(v\) 的下行段。
先分析上行段:
对于 \(u\) 到 \(lca\) 路径上的任意一点 \(k\),从 \(u\) 到 \(k\) 需要的时间 \(t=dep_u-dep_k\)。所以这个点上的观察员可以观测到这位玩家的充要条件是 \(w_k=t\),即 \(dep_k+w_k=dep_u\)。
考虑对每个节点维护一个桶,对路径上每个节点的桶的 \(dep_u\) 对应的位置加上 \(1\),最后查询 \(dep_k+w_k\) 上的数量即可。可以树上差分,将 \(u\) 上的桶的 \(dep_u\) 对应的位置加上 \(1\),将 \(fa_{lca}\) 上的桶的 \(dep_u\) 对应的位置减 \(1\)。 最后统计答案从叶子节点向上合并即可。(当然因为是区间合并所以要上线段树合并)
接着分析下行段:
对于 \(lca\) 到 \(v\) 路径上的任意一点 \(k\),从 \(lca\) 到 \(k\) 需要的时间 \(t_1=dep_k-dep_{lca}\),从 \(u\) 到 \(lca\) 需要的时间 \(t_2 = dep_u-dep_{lca}\),所以从 \(u\) 到 \(v\) 需要的时间 \(t_3 = t_1+t_2 = dep_u+dep_k-2\times dep_{lca}\)。所以这个点上的观察员可以观测到这位玩家的充要条件是 \(w_k=t_3\),即 \(dep_k-w_k=2\times dep_{lca}-dep_u\)。
同样的,将 \(v\) 上的桶的 \(2\times dep_{lca}-dep_u\) 对应的位置加 \(1\),将 \(lca\) 上的桶的 \(2\times dep_{lca}-dep_u\) 对应的位置减 \(1\)(抵消上行段的重复计算)。最后查询 \(dep_k-w_k\) 上的数量。
总贡献即为线段树上的两点的值的和。要特判 \(w_u = 0\) 的情况,因为此时会重复算两次。
注意线段树值域要开两倍,因为线段树查询位置 \(dep_u+w_u \le 2\times n\)。
#include<bits/stdc++.h>
using namespace std;
#define DEBUG(a) cout<<"Dline[ "<<__LINE__<<" ]: "<<(a)<<"\n";
#define ll long long
constexpr int N = 3e5+5;
int n, m, _n;
vector<int> g[N];
int dep[N], fa[N][20];
void dfs(int u, int f){
dep[u] = dep[f]+1; fa[u][0] = f;
for(int i=1; i<=__lg(n); i++) fa[u][i] = fa[fa[u][i-1]][i-1];
for(int v : g[u]) if(v != f) dfs(v, u);
}
int LCA(int u, int v){
if(dep[u] < dep[v]) swap(u, v);
for(int d=dep[u]-dep[v], i=0; d; i++, d>>=1) if(d&1) u = fa[u][i];
if(u == v) return u;
for(int i=__lg(n); i>=0; i--) if(fa[u][i] != fa[v][i]) u = fa[u][i], v = fa[v][i];
return fa[u][0];
}
struct segtree{
int rt[N], ls[N*80], rs[N*80], cache[N*80], delcnt, cnt;
int sum[N*80];
#define mid (((l)+(r))>>1)
#define lson ls[x], l, mid
#define rson rs[x], mid+1, r
int newnode(){
if(delcnt) return cache[delcnt--];
return ++cnt;
}
void del(int x){
cache[++cnt] = x;
sum[x] = ls[x] = rs[x] = 0;
}
void pushup(int x){sum[x] = sum[ls[x]]+sum[rs[x]];}
void update(int &x, int l, int r, int p, int v){
if(!x) x = newnode();
if(l == r){
sum[x] += v;
return;
}
if(p<=mid) update(lson, p, v);
else update(rson, p, v);
pushup(x);
}
int query(int x, int l, int r, int p){
if(l == r)
return sum[x];
if(p<=mid) return query(lson, p);
else return query(rson, p);
}
int merge(int &x, int &y){
if(!x || !y) return x+y;
if(!y) y = newnode();
ls[x] = merge(ls[x], ls[y]);
rs[x] = merge(rs[x], rs[y]);
sum[x] += sum[y];
del(y);
return x;
}
}T;
int w[N], ans[N];
void redfs(int u, int f){
if(!T.rt[u]) T.rt[u] = T.newnode();
for(int v : g[u]){
if(v == f) continue;
redfs(v, u);
T.merge(T.rt[u], T.rt[v]);
}
ans[u] = T.query(T.rt[u], -_n, _n, dep[u]+w[u]) + T.query(T.rt[u], -_n, _n, dep[u]-w[u]);
if(w[u] == 0) ans[u] /= 2;
}
int main(){
ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
cin>>n>>m; _n = n*2;
for(int i=1; i<n; i++){
int u, v; cin>>u>>v;
g[u].push_back(v);
g[v].push_back(u);
}
for(int i=1; i<=n; i++)
cin>>w[i];
dfs(1, 0);
for(int i=1; i<=m; i++){
int u, v; cin>>u>>v;
int lca = LCA(u, v);
T.update(T.rt[u], -_n, _n, dep[u], 1);
T.update(T.rt[v], -_n, _n, 2*dep[lca]-dep[u], 1);
T.update(T.rt[lca], -_n, _n, 2*dep[lca]-dep[u], -1); // dep[u] 也可
T.update(T.rt[fa[lca][0]], -_n, _n, dep[u], -1);
}
redfs(1, 0);
for(int i=1; i<=n; i++)
cout<<ans[i]<<" ";
return 0;
}