线段树的综合应用
接下来,以洛谷P6242 【模板】线段树 3(超级毒瘤)为例,来看一下线段树的综合应用。
先来看一下此题题意,很熟悉的题面:
题目描述
给出一个长度为 \(n\) 的数列 \(A\),同时定义一个辅助数组 \(B\),\(B\) 开始与 \(A\) 完全相同。接下来进行了 \(m\) 次操作,操作有五种类型,按以下格式给出:
1 l r k
:对于所有的 \(i\in[l,r]\),将 \(A_i\) 加上 \(k\)(\(k\) 可以为负数)。2 l r v
:对于所有的 \(i\in[l,r]\),将 \(A_i\) 变成 \(\min(A_i,v)\)。3 l r
:求 \(\sum\limits_{i=l}^{r}A_i\)。4 l r
:对于所有的 \(i\in[l,r]\),求 \(A_i\) 的最大值。5 l r
:对于所有的 \(i\in[l,r]\),求 \(B_i\) 的最大值。在每一次操作后,我们都进行一次更新,让 \(B_i\gets\max(B_i,A_i)\)。
数据规模与约定
- 对于全部测试数据,保证 \(1\leq n,m\leq 5\times 10^5\),\(-5\times10^8\leq A_i\leq 5\times10^8\),\(op\in[1,5]\),\(1 \leq l\leq r \leq n\),\(-2000\leq k\leq 2000\),\(-5\times10^8\leq v\leq 5\times10^8\)。
初看此题,挺简单啊,就是操作有点多(竟然还是紫题)。
但是你细看这个操作,2 l r v
:对于所有的 \(i\in[l,r]\),将 \(A_i\) 变成 \(\min(A_i,v)\)。
怎么修改?
这也没法 \(\texttt{lazy_tag}\) 大法呀。
要是硬修改,那么恭喜你写出了一个区间修改复杂度为 \(\mathcal{O}(n\log n)\) 的优秀线段树。
这跟没加 \(\texttt{lazy_tag}\) 的区间加不一样了吗...
那怎么办?
就要用到大名鼎鼎的 吉司机线段树 了!
其实也跟普通线段树没什么区别。
只是在线段树中多维护两个值:
次大值 和 最大值的个数。
下面分别用 \(sem\) 和 \(cnt\) 表示(最大值为 \(maxa\))。
这有什么用呢?
在进行修改操作时,遇到某一个节点(代表了一个区间),要对于所有的 \(i\in[l,r]\),将 \(A_i\) 变成 \(\min(A_i,v)\):
- 当此区间的 \(maxa \leqslant v\) 时,此区间肯定不用修改,直接
return;
- 当此区间满足 \(sem \leqslant v < maxa\) 时,就只用将所有最大值改为 \(cnt\) ,将 \(sum\) 减去 \(cnt \times (maxa - v)\) ,再打上标记即可。
- 否则无法修改,继续向左右子节点递归。
那么,这样我们一次操作的复杂度就是 \(\mathcal{O}(\log^2 n)\) (这是结论,具体证明就不说了)
接下来,我们就来看一下具体的实现。
结点维护的信息
因为这题实属毒瘤,所以我们需要维护 \(\rm{4}\) 个 \(\texttt{lazy_tag}\) :
- \(\rm{add1}\) : \(A\) 数组中最大值要加的
- \(\rm{add2}\) : \(A\) 数组中非最大值要加的
- \(\rm{add3}\) : \(B\) 数组中最大值要加的(其实是 \(A\) 历史中最大值加的最多的一次)
- \(\rm{add4}\) : \(B\) 数组中非最大值要加的(其实是 \(A\) 历史中非最大值加的最多的一次)
\(\text{Q} :\) 为什么要用 \(\rm{4}\) 个 \(\texttt{lazy_tag}\) 呢?
\(\text{A} :\) 因为当进行取 \(\min\) 操作时,节点只会更改最大值,这样最大值要加的和非最大值要加的就不一样了,所以需要分最大值和非最大值,于是就被迫使用 \(\rm{4}\) 个 \(\texttt{lazy_tag}\) 。
\(\text{Code}\)
struct Node{
int l, r;
ll sum, add1, add2, add3, add4;
ll maxa, maxb, sem;
//不开 long long 见祖宗
int cnt;
}tr[4 * N];
而真正折磨人的,在后面。
修改操作实现
这里一定要仔细理解,有很多分类讨论,也容易被误导。
-
\(\rm{pushup}\)
inline void pushup(int u){
tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
tr[u].maxa = Max(tr[u << 1].maxa, tr[u << 1 | 1].maxa);
tr[u].maxb = Max(tr[u << 1].maxb, tr[u << 1 | 1].maxb);
if(tr[u << 1].maxa > tr[u << 1 | 1].maxa){
// 最大值在左边
tr[u].sem = Max(tr[u << 1].sem, tr[u << 1 | 1].maxa);
// 那么次大值就是左次大值和右最大值中大的一个
tr[u].cnt = tr[u << 1].cnt;
// 同时最大值的个数就是左子节点中最大值的个数
}
else if(tr[u << 1].maxa == tr[u << 1 | 1].maxa){
// 左右两边最大值相同的这种情况需要特判一下
tr[u].sem = Max(tr[u << 1].sem, tr[u << 1 | 1].sem);
tr[u].cnt = tr[u << 1].cnt + tr[u << 1 | 1].cnt;
// 这里最大值的个数就是两边的 cnt 加起来了
}
else{
// 同上,就是最大值在右边的情况
tr[u].sem = Max(tr[u << 1].maxa, tr[u << 1 | 1].sem);
tr[u].cnt = tr[u << 1 | 1].cnt;
}
}
-
\(\rm{pushdown}\)
inline void pushdown(int u){
ll maxn = Max(tr[u << 1].maxa, tr[u << 1 | 1].maxa);
if(tr[u << 1].maxa == maxn) change(u << 1, tr[u].add1, tr[u].add2, tr[u].add3, tr[u].add4);
// 当最大值在左边时,4 个 lazy_tag 可直接传入 change (change的定义见下)
else change(u << 1, tr[u].add2, tr[u].add2, tr[u].add4, tr[u].add4);
// 当最大值不在左边时,就全都是非最大值,传入两个非最大值的 lazy_tag
if(tr[u << 1 | 1].maxa == maxn) change(u << 1 | 1, tr[u].add1, tr[u].add2, tr[u].add3, tr[u].add4);
else change(u << 1 | 1, tr[u].add2, tr[u].add2, tr[u].add4, tr[u].add4);
// 此处同理
tr[u].add1 = tr[u].add2 = tr[u].add3 = tr[u].add4 = 0;
// 记得清空
}
\(\rm{change}\)
为了方便,定义一个 \(\rm{change}\) 函数:
inline void change(int u, ll a1, ll a2, ll a3, ll a4){ /* a1:A 数组中最大值要加的 a2:A 数组中非最大值要加的 a3:B 数组中最大值要加的 a4:B 数组中非最大值要加的 */ tr[u].sum += a2 * (tr[u].r - tr[u].l + 1 - tr[u].cnt) + a1 * tr[u].cnt; tr[u].maxb = Max(tr[u].maxb, tr[u].maxa + a3); // 因为 a3 实质上是 A 历史中最大值加的最多的一次,所以此处应与 tr[u].maxa + a3 比较取 max tr[u].add3 = Max(tr[u].add3, tr[u].add1 + a3); tr[u].add4 = Max(tr[u].add4, tr[u].add2 + a4); // 此处同理 tr[u].maxa += a1; if(tr[u].sem != -1e16) tr[u].sem += a2; // 只有当此节点存在次大值时更新 tr[u].add1 += a1; tr[u].add2 += a2; // 这两处更新一定要放在最后,因为前面的更新要用到这两个变量 }
-
$ A_i\gets\min(A_i,v)$
void update_min(int u, int l, int r, int k){
if(tr[u].maxa <= k) return;
// 此时修改肯定不会影响到此节点,直接不管 (返回)
if(l <= tr[u].l && tr[u].r <= r && tr[u].sem <= k){
// 注意添加一个直接修改的条件:tr[u].sem <= k
ll t = tr[u].maxa - k;
// 最好先记录减小值
tr[u].maxa = k;
tr[u].sum -= tr[u].cnt * t;
tr[u].add1 -= t;
// 这应该没什么好说的了...(-_-')
return;
}
pushdown(u);
int mid = (tr[u].l + tr[u].r) >> 1;
if(l <= mid) update_min(u << 1, l, r, k);
if(r > mid) update_min(u << 1 | 1, l, r, k);
pushup(u);
}
总结
其他就没什么难的了,查询之类都很容易。
\(\text{完整 Code}\)
#include <iostream>
#include <cstring>
#include <cstdio>
#include <algorithm>
using namespace std;
typedef long long ll;
const int N = 500010;
int n, m;
int a[N];
struct Node{
int l, r;
ll sum, add1, add2, add3, add4;
ll maxa, maxb, sem;
int cnt;
}tr[4 * N];
// 个人习惯,手写 Min, Max 函数
inline ll Max(ll a, ll b){return a > b ? a : b;}
inline ll Min(ll a, ll b){return a < b ? a : b;}
inline void pushup(int u){
tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
tr[u].maxa = Max(tr[u << 1].maxa, tr[u << 1 | 1].maxa);
tr[u].maxb = Max(tr[u << 1].maxb, tr[u << 1 | 1].maxb);
if(tr[u << 1].maxa > tr[u << 1 | 1].maxa){
tr[u].sem = Max(tr[u << 1].sem, tr[u << 1 | 1].maxa);
tr[u].cnt = tr[u << 1].cnt;
}
else if(tr[u << 1].maxa == tr[u << 1 | 1].maxa){
tr[u].sem = Max(tr[u << 1].sem, tr[u << 1 | 1].sem);
tr[u].cnt = tr[u << 1].cnt + tr[u << 1 | 1].cnt;
}
else{
tr[u].sem = Max(tr[u << 1].maxa, tr[u << 1 | 1].sem);
tr[u].cnt = tr[u << 1 | 1].cnt;
}
}
inline void change(int u, ll a1, ll a2, ll a3, ll a4){
tr[u].sum += a2 * (tr[u].r - tr[u].l + 1 - tr[u].cnt) + a1 * tr[u].cnt;
tr[u].maxb = Max(tr[u].maxb, tr[u].maxa + a3);
tr[u].add3 = Max(tr[u].add3, tr[u].add1 + a3);
tr[u].add4 = Max(tr[u].add4, tr[u].add2 + a4);
tr[u].maxa += a1;
if(tr[u].sem != -1e16) tr[u].sem += a2;
tr[u].add1 += a1;
tr[u].add2 += a2;
}
inline void pushdown(int u){
ll maxn = Max(tr[u << 1].maxa, tr[u << 1 | 1].maxa);
if(tr[u << 1].maxa == maxn) change(u << 1, tr[u].add1, tr[u].add2, tr[u].add3, tr[u].add4);
else change(u << 1, tr[u].add2, tr[u].add2, tr[u].add4, tr[u].add4);
if(tr[u << 1 | 1].maxa == maxn) change(u << 1 | 1, tr[u].add1, tr[u].add2, tr[u].add3, tr[u].add4);
else change(u << 1 | 1, tr[u].add2, tr[u].add2, tr[u].add4, tr[u].add4);
tr[u].add1 = tr[u].add2 = tr[u].add3 = tr[u].add4 = 0;
}
void build(int u, int l, int r){
// 后面这几个处初值要注意 (栽跟头 * n)
tr[u].l = l, tr[u].r = r;
tr[u].add1 = tr[u].add2 = tr[u].add3 = tr[u].add4 = 0;
if(l == r){
tr[u].sum = tr[u].maxa = tr[u].maxb = a[l];
tr[u].cnt = 1;
tr[u].sem = -1e16;
return;
}
int mid = (l + r) >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
pushup(u);
}
void update_add(int u, int l, int r, int k){
if(l <= tr[u].l && tr[u].r <= r){
// 此处可以直接调用 change 偷懒
change(u, k, k, k, k);
return;
}
pushdown(u);
int mid = (tr[u].l + tr[u].r) >> 1;
if(l <= mid) update_add(u << 1, l, r, k);
if(r > mid) update_add(u << 1 | 1, l, r, k);
pushup(u);
}
void update_min(int u, int l, int r, int k){
if(tr[u].maxa <= k) return;
if(l <= tr[u].l && tr[u].r <= r && tr[u].sem <= k){
ll t = tr[u].maxa - k;
tr[u].maxa = k;
tr[u].sum -= tr[u].cnt * t;
tr[u].add1 -= t;
return;
}
pushdown(u);
int mid = (tr[u].l + tr[u].r) >> 1;
if(l <= mid) update_min(u << 1, l, r, k);
if(r > mid) update_min(u << 1 | 1, l, r, k);
pushup(u);
}
ll query_sum(int u, int l, int r){
if(l <= tr[u].l && tr[u].r <= r){
return tr[u].sum;
}
pushdown(u);
int mid = (tr[u].l + tr[u].r) >> 1;
ll sum = 0;
// 之前此处没赋初值 0 ,栽跟头 * n
if(l <= mid) sum = query_sum(u << 1, l, r);
if(r > mid) sum += query_sum(u << 1 | 1, l, r);
return sum;
}
ll query_A_max(int u, int l, int r){
if(l <= tr[u].l && tr[u].r <= r) return tr[u].maxa;
pushdown(u);
int mid = (tr[u].l + tr[u].r) >> 1;
ll res = -1e16;
if(l <= mid) res = query_A_max(u << 1, l, r);
if(r > mid) res = Max(res, query_A_max(u << 1 | 1, l, r));
return res;
}
ll query_B_max(int u, int l, int r){
if(l <= tr[u].l && tr[u].r <= r) return tr[u].maxb;
pushdown(u);
int mid = (tr[u].l + tr[u].r) >> 1;
ll res = -1e16;
if(l <= mid) res = query_B_max(u << 1, l, r);
if(r > mid) res = Max(res, query_B_max(u << 1 | 1, l, r));
return res;
}
int main(){
scanf("%d%d", &n, &m);
for(int i = 1; i <= n; ++i) scanf("%d", &a[i]);
build(1, 1, n);
int op, l, r, x;
while(m--){
scanf("%d%d%d", &op, &l, &r);
if(op == 1){
scanf("%d", &x);
update_add(1, l, r, x);
}
else if(op == 2){
scanf("%d", &x);
update_min(1, l, r, x);
}
else if(op == 3) printf("%lld\n", query_sum(1, l, r));
else if(op == 4) printf("%lld\n", query_A_max(1, l, r));
else printf("%lld\n", query_B_max(1, l, r));
}
return 0;
}
温馨提示(给自己)
- 不开 \(\text{long long}\) 见祖宗
- 不赋初值见祖宗
- 写错函数名见祖宗
- 修改乱序见祖宗
最后,让我们愉快地通过这个水题 。