概念
Splay 树(伸展树),是一种平衡BST
它通过伸展操作不断将某个节点旋转到根节点,使得整棵树仍然满足BST的性质,能够在均摊 \(O(\log n)\) 时间内完成插入,查找和删除操作,并且保持平衡而不至于退化为链。
实现
rotate
其保证
- 不破坏BST的性质
- 不破坏节点维护的信息
- root必须指向旋转后的根节点
在Splay中旋转分为左旋和右旋
具体分析旋转过程(令需要旋转的节点为 \(x\) ,其父亲为 \(y\) ,以右旋为例)
- 将 \(y\) 的左儿子指向 \(x\) 的右儿子,且 \(x\) 的右儿子的父亲指向 \(y\)
- 将 \(x\) 的右儿子指向 \(y\) ,且 \(y\) 的父亲指向 \(x\)
- 将 \(y\) 的父亲 \(z\) 指向 \(y\) 的儿子的信息指向 \(x\) 并将 \(x\) 的父亲指向 \(z\)
inline void rotate(int x) {
int y = fa[x], z = fa[y], d = dir(x);
ch[y][d] = ch[x][d ^ 1];
if (ch[x][d ^ 1])
fa[ch[x][d ^ 1]] = y;
ch[x][d ^ 1] = y;
fa[y] = x, fa[x] = z;
if (z)
ch[z][y == ch[z][1]] = x;
pushup(x), pushup(y);
}
splay
定义:每次访问一个节点后都多次使用 splay 操作强制旋转到根
splay操作步骤有三种,具体分为六种情况
-
zig: \(y\) 是根节点。直接旋转即可
-
zig-zig:\(x,y\) 都是其父亲的左儿子或右儿子。先把 \(y\) 旋上去,再把 \(x\) 旋上去
- zig-zag: \(x, y\) 不都是其父亲的左儿子或右儿子。把 \(x\) 旋上去两次即可
代码实现:
inline void splay(int x) {
for (int f = fa[x]; f; rotate(x), f = fa[x])
if (fa[f])
rotate(dir(x) == dir(f) ? f : x);
root = x;
}
合并
合并两棵splay树
设两棵树的根节点为 \(x, y\) ,令 \(x\) 中的最大值小于 \(y\) 中的最小值
将 \(x\) 树的最大值splay到根,将其右子树设为 \(y\) 即可
插入
设插入值为 \(k\)
- 若树空,则直接插入根并退出
- 若当前节点权值等于 \(k\) ,则增加当前节点大小并更新信息
- 否则按照 BST 的性质向下找,找到空节点插入即可
inline void insert(int k) {
if (!root) {
val[++tot] = k;
cnt[tot] = 1;
root = tot;
pushup(root);
return ;
}
int cur = root, f = 0;
for (;;) {
if (val[cur] == k) {
++cnt[cur];
pushup(cur), pushup(f);
splay(cur);
break;
}
f = cur, cur = ch[cur][val[cur] < k];
if (!cur) {
val[++tot] = k;
cnt[tot] = 1;
fa[tot] = f;
ch[f][val[f] < k] = tot;
pushup(tot), pushup(f);
splay(tot);
break;
}
}
}
查询 \(x\) 的排名
与 BST 类似
inline int rnk(int k) {
int res = 0, cur = root;
for (;;) {
if (k < val[cur])
cur = ch[cur][0];
else {
res += siz[ch[cur][0]];
if (k == val[cur]) {
splay(cur);
return res + 1;
}
res += cnt[cur], cur = ch[cur][1];
}
}
}
查询排名为 \(x\) 的数
与 BST 类似
inline int kth(int k) {
int cur = root;
for (;;) {
if (ch[cur][0] && k <= siz[ch[cur][0]])
cur = ch[cur][0];
else {
k -= siz[ch[cur][0]] + cnt[cur];
if (k <= 0) {
splay(cur);
return val[cur];
}
cur = ch[cur][1];
}
}
}
查询前驱 / 后继
前驱定义为小于 \(x\) 的最大数,那么我们可以先插入 \(x\) ,前驱即为 \(x\) 左子树中最右节点,最后删除 \(x\) 即可
后继定义为大于 \(x\) 的最小数,查询方法类似前驱: \(x\) 的右子树中的最左节点
找根节点的前驱后继代码:
inline int near(int sign) {
int cur = ch[root][sign];
if (!cur)
return cur;
while (ch[cur][sign ^ 1])
cur = ch[cur][sign ^ 1];
splay(cur);
return cur;
}
删除
首先将 \(x\) 旋转到根
- 若有不止一个 \(x\) ,则直接将该点数量减 \(1\) 即可
- 否则,合并左右子树即可
inline void remove(int k) {
rnk(k);
if (cnt[root] > 1)
--cnt[root], pushup(root);
else if (!ch[root][0] && !ch[root][1])
clear(root), root = 0;
else if (!ch[root][0]) {
int cur = root;
root = ch[root][1];
fa[root] = 0;
clear(cur);
} else if (!ch[root][1]) {
int cur = root;
root = ch[root][0];
fa[root] = 0;
clear(cur);
} else {
int cur = root, x = near(0);
fa[ch[cur][1]] = x;
ch[x][1] = ch[cur][1];
clear(cur);
pushup(root);
}
}
应用
#include<bits/stdc++.h>
using namespace std;
const int N = 1e5 + 7;
namespace Splay {
int ch[N][2];
int fa[N], siz[N], cnt[N];
int val[N];
int root, tot;
inline void pushup(int x) {
siz[x] = siz[ch[x][0]] + siz[ch[x][1]] + cnt[x];
}
inline int get(int x) {
return x == ch[fa[x]][1];
}
inline void clear(int x) {
ch[x][0] = ch[x][1] = fa[x] = val[x] = siz[x] = cnt[x] = 0;
}
inline void rotate(int x) {
int y = fa[x], z = fa[y], d = get(x);
ch[y][d] = ch[x][d ^ 1];
if (ch[x][d ^ 1])
fa[ch[x][d ^ 1]] = y;
ch[x][d ^ 1] = y;
fa[y] = x, fa[x] = z;
if (z)
ch[z][y == ch[z][1]] = x;
pushup(x), pushup(y);
}
inline void splay(int x) {
for (int f = fa[x]; f = fa[x], f; rotate(x))
if (fa[f])
rotate(get(x) == get(f) ? f : x);
root = x;
}
inline void insert(int k) {
if (!root) {
val[++tot] = k;
cnt[tot] = 1;
root = tot;
pushup(root);
return ;
}
int cur = root, f = 0;
for (;;) {
if (val[cur] == k) {
++cnt[cur];
pushup(cur), pushup(f);
splay(cur);
break;
}
f = cur, cur = ch[cur][val[cur] < k];
if (!cur) {
val[++tot] = k;
cnt[tot] = 1;
fa[tot] = f;
ch[f][val[f] < k] = tot;
pushup(tot), pushup(f);
splay(tot);
break;
}
}
}
inline int rnk(int k) {
int res = 0, cur = root;
for (;;) {
if (k < val[cur])
cur = ch[cur][0];
else {
res += siz[ch[cur][0]];
if (k == val[cur]) {
splay(cur);
return res + 1;
}
res += cnt[cur], cur = ch[cur][1];
}
}
}
inline int kth(int k) {
int cur = root;
for (;;) {
if (ch[cur][0] && k <= siz[ch[cur][0]])
cur = ch[cur][0];
else {
k -= siz[ch[cur][0]] + cnt[cur];
if (k <= 0) {
splay(cur);
return val[cur];
}
cur = ch[cur][1];
}
}
}
inline int near(int sign) {
int cur = ch[root][sign];
if (!cur)
return cur;
while (ch[cur][sign ^ 1])
cur = ch[cur][sign ^ 1];
splay(cur);
return cur;
}
inline void remove(int k) {
rnk(k);
if (cnt[root] > 1)
--cnt[root], pushup(root);
else if (!ch[root][0] && !ch[root][1])
clear(root), root = 0;
else if (!ch[root][0]) {
int cur = root;
root = ch[root][1];
fa[root] = 0;
clear(cur);
} else if (!ch[root][1]) {
int cur = root;
root = ch[root][0];
fa[root] = 0;
clear(cur);
} else {
int cur = root, x = near(0);
fa[ch[cur][1]] = x;
ch[x][1] = ch[cur][1];
clear(cur);
pushup(root);
}
}
}
int m;
signed main() {
scanf("%d", &m);
for (int op, x; m; --m) {
scanf("%d%d", &op, &x);
if (op == 1)
Splay::insert(x);
else if (op == 2)
Splay::remove(x);
else if (op == 3)
printf("%d\n", Splay::rnk(x));
else if (op == 4)
printf("%d\n", Splay::kth(x));
else if (op == 5) {
Splay::insert(x);
printf("%d\n", Splay::val[Splay::near(0)]);
Splay::remove(x);
}
else {
Splay::insert(x);
printf("%d\n", Splay::val[Splay::near(1)]);
Splay::remove(x);
}
}
return 0;
}
扩展
区间翻转
我们以编号为下标建立一棵 Splay
当我们翻转区间 \([l, r]\) ,时,我们可以考虑利用 Splay 的性质,将 \(l - 1\) 翻转至根节点,再将 \(r + 1\) 翻转至其右儿子,这样 \(r + 1\) 的左儿子就是所有 \([l, r]\) 的数了
此时,我们对这个节点打上标记,有需要时再翻转即可
为了方便,我们在树的两端各插入 \(\pm \infty\) ,防止在翻转 \([1, n]\) 时出现问题
#include <bits/stdc++.h>
using namespace std;
const int inf = 0x3f3f3f3f;
const int N = 1e5 + 7;
namespace Splay {
int ch[N][2];
int fa[N], siz[N], val[N], tag[N];
int root, tot;
inline void pushup(int x) {
siz[x] = siz[ch[x][0]] + siz[ch[x][1]] + 1;
}
inline void pushdown(int x) {
if (tag[x]) {
tag[ch[x][0]] ^= 1;
tag[ch[x][1]] ^= 1;
swap(ch[x][0], ch[x][1]);
tag[x] = 0;
}
}
inline int get(int x) {
return x == ch[fa[x]][1];
}
inline void rotate(int x) {
int y = fa[x], z = fa[y], d = get(x);
pushdown(y), pushdown(x);
ch[y][d] = ch[x][d ^ 1];
if (ch[x][d ^ 1])
fa[ch[x][d ^ 1]] = y;
ch[x][d ^ 1] = y;
fa[y] = x, fa[x] = z;
if (z)
ch[z][y == ch[z][1]] = x;
pushup(y), pushup(x);
}
inline void splay(int x, int goal = 0) {
for (int f = fa[x]; f != goal; rotate(x), f = fa[x])
if (fa[f] != goal)
rotate(get(x) == get(f) ? f : x);
if (!goal)
root = x;
}
inline void insert(int k) {
if (!root) {
val[++tot] = k;
root = tot;
pushup(root);
return;
}
for (int x = root, f = 0;;) {
f = x, x = ch[x][val[x] < k];
if (!x) {
val[++tot] = k;
fa[tot] = f, ch[f][val[f] < k] = tot;
pushup(tot), pushup(f);
splay(tot);
break;
}
}
}
inline int find(int k) {
for (int x = root;;) {
pushdown(x);
if (k <= siz[ch[x][0]])
x = ch[x][0];
else if (k == siz[ch[x][0]] + 1)
return x;
else
k -= siz[ch[x][0]] + 1, x = ch[x][1];
}
}
inline void reverse(int l, int r) {
l = find(l - 1), r = find(r + 1);
splay(l, 0), splay(r, l);
tag[ch[ch[root][1]][0]] ^= 1;
}
inline void dfs(int x) {
if (!x)
return;
pushdown(x);
dfs(ch[x][0]);
if (val[x] != -inf && val[x] != inf)
printf("%d ", val[x]);
dfs(ch[x][1]);
}
}
int n, m;
signed main() {
scanf("%d%d", &n, &m);
Splay::insert(-inf), Splay::insert(inf);
for (int i = 1; i <= n; ++i)
Splay::insert(i);
for (int l, r; m; --m) {
scanf("%d%d", &l, &r);
Splay::reverse(l + 1, r + 1);
}
Splay::dfs(Splay::root);
return 0;
}
区间移动
将区间 \([l, r]\) 扔到 \(c\) 后面
首先,类似区间翻转,我们拿出区间 \([l, r]\) ,并将 \(c\) 旋转到根,将 \(c + 1\) 旋转至 \(c\) 的右儿子,接着把 \([l, r]\) 设为 \(c + 1\) 的左儿子即可
标签:ch,cur,int,tot,Splay,fa,root From: https://www.cnblogs.com/wshcl/p/splay.html