前言
线段树绝对是出题人最爱考的高级数据结构了。它快、灵活、码量也大,相当考验 OIer 的综合能力。所以好好学习一下线段树是相当必要的。
基础
线段树是基于二叉树的。通过为二叉树的每个节点赋予线段的意义,线段树可以维护很多的区间信息,包括但不限于区间和、区间最大值、区间第 k 大(这玩意是可持久化线段树维护的)。另外,由于它是基于二叉树的,所以它可以在 \(\Theta(\log n)\) 的时间复杂度内实现区间修改和区间查询。
这里借用 OI wiki 的一张图来帮助理解。(其中d[i]
表示线段树的节点,“\(=\)”后面维护的是区间和,红色方括号括起来的是节点表示的线段):
以下以区间修改求区间和为例,给出线段树的示例代码。
洛谷 P3372 【模板】线段树 1
给出一个长度为 \(n\) 的序列 \(a\),并进行以下操作共 \(m\) 次:
- 将区间 \([l..r]\) 的每一个数加上 \(k\);
- 求区间 \([l..r]\) 的和。
更新答案
通过惊人的注意力,你应该可以发现:在上图中,每一个节点 \(u\) 的左儿子编号是 \(u \times 2\),右儿子编号是 \(u\times 2 + 1\),所以找左右儿子就可以简化为:
//使用位运算可以达到同样的效果,且常数更优
#define ls(u) (u << 1)
#define rs(u) (u << 1 | 1)
那么更新答案的函数就是:
// u 表示正在更新的节点编号
void pushup(int u) { sum[u] = sum[ls(u)] + sum[rs(u)]; }
建树
递归即可:
void build(int u, int l, int r) { //在节点 u 对区间 [l, r] 建树
if (l == r) return sum[u] = a[l], void(); //return 的值是最后一个,这里就是 void() 占位符,其实就是没有返回值的意思
int mid = (l + r) >> 1; //防溢出可以 mid = l + ((r - l) >> 1)
build(ls(u), l, mid), build(rs(u), mid + 1, r);
pushup(u);
}
注意 sum
的空间应当开为 a
的 \(4\) 倍(准确来说是 \(2^{\left\lceil \log n \right\rceil + 1}\) 倍),即 int a[N], sum[N << 2]
。
区间查询
假如我们现在要查询 \([3..5]\) 的区间和,可以用 \([3..3]\) 的区间和加上 \([4..5]\) 的区间和。
一般地,如果查询区间 \([l..r]\),则可以将其分割成 \(O(\log n)\) 个极大区间,使得这些区间的答案合并得到区间 \([l..r]\) 的答案。
代码如下:
//当前节点为 u,u 所代表的区间为 [l..r],要查询的区间为 [ql..qr]
int query(int u, int l, int r, int ql, int qr) {
if (l > qr || r < ql) return 0; //完全不是要找的区间,退回
if (l >= ql && r <= qr) return sum[u]; //当前区间是询问区间的一个极大子集
int mid = l + ((r - l) >> 1), res = 0;
if (ql <= mid) res += query(ls(u), l, mid, ql, qr);
if (qr > mid) res += query(rs(u), mid + 1, r, ql, qr);
return res;
}
区间修改
如果每一次都遍历所有 \([l..r]\) 的极大子区间来修改,时间复杂度难以接受。
还记得堆的删除中用的懒标记吗?我们考虑往线段树里也引入懒标记。具体地,我们给每一个节点新开一个 tag
数组来记录节点对应的修改,而只有当访问到该节点时,它所持有的懒标记才生效。每次生效一个标记的同时,往这个节点的儿子下传标记。
本题的增加操作是可累加的,所以多个标记的合并用累加即可。(如果是区间改为 \(k\),那么标记的合并就应该用覆盖;具体情况具体分析。)
打标记代码:
//为表示 [l..r] 的节点 u 打上加 d 的标记
void addtag(int u, int l, int r, int d) {
tag[u] += d;
sum[u] += d * (r - l + 1);
//当然也可以不在这里修改,而是到叶节点修改后靠 pushup 更新上面的答案
}
下传标记代码:
void pushdown(int u, int l, int r) {
int mid = (l + r) >> 1;
addtag(ls(u), l, mid);
addtag(rs(u), mid + 1, r);
tag[u] = 0;
}
区间修改代码:
void update(int u, int l, int r, int ql, int qr, int x) {
if (ql <= l && qr >= r) return addtag(u, l, r, x), void();
pushdown(u, l, r);
int mid = (l + r) >> 1;
if (ql <= mid) update(ls(u), l, mid, ql, qr, x);
if (qr > mid) update(rs(u), mid + 1, r, ql, qr, x);
pushup(u);
}
并且因为有了 pushdown
这个操作,区间询问的代码也得有点微调:
int query(int u, int l, int r, int ql, int qr) {
if (l >= ql && qr >= r) return sum[u];
pushdown(u, l, r);
int mid = (l + r) >> 1, res = 0;
if (ql <= mid)
res += query(ls(u), l, mid, ql, qr);
if (qr > mid)
res += query(rs(u), mid + 1, r, ql, qr);
return res;
}
于是这道题的代码就写完了:
#include <bits/stdc++.h>
using namespace std;
#define int long long
const int N = 100010;
int n, m;
int a[N], tree[N << 2], tag[N << 2];
#define ls(u) (u << 1)
#define rs(u) (u << 1 | 1)
void pushup(int p) {
tree[p] = tree[ls(p)] + tree[rs(p)];
}
void build(int p, int pl, int pr) {
tag[p] = 0;
if (pl == pr) return tree[p] = a[pl], void();
int mid = (pl + pr) >> 1;
build(ls(p), pl, mid), build(rs(p), mid + 1, pr);
pushup(p);
}
void addtag(int p, int pl, int pr, int d) {
tag[p] += d;
tree[p] += (pr - pl + 1) * d;
}
void pushdown(int p, int pl, int pr) {
if (tag[p]) {
int mid = (pl + pr) >> 1;
addtag(ls(p), pl, mid, tag[p]), addtag(rs(p), mid + 1, pr, tag[p]);
tag[p] = 0;
}
}
void update(int L, int R, int p, int pl, int pr, int d) {
if (L <= pl && R >= pr)
return addtag(p, pl, pr, d), void();
pushdown(p, pl, pr);
int mid = (pl + pr) >> 1;
if (L <= mid)
update(L, R, ls(p), pl, mid, d);
if (R > mid)
update(L, R, rs(p), mid + 1, pr, d);
pushup(p);
}
int query(int L, int R, int p, int pl, int pr) {
if (L <= pl && R >= pr)
return tree[p];
pushdown(p, pl, pr);
int res = 0;
int mid = (pl + pr) >> 1;
if (L <= mid)
res += query(L, R, ls(p), pl, mid);
if (R > mid)
res += query(L, R, rs(p), mid + 1, pr);
return res;
}
signed main() {
ios::sync_with_stdio(0);
cin.tie(0), cout.tie(0);
cin >> n >> m;
for (int i = 1; i <= n; i++) cin >> a[i];
build(1, 1, n);
while (m--) {
int opt, x, y, z;
cin >> opt;
if (opt == 1) cin >> x >> y >> z, update(x, y, 1, 1, n, z);
else cin >> x >> y, cout << query(x, y, 1, 1, n) << "\n";
}
return 0;
}
例题
下面再来看几道例题吧:
P3373 【模板】线段树 2
题意简述:给出一个序列,支持区间加、区间乘、区间求和;\(n, q\le 10^5\)。
题意解析:就是 P3372 【模板】线段树 1 多一个区间乘操作,额外开一个懒标记表示乘即可。
注意到加和乘不是同级运算。使用乘法分配律 \(a(b + c) = ab + ac\) 来同时操作两个 tag
即可。
AC 代码:
#include <bits/stdc++.h>
using namespace std;
#define int long long
inline int rd() {
int a = 0, sgn = 1;
char c = getchar();
for (; !isdigit(c); c = getchar()) if (c == '-') sgn = -1;
for (; isdigit(c); c = getchar()) a = (a << 3) + (a << 1) + c - '0';
return a * sgn;
}
inline void wt(int a) {
if (a < 0) putchar('-'), a = -a;
int sta[35], top = 0;
do { sta[top++] = a % 10; } while (a /= 10);
while (top--) putchar(sta[top] + '0');
}
const int N = 1e5 + 3;
int a[N], n, m, Mod;
struct SegmentTree {
int tree[N << 2];
int tag1[N << 2], tag2[N << 2];
int ls(int k) { return k << 1; }
int rs(int k) { return k << 1 | 1; }
void pushup(int u) { tree[u] = ((tree[ls(u)] %= Mod) + (tree[rs(u)] %= Mod)) % Mod; }
void build(int u, int l, int r) {
tag1[u] = 0, tag2[u] = 1;
if (l == r) { return tree[u] = a[l], void(); }
int mid = (l + r) >> 1;
build(ls(u), l, mid), build(rs(u), mid + 1, r);
pushup(u);
}
void addtag1(int u, int l, int r, int d) {
(tag1[u] += (d % Mod)) %= Mod;
(tree[u] += (d % Mod) * ((r - l + 1) % Mod)) %= Mod;
}
void addtag2(int u, int l, int r, int d) {
(tag1[u] *= d % Mod) %= Mod;
(tag2[u] *= d % Mod) %= Mod;
(tree[u] *= d % Mod) %= Mod;
}
void pushdown(int u, int l, int r) {
int mid = (l + r) >> 1;
addtag2(ls(u), l, mid, tag2[u]);
addtag1(ls(u), l, mid, tag1[u]);
addtag2(rs(u), mid + 1, r, tag2[u]);
addtag1(rs(u), mid + 1, r, tag1[u]);
tag1[u] = 0, tag2[u] = 1;
}
void update1(int u, int l, int r, int ql, int qr, int d) {
if (ql <= l && qr >= r)
return addtag1(u, l, r, d), void();
pushdown(u, l, r);
int mid = (l + r) >> 1;
if (ql <= mid) update1(ls(u), l, mid, ql, qr, d);
if (qr > mid) update1(rs(u), mid + 1, r, ql, qr, d);
pushup(u);
}
void update2(int u, int l, int r, int ql, int qr, int d) {
if (ql <= l && qr >= r)
return addtag2(u, l, r, d), void();
pushdown(u, l, r);
int mid = (l + r) >> 1;
if (ql <= mid) update2(ls(u), l, mid, ql, qr, d);
if (qr > mid) update2(rs(u), mid + 1, r, ql, qr, d);
pushup(u);
}
int query(int u, int l, int r, int ql, int qr) {
if (l >= ql && qr >= r) return tree[u];
pushdown(u, l, r);
int mid = (l + r) >> 1;
int res = 0;
if (ql <= mid) res += query(ls(u), l, mid, ql, qr);
if (qr > mid) res += query(rs(u), mid + 1, r, ql, qr);
return res;
}
} st;
main() {
n = rd(), m = rd(), Mod = rd();
for (int i = 1; i <= n; i++)
a[i] = rd();
st.build(1, 1, n);
for (int type; m--; ) {
type = rd();
if (type == 3) {
int l = rd(), r = rd();
wt(st.query(1, 1, n, l, r) % Mod), puts("");
}
if (type == 2) {
int l = rd(), r = rd(), c = rd();
st.update1(1, 1, n, l, r, c);
}
if (type == 1) {
int l = rd(), r = rd(), c = rd();
st.update2(1, 1, n, l, r, c);
}
}
}
P3178 [HAOI2015] 树上操作
题意简述:给一颗树,以 \(1\) 为根,树有点权。支持单点点权增加,点到根路径点权增加,点到根路径点权求和;\(n, m\le 10^5\)。
题意解析:求出这棵树的欧拉序(不懂欧拉序的看这篇博客或者期待一下以后本蒟蒻写一篇)。利用欧拉序的性质将点到根的路径转化为欧拉序的前缀。对区间加 \(k\) 的实际影响是 \(k\times(cnt_+ - cnt_-)\)(\(cnt_+\) 表示区间内进入节点的次数(即 \(+x\) 的个数),\(cnt_-\) 表示区间内离开节点的次数(即 \(-x\) 的个数)。
AC 代码:
#include <bits/stdc++.h>
using namespace std;
#define int long long
inline int rd() {
int a = 0, sgn = 1;
char c = getchar();
for (; !isdigit(c); c = getchar()) if (c == '-') sgn = -1;
for (; isdigit(c); c = getchar()) a = (a << 3) + (a << 1) + c - '0';
return a * sgn;
}
inline void wt(int a) {
if (a < 0) putchar('-'), a = -a;
int sta[35], top = 0;
do { sta[top++] = a % 10; } while (a /= 10);
while (top--) putchar(sta[top] + '0');
}
const int N = 100003;
#define ls(x) (x) << 1
#define rs(x) (x) << 1 | 1
int n, m, a[N], tid[N], tif[N], num[N << 1], stamp;
int ecnt, head[N], vet[N << 1], nxt[N << 1];
int val[N << 3], tag[N << 3];
struct DFN {
int f, v;
} dfspath[N << 1];
void add(int u, int v) {
vet[++ecnt] = v;
nxt[ecnt] = head[u];
head[u] = ecnt;
}
void dfs(int u, int f) {
dfspath[tid[u] = ++stamp].f = 1;
dfspath[stamp].v = u;
for (int e = head[u]; e; e = nxt[e])
if (vet[e] != f)
dfs(vet[e], u);
dfspath[tif[u] = ++stamp].f = -1;
dfspath[stamp].v = u;
}
void pushdown(int u, int l, int m, int r) {
if (tag[u]) {
val[ls(u)] += tag[u] * (num[m] - num[l - 1]);
val[rs(u)] += tag[u] * (num[r] - num[m]);
tag[ls(u)] += tag[u];
tag[rs(u)] += tag[u];
tag[u] = 0;
}
}
void build(int u, int l, int r) {
if (l == r)
return val[u] = dfspath[l].f * a[dfspath[l].v], void();
int m = (l + r) >> 1;
build(ls(u), l, m), build(rs(u), m + 1, r);
val[u] = val[ls(u)] + val[rs(u)];
}
void update(int u, int l, int r, int x, int c) {
val[u] += c;
if (l == r)
return;
int m = (l + r) >> 1;
if (x <= m)
update(ls(u), l, m, x, c);
else
update(rs(u), m + 1, r, x, c);
}
void change(int u, int l, int r, int p, int q, int c) {
if (p <= l && r <= q) {
tag[u] += c;
val[u] += (int)c * (num[r] - num[l - 1]);
return;
}
int m = (l + r) >> 1;
pushdown(u, l, m, r);
if (p <= m)
change(ls(u), l, m, p, q, c);
if (q > m)
change(rs(u), m + 1, r, p, q, c);
val[u] = val[ls(u)] + val[rs(u)];
}
int query(int u, int l, int r, int p, int q) {
if (p <= l && r <= q)
return val[u];
int m = (l + r) >> 1;
int res = 0;
pushdown(u, l, m, r);
if (p <= m)
res += query(ls(u), l, m, p, q);
if (q > m)
res += query(rs(u), m + 1, r, p, q);
return res;
}
main() {
n = rd(), m = rd();
for (int i = 1; i <= n; ++i) a[i] = rd();
for (int i = 1, u, v; i < n; ++i) u = rd(), v = rd(), add(u, v), add(v, u);
dfs(1, 0);
for (int i = 1; i <= stamp; ++i) num[i] = num[i - 1] + dfspath[i].f;
build(1, 1, n + n);
for (int opt, x, a; m--;) {
opt = rd();
if (opt == 1)
x = rd(), a = rd(), update(1, 1, n + n, tid[x], +a), update(1, 1, n + n, tif[x], -a);
else if (opt == 2)
x = rd(), a = rd(), change(1, 1, n + n, tid[x], tif[x], a);
else
x = rd(), wt(query(1, 1, n + n, 1, tid[x])), puts("");
}
}
结语
线段树还有很多进阶的知识点。碍于篇幅,本篇仅介绍了线段树的基本原理和初级运用。想看更多有关线段树的内容可以期待一下以后写的博客,这里就先咕咕掉了。
标签:rs,int,线段,基本知识,mid,初级,void,区间,ql From: https://www.cnblogs.com/qazsedcrfvgyhnujijn-blog/p/18281401