二叉查找树
- 一定是一颗二叉树
- 左子树任意节点的值 \(<\) 当前节点的值 \(<\) 右子树任意节点的值
基本思想
Splay 的基本思想就是越经常访问的节点的深度就越低,也就是说,如果我们访问了一个节点,我们就要将其 Splay 到根。
所以 Splay 的复杂度其实是均摊 \(\mathcal{O}(\log n)\) 的。
Splay 相比其他平衡树优点主要在好写,扩展性强,空间复杂度低。
旋转
引用一个比较经典的图片。
![[Pasted image 20230123163624.png]]
我们具体分析一下旋转过程。
左旋和右旋的区别其实就是我们当前旋转的节点是左儿子还是右儿子,不用太多去管。
比如我们要旋转节点 \(2\)。
节点 \(2\) 是节点 \(1\) 的左儿子,所以这是右旋操作,因为经过旋转之后,之前的根节点变成了当前根节点的右儿子。
但是节点 \(5\) 要怎么办呢?
因为节点 \(2\) 原来是节点 \(1\) 的左儿子,所以说节点 \(1\) 是没有左儿子的,所以我们只要把当前旋转的节点的右儿子作为他父亲节点的左儿子就好了。
分析一下正确性,发现右旋前的中序遍历是 \(4,2,5,1,3\) 右旋后的中序遍历还是 \(4,2,5,1,3\)。
这就很神奇了,因为我们考虑中序遍历是先访问的左儿子,所以实际上我们把左儿子上移并不会影响中序遍历,同时原来的右儿子也被移到了右子树第一个会被访问到的地方,很合理。
左旋同理。
所以旋转本质上就是将旋转节点与其父节点交换后再进行调整使其满足二叉搜索树的性质。
Splay 操作
其实就是将一个节点一直旋转,直到成为根节点。
具体要分 \(6\) 种情况讨论 (zig, zag, zig-zig, zag-zag, zig-zag, zag-zig)
由于 zig 和 zag 有对称性,我们只讨论 zig, zig-zig, zig-zag。
记当前节点为 \(x\),其父亲为 \(p\),父亲的父亲为 \(g\)。
zig
当 \(p\) 为根节点时,旋转 \(x\)。
![[Pasted image 20230124104337.png]]
zig-zig
当 \(p\) 不为根节点且 \(p,x\) 为同侧节点时 ,其先旋转 \(p\),再旋转 \(x\)。
![[Pasted image 20230124103948.png]]
zig-zag
当 \(p\) 不为根节点,且 \(p,x\) 不为同侧节点,那么就旋转两次 \(x\),将 \(p\) 和 \(g\) 变为 \(x\) 的儿子。
![[Pasted image 20230124105250.png]]
复杂度分析以后再补。
插入
跟普通平衡树插入差不多,不过不同的是要对插入后的节点做 Splay 操作。
Rank
类似于二分的操作,累加左儿子和中间节点的大小,找到最后节点之后要进行 Splay 操作。
Kth
类似于二分的操作,找到节点之后要进行 Splay 操作。
查找前驱
可以看作找小于 \(x\) 的第一个数,可以把 \(x\) 插入之后找其左子树的最右节点,之后再删除。
查找后继
与前驱同理。
删除
首先我们将要删除的数 Splay 到根。
然后如果我们要删除的数有很多个,我们就将其 \(cnt - 1\)。
如果只有一个,删除这个节点然后合并其左右子树。
因为左子树的最大值一定小于右子树的最小值,
所以我们可以把左子树的最大值 Splay 到根后将右子树挂到左子树的右儿子上。
现在的左子树就是新的树。
普通平衡树 Splay 版本
需要注意的细节都写代码里了。
#include <cstdio>
#include <functional>
using namespace std;
/**
*
* Splay Template By luanmenglei
*
* Credit oi-wiki.org
*
**/
namespace Splay {
const int N = 1e5 + 10;
int ch[N][2], sze[N], val[N], cnt[N], fa[N], tot, rt;
void update(int x) { sze[x] = sze[ch[x][0]] + sze[ch[x][1]] + cnt[x]; }
int get(int x) { return ch[fa[x]][1] == x; } // return 1 if x is right child
void clear(int x) { ch[x][0] = ch[x][1] = fa[x] = val[x] = sze[x] = cnt[x] = 0; }
void set(int parent, int x, int side) { fa[x] = parent, ch[parent][side] = x; } // need update size
void rotate(int x) {
int y = fa[x], z = fa[y], sidex = get(x), sidey = get(y);
set(y, ch[x][sidex ^ 1], sidex), set(z, x, sidey), set(x, y, sidex ^ 1);
fa[0] = 0, ch[0][0] = ch[0][1] = 0; // must clear
update(y), update(x); // must update y before x because y is x's child
}
void splay(int x) {
for (int f = fa[x], g = fa[f]; f; rotate(x), f = fa[x], g = fa[f]) if (g) rotate(get(x) == get(f) ? f : x); // 是否同侧
rt = x;
}
int create(int k) { return val[++ tot] = k, cnt[tot] = 1, sze[tot] = 1, tot; }
int insert(int k) { // insert a number
if (!rt) return rt = create(k);
int cur = rt, f = 0;
while (true) {
if (val[cur] == k) return ++ cnt[cur], update(cur), update(f), splay(cur), rt;
f = cur, cur = ch[f][val[f] < k];
if (!cur) return cur = create(k), set(f, cur, val[f] < k), update(f), splay(cur), rt;
}
}
int rk(int k) { // return the rank of k
int rank = 0, cur = rt;
while (true) {
if (k < val[cur]) cur = ch[cur][0];
else {
rank += sze[ch[cur][0]];
if (k == val[cur]) return splay(cur), rank + 1;
rank += cnt[cur], cur = ch[cur][1];
}
}
}
int kth(int k) { // return the kth node's value
int cur = rt;
while (true) {
if (ch[cur][0] && k <= sze[ch[cur][0]]) cur = ch[cur][0];
else {
k -= cnt[cur] + sze[ch[cur][0]];
if (k <= 0) return splay(cur), val[cur];
cur = ch[cur][1];
}
}
}
int travel(int side) {
int cur = ch[rt][side];
if (!cur) return cur;
while (ch[cur][side ^ 1]) cur = ch[cur][side ^ 1];
return splay(cur), cur;
}
int get_lr() { return travel(0); } // get the left tree's rightest node
int get_rl() { return travel(1); } // get the right tree's leftest node
void del(int k) {
auto clear = [&](int x) { sze[x] = cnt[x] = val[x] = fa[x] = ch[x][0] = ch[x][1] = 0; };
rk(k);
if (cnt[rt] > 1) -- cnt[rt], -- sze[rt];
else if (ch[rt][0] && ch[rt][1]) { // merge the left tree and the right tree
int tmp = rt, maxl = get_lr();
set(maxl, ch[tmp][1], 1);
clear(tmp), update(rt);
} else if (ch[rt][0]) {
int tmp = ch[rt][0];
clear(rt);
rt = tmp, fa[rt] = 0;
} else if (ch[rt][1]) {
int tmp = ch[rt][1];
clear(rt);
rt = tmp, fa[rt] = 0;
} else clear(rt), rt = 0;
}
int pre(int k) {
insert(k);
int ret = get_lr();
del(k);
return val[ret];
}
int nxt(int k) {
insert(k);
int ret = get_rl();
del(k);
return val[ret];
}
}
using Splay::insert;
using Splay::pre;
using Splay::nxt;
using Splay::rk;
using Splay::kth;
using Splay::del;
const function<void(int)> SPLAY_FUNC[] = {
[](int x) { insert(x); },
[](int x) { del(x); },
[](int x) { printf("%d\n", rk(x)); },
[](int x) { printf("%d\n", kth(x)); },
[](int x) { printf("%d\n", pre(x)); },
[](int x) { printf("%d\n", nxt(x)); }
};
int main() {
int q; scanf("%d", &q);
while (q --) {
int op, x; scanf("%d%d", &op, &x);
SPLAY_FUNC[op - 1](x);
}
return 0;
}