平衡树
平衡树是一类二叉查找树,因为普通的二叉查找树可能会因为特殊的数据的构造变成链,导致原本应该是 \(\mathcal O(\log n)\) 的查找速度退化成为 \(\mathcal O(n)\),损失大量效率。为了解决这个问题,就有了平衡树这一数据结构。
平衡树,就是对二叉查找树进行一些变形,使得这个二叉查找树尽量的平衡,使得单次操作的时间复杂度尽量靠近于 \(\mathcal O(\log n)\)。下面以 Luogu P3369 为例,对一些常用的平衡树进行一些讲解。
非旋 Treap(待填坑)
Splay
Splay 是一种神奇的平衡树,给一个关键词就是:转转转(
虽然时间复杂度是 \(\mathcal O(n\log n)\) 的,但是常数很大,可能有 \(8\) 左右,所以不算一个比较高效的平衡树。
Splay 首先需要记录每个节点的值,值出现次数,节点左右儿子,父亲节点,节点子树大小(有什么用之后会了解)。
定义
namespace Splay{
#define lc(x) ch[x][0]
#define rc(x) ch[x][1]
const int _SIZE=1e5;
int root,tot;
int val[_SIZE+5],cnt[_SIZE+5],ch[_SIZE+5][2],sz[_SIZE+5],fa[_SIZE+5];
#undef lc
#undef rc
}
这段代码进行了宏定义,可以比较好的简化代码,提升可读性。因为是定义在命名空间中,所以建议在命名空间结束的时候取消定义,养成一个好习惯,方便以后可能的更大的代码项目。
三个基础操作
void maintain(int x) {sz[x]=sz[lc(x)]+sz[rc(x)]+cnt[x];}//更新子树大小
bool get(int x) {return x==rc(fa[x]);}//判断x为哪个子树
void clear(int x) {val[x]=cnt[x]=lc(x)=rc(x)=sz[x]=fa[x]=0;}//清除x节点
这三个函数的作用应该是显而易见的了。这里不做过多解释。
旋转
Splay 的一个基本操作就是旋转 rotate
。rotate
操作会将节点 x
向其父节点 y
旋转。具体操作(以 x
为 y
的左儿子为例):将 x
设为 y
的父节点,然后将 x
的右儿子给设置为 y
的左儿子。不难发现,这样的旋转操作不会破坏 BST 的平衡性。
void rotate(int x,int &rt=root)
{
int y=fa[x],z=fa[y],chk=get(x);//chk用来确定x是在哪一个子树
ch[y][chk]=ch[x][chk^1];
if (ch[x][chk^1]) fa[ch[x][chk^1]]=y;//x的儿子与y连边
ch[x][chk^1]=y,fa[y]=x,fa[x]=z;//x与y父子关系反转
if (z) ch[z][y==rc(z)]=x;//如果y有父节点z需要将z的儿子改到x
else rt=x;//如果没有就将这个子树的根改为x
maintain(y),maintain(x);//更新size
}
Splay
这是 Splay 这种平衡树的最关键的操作,是其时间复杂度的保证。具体的做法就是将某一个节点 x
给旋转到整棵 BST 的一个节点下(一般为根节点)。
void splay(int x,int &rt=root)
{
int y=fa[x];
for (;x!=rt;rotate(x,rt),y=fa[x])//没到就一直转
if (y!=rt) rotate(get(x)==get(y)?y:x,rt);//如果y不是根节点就需要判断x,y是否是在一条链上的,如果是就先旋转y再旋转x,否则旋转两次x;如果y是根节点就只需要旋转一次x
rt=x;//更新新的子树的根节点
}
需要注意的是,在 Splay 的所有的可能更改平衡树结构的操作时,都需要将新更改的节点 splay
到根节点,否则将无法保证时间复杂度的正确性。
插入
相比于前面的两个函数,接下来的操作就是普通的 BST 也支持的东西了,实现也会比较简单一些了,就直接给出代码了(记住 splay
)。
void insert(int k)
{
if (!root)//树是空的
{
root=++tot,cnt[tot]++,val[tot]=k;
return maintain(root);
}
int cur=root,f=0;
while (1)//非递归实现
{
if (k==val[cur])//存在节点
{
cnt[cur]++;
maintain(cur),maintain(f);
return splay(cur);//记得splay
}
f=cur,cur=ch[cur][k>val[cur]];
if (!cur)//新建节点
{
cnt[++tot]++,val[tot]=k;
fa[tot]=f,ch[f][k>val[f]]=tot;
maintain(tot),maintain(f);
return splay(tot);//splay
}
}
}
根据 Val 查询排名
int rk(int k)
{
int res=0,cur=root;//res用于存储目前的排名
while (1)
{
if (k<val[cur]) cur=lc(cur);
else
{
res+=sz[lc(cur)];//不在左子树,就将左子树的全部节点个数统计入排名
if (k==val[cur]) {splay(cur);return res+1;}
res+=cnt[cur],cur=rc(cur);
}
}
}
根据排名查询 Val
int kth(int k)
{
int cur=root;
while (1)
{
if (lc(cur) && k<=sz[lc(cur)]) cur=lc(cur);
else
{
k-=sz[lc(cur)]+cnt[cur];//直接减,如果减成负数就证明是当前节点
if (k<=0) {splay(cur);return val[cur];}
cur=rc(cur);
}
}
}
查询前驱
查询 x
前驱的操作可以变成插入 x
,然后将 x
splay
到根节点,此时左子树中的最大值就是 x
的前驱,最后再将 x
删除即可。
这里给出查找根节点左子树最大值的函数。
int pre()
{
int cur=lc(root);
if (!cur) return cur;
while (rc(cur)) cur=rc(cur);//只要有右儿子就一直向右下走
splay(cur);//将cur旋转到根节点
return cur;//返回节点编号
}
查询后继
与查询前驱基本一致,插入 x
,在根节点右子树查找最小值,删除 x
。
int nxt()
{
int cur=rc(root);
if (!cur) return cur;
while (lc(cur)) cur=lc(cur);
splay(cur);
return cur;
}
删除操作
假设删除的数为 x
,那么先将 x
splay
至根节点,然后删除该数。如果 x
节点的 cnt
值被减为了 \(0\),那么就删除根节点,合并根节点的左右子树(此时 x
已经被 splay
到根节点了)。
假设合并的两棵平衡树为 \(A,B\),那么分情况进行讨论(需要注意,此处的 \(A,B\) 是需要满足 \(\max\{A_i\}<\min\{B_j\}\),即 \(A\) 的最大值小于 \(B\) 的最小值):
- 如果 \(A,B\) 二者中有一者为空,就将非空者作为新的平衡树。
- 如果 \(A,B\) 二者均非空,那么将 \(A\) 树的最大值
splay
到根节点(此时显然根节点是没有右子树的),然后将 \(A\) 树根节点的右子树接到 \(B\) 树的根节点上,并更新节点信息。
void remove(int k)
{
rk(k);
if (cnt[root]>1) {cnt[root]--;return maintain(root);}
if (!lc(root) && !rc(root)) {clear(root);root=0;return;}//删后为空树
if (!lc(root))//没有左子树
{
int cur=root;root=rc(root),fa[root]=0;
return clear(cur);
}
if (!rc(root))//没有右子树
{
int cur=root;root=lc(root),fa[root]=0;
return clear(cur);
}
int cur=root,x=pre();//利用之前写的pre函数获取到左子树的最大值
fa[rc(cur)]=x,rc(x)=rc(cur);//连接A B根节点
clear(cur),maintain(root);//清楚原来根节点的信息,更新新根节点的信息
}
总代码
#include<bits/stdc++.h>
#define mem(a,b) memset(a,b,sizeof a)
//#define int long long
using namespace std;
template<typename T> void read(T &k)
{
k=0;T flag=1;char b=getchar();
while (!isdigit(b)) {flag=(b=='-')?-1:1;b=getchar();}
while (isdigit(b)) {k=k*10+b-48;b=getchar();}
k*=flag;
}
template<typename T> void write(T k) {if (k<0) {putchar('-'),write(-k);return;}if (k>9) write(k/10);putchar(k%10+48);}
template<typename T> void writewith(T k,char c) {write(k);putchar(c);}
namespace Splay{
#define lc(x) ch[x][0]
#define rc(x) ch[x][1]
const int _SIZE=1e5;
int root,tot;
int val[_SIZE+5],cnt[_SIZE+5],ch[_SIZE+5][2],sz[_SIZE+5],fa[_SIZE+5];
void maintain(int x) {sz[x]=sz[lc(x)]+sz[rc(x)]+cnt[x];}
bool get(int x) {return x==rc(fa[x]);}
void clear(int x) {val[x]=cnt[x]=lc(x)=rc(x)=sz[x]=fa[x]=0;}
void rotate(int x,int &rt=root)
{
int y=fa[x],z=fa[y],chk=get(x);
ch[y][chk]=ch[x][chk^1];
if (ch[x][chk^1]) fa[ch[x][chk^1]]=y;
ch[x][chk^1]=y,fa[y]=x,fa[x]=z;
if (z) ch[z][y==rc(z)]=x;
else rt=x;
maintain(y),maintain(x);
}
void splay(int x,int &rt=root)
{
int y=fa[x];
for (;x!=rt;rotate(x,rt),y=fa[x])
if (y!=rt) rotate(get(x)==get(y)?y:x,rt);
rt=x;
}
void insert(int k)
{
if (!root)
{
root=++tot,cnt[tot]++,val[tot]=k;
return maintain(root);
}
int cur=root,f=0;
while (1)
{
if (k==val[cur])
{
cnt[cur]++;
maintain(cur),maintain(f);
return splay(cur);
}
f=cur,cur=ch[cur][k>val[cur]];
if (!cur)
{
cnt[++tot]++,val[tot]=k;
fa[tot]=f,ch[f][k>val[f]]=tot;
maintain(tot),maintain(f);
return splay(tot);
}
}
}
int rk(int k)
{
int res=0,cur=root;
while (1)
{
if (k<val[cur]) cur=lc(cur);
else
{
res+=sz[lc(cur)];
if (k==val[cur]) {splay(cur);return res+1;}
res+=cnt[cur],cur=rc(cur);
}
}
}
int kth(int k)
{
int cur=root;
while (1)
{
if (lc(cur) && k<=sz[lc(cur)]) cur=lc(cur);
else
{
k-=sz[lc(cur)]+cnt[cur];
if (k<=0) {splay(cur);return val[cur];}
cur=rc(cur);
}
}
}
int pre()
{
int cur=lc(root);
if (!cur) return cur;
while (rc(cur)) cur=rc(cur);
splay(cur);
return cur;
}
int nxt()
{
int cur=rc(root);
if (!cur) return cur;
while (lc(cur)) cur=lc(cur);
splay(cur);
return cur;
}
void remove(int k)
{
rk(k);
if (cnt[root]>1) {cnt[root]--;return maintain(root);}
if (!lc(root) && !rc(root)) {clear(root);root=0;return;}
if (!lc(root))
{
int cur=root;root=rc(root),fa[root]=0;
return clear(cur);
}
if (!rc(root))
{
int cur=root;root=lc(root),fa[root]=0;
return clear(cur);
}
int cur=root,x=pre();
fa[rc(cur)]=x,rc(x)=rc(cur);
clear(cur),maintain(root);
}
#undef lc
#undef rc
}using namespace Splay;
int n;
signed main()
{
read(n);
for (int i=1;i<=n;i++)
{
int opt,x;read(opt),read(x);
if (opt==1) insert(x);
if (opt==2) remove(x);
if (opt==3) writewith(rk(x),'\n');
if (opt==4) writewith(kth(x),'\n');
if (opt==5) insert(x),writewith(val[pre()],'\n'),remove(x);
if (opt==6) insert(x),writewith(val[nxt()],'\n'),remove(x);
}
return 0;
}
替罪羊树
替罪羊树是平衡树中效率极其优秀的,常数较小(可能是仅次于红黑树的),缺点是不如 Splay 和非旋 Treap 那么通用。
定义
先给出替罪羊树的变量声明。
namespace SGT{
const int _SIZE=1e5;
const double alpha=0.7;//一个常数,后文会提到
int tot,root,lc[_SIZE+5],rc[_SIZE+5],val[_SIZE+5];//按顺序,节点个数,根节点,左右儿子,节点权值
int cnt[_SIZE+5],sz[_SIZE+5],szv[_SIZE+5],szd[_SIZE+5];//节点权值个数,子树大小(每个节点记1次),子树大小(每个节点记cnt次),非空子树大小(非空节点记1次)
}
信息更新
根据定义,可以很简单的得出更新方式。
void maintain(int x)
{
sz[x]=sz[lc[x]]+sz[rc[x]]+1;
szv[x]=szv[lc[x]]+szv[rc[x]]+cnt[x];
szd[x]=szd[lc[x]]+szd[rc[x]]+(cnt[x]!=0);
}
判断是否需要重构
替罪羊树引入了一个常数 \(\alpha\)(一般为 \(0.6\sim 0.7\),通常取 \(0.7\)),当某个节点 x
的某一子树 y
大小占到了 x
的 \(\alpha\),那么就将这棵子树重构。并且,如果一棵子树的空节点(空节点定义为一个 cnt=0
的有权值 val
的节点)占到了该子树的 \(\alpha\),那么也将这棵子树重构。
bool canRbd(int x)
{
return cnt[x] && (alpha*sz[x]<=(double)max(sz[lc[x]],sz[rc[x]])) || sz[x]*alpha>=(double)szd[x];
}
拍扁重构
这是替罪羊树最核心的操作,时间复杂度是由这个操作来保证的。当判断了某个子树是否需要重构后,就需要进行重构操作。
将操作分为两步:拍扁、重构。
拍扁
拍扁就是将某个子树按照中序遍历的顺序存储到一个 vector
中(一般用数组模拟,而不用 vector
,否则就可能失去替罪羊树小常数的优势),一般就用普通的二叉树的中序遍历的方式就行。
int ldr[_SIZE+5];//模拟的vector
void Rbd_flatten(int &ldc,int x)//ldc是vector的尾下标
{
if (!x) return;//不存在x这个节点
Rbd_flatten(ldc,lc[x]);
if (cnt[x]) ldr[ldc++]=x;//只要该节点不是空节点就加入vector中
Rbd_flatten(ldc,rc[x]);
}
重构
直接用二分的方式递归 vector
建树即可。
int Rbd_build(int l,int r)//注意有返回值,返回根节点的新编号,此处l,r是前闭后开
{
int mid=l+r>>1;
if (l>=r) return 0;
lc[ldr[mid]]=Rbd_build(l,mid);//重构左子树
rc[ldr[mid]]=Rbd_build(mid+1,r);//右子树
maintain(ldr[mid]);//更新节点信息
return ldr[mid];
}
总代码:
int ldr[_SIZE+5];
void Rbd_flatten(int &ldc,int x)
{
if (!x) return;
Rbd_flatten(ldc,lc[x]);
if (cnt[x]) ldr[ldc++]=x;
Rbd_flatten(ldc,rc[x]);
}
int Rbd_build(int l,int r)
{
int mid=l+r>>1;
if (l>=r) return 0;
lc[ldr[mid]]=Rbd_build(l,mid);
rc[ldr[mid]]=Rbd_build(mid+1,r);
maintain(ldr[mid]);
return ldr[mid];
}
void Rebuild(int &x)
{
int ldc=0;
Rbd_flatten(ldc,x);
x=Rbd_build(0,ldc);
}
插入
与普通 BST 相同,采用递归实现。
void insert(int &k,int p)
{
if (!k)//
{
k=++tot;
if (!root) root=1;
val[k]=p,lc[k]=rc[k]=0;
sz[k]=szv[k]=szd[k]=cnt[k]=1;
}
else
{
if (val[k]==p) cnt[k]++;
else if (val[k]<p) insert(rc[k],p);
else insert(lc[k],p);
maintain(k);
if (canRbd(k)) Rebuild(k);//每次对BST的结构更改的时候都需要判断是否需要重构
}
}
删除
替罪羊树的删除是采用类似懒删除的方式,只将对应节点的 cnt--
,而不判断删除后是否成为了空节点,当空节点的数目很多的时候才会使用拍扁重构来清除这些空节点。
void remove(int &k,int p)
{
if (!k) return;//删除的节点不存在,忽略
if (val[k]==p) cnt[k]--;
else if (val[k]<p) remove(rc[k],p);
else remove(lc[k],p);
maintain(k);
if (canRbd(k)) Rebuild(k);//判断是否应该重构
}
Upper_bound 和 Upper_greater
uprbd
函数用于找到最小的大于某个值的节点的位置,uprgr
函数用于找到最大的小于某个值的节点的位置,实现方式不讲(BST 基本操作)。
int uprbd(int k,int p)
{
if (!k) return 1;
if (val[k]==p && cnt[k]) return szv[lc[k]]+cnt[k]+1;
else if (p<val[k]) return uprbd(lc[k],p);
else return szv[lc[k]]+cnt[k]+uprbd(rc[k],p);
}
int uprgr(int k,int p)
{
if (!k) return 0;
if (val[k]==p && cnt[k]) return szv[lc[k]];
else if (p<val[k]) return uprgr(lc[k],p);
else return szv[lc[k]]+cnt[k]+uprgr(rc[k],p);
}
排名与权值相互查询
查询排名可以用上面的 uprgr
函数直接得到。
int getRank(int x) {return uprgr(root,x)+1;}
查询权值也很简单。
int getVal(int k,int p)
{
if (!k) return 0;
if (szv[lc[k]]<p && p<=szv[lc[k]]+cnt[k]) return val[k];
else if (szv[lc[k]]+cnt[k]<p) return getVal(rc[k],p-szv[lc[k]]-cnt[k]);
else return getVal(lc[k],p);
}
查询前驱后继
直接用 getVal
函数与 uprbd
和 uprgr
函数组合即可。
int getPre(int k,int p) {return getVal(k,uprgr(k,p));}
int getNxt(int k,int p) {return getVal(k,uprbd(k,p));}
总代码
因为 Dev-C++ 不好用中文写注释所以就干脆写英文了(逃
#include<bits/stdc++.h>
#define mem(a,b) memset(a,b,sizeof a)
//#define int long long
using namespace std;
template<typename T> void read(T &k)
{
k=0;T flag=1;char b=getchar();
while (!isdigit(b)) {flag=(b=='-')?-1:1;b=getchar();}
while (isdigit(b)) {k=k*10+b-48;b=getchar();}
k*=flag;
}
template<typename T> void write(T k) {if (k<0) {putchar('-'),write(-k);return;}if (k>9) write(k/10);putchar(k%10+48);}
template<typename T> void writewith(T k,char c) {write(k);putchar(c);}
namespace SGT{
const int _SIZE=1e5;
const double alpha=0.7;
int tot,root,lc[_SIZE+5],rc[_SIZE+5],val[_SIZE+5];
int cnt[_SIZE+5],sz[_SIZE+5],szv[_SIZE+5],szd[_SIZE+5];
void maintain(int x)
{
sz[x]=sz[lc[x]]+sz[rc[x]]+1;//size of tree, each node count as 1
szv[x]=szv[lc[x]]+szv[rc[x]]+cnt[x];//each node count as its cnt
szd[x]=szd[lc[x]]+szd[rc[x]]+(cnt[x]!=0);//each node count as 1,deleted node not counted
}
bool canRbd(int x)//can rebuild
{
return cnt[x] && (alpha*sz[x]<=(double)max(sz[lc[x]],sz[rc[x]])) || sz[x]*alpha>=(double)szd[x];
}
int ldr[_SIZE+5];
void Rbd_flatten(int &ldc,int x)//flatten function in rebuild
{//ldc->tail of ldr(temp vector for flatten)
if (!x) return;//node x not exist
Rbd_flatten(ldc,lc[x]);//into left son
if (cnt[x]) ldr[ldc++]=x;//node x isnt empty
Rbd_flatten(ldc,rc[x]);//right son
}
int Rbd_build(int l,int r)//returns the new number of node
{
int mid=l+r>>1;
if (l>=r) return 0;
lc[ldr[mid]]=Rbd_build(l,mid);//get left son node, and rebuild
rc[ldr[mid]]=Rbd_build(mid+1,r);
maintain(ldr[mid]);
return ldr[mid];
}
void Rebuild(int &x)
{
int ldc=0;
Rbd_flatten(ldc,x);//flatten into vector
x=Rbd_build(0,ldc);//rebuild vector into BST
}
void insert(int &k,int p)
{
if (!k)//p not in BST
{
k=++tot;//newnode
if (!root) root=1;//empty tree
val[k]=p,lc[k]=rc[k]=0;//new node
sz[k]=szv[k]=szd[k]=cnt[k]=1;//init
}
else
{
if (val[k]==p) cnt[k]++;//current node is p,add 1 to cnt
else if (val[k]<p) insert(rc[k],p);//p in right son
else insert(lc[k],p);// p in left son
maintain(k);//update
if (canRbd(k)) Rebuild(k);//rebuild
}
}
void remove(int &k,int p)
{
if (!k) return;//k not in BST
if (val[k]==p) cnt[k]--;//current node is p,del it
else if (val[k]<p) remove(rc[k],p);//p in right son
else remove(lc[k],p);//p in left son
maintain(k);//update
if (canRbd(k)) Rebuild(k);//rebuild
}
int uprbd(int k,int p)//same as upper_bound, finds the smallest elements greater than p
{
if (!k) return 1;//p is smaller than any one
if (val[k]==p && cnt[k]) return szv[lc[k]]+cnt[k]+1;//p equals to current node, the ans is first one in right son
else if (p<val[k]) return uprbd(lc[k],p);//p in left son
else return szv[lc[k]]+cnt[k]+uprbd(rc[k],p);//p in right son
}
int uprgr(int k,int p)//finds the largerst elements smaller than p
{
if (!k) return 0;
if (val[k]==p && cnt[k]) return szv[lc[k]];
else if (p<val[k]) return uprgr(lc[k],p);
else return szv[lc[k]]+cnt[k]+uprgr(rc[k],p);
}
int getRank(int x) {return uprgr(root,x)+1;}//rank is the rank of upper_greater(p) + 1
int getVal(int k,int p) // get value by rank
{
if (!k) return 0;
if (szv[lc[k]]<p && p<=szv[lc[k]]+cnt[k]) return val[k];
else if (szv[lc[k]]+cnt[k]<p) return getVal(rc[k],p-szv[lc[k]]-cnt[k]);
else return getVal(lc[k],p);
}
int getPre(int k,int p) {return getVal(k,uprgr(k,p));}
int getNxt(int k,int p) {return getVal(k,uprbd(k,p));}
void print(int x)
{
if (lc[x]) print(lc[x]);
if (cnt[x]) writewith(val[x],' ');
if (rc[x]) print(rc[x]);
}
} using namespace SGT;
int n;
signed main()
{
read(n);
for (int i=1;i<=n;i++)
{
int opt,x;read(opt),read(x);
if (opt==1) insert(root,x);
if (opt==2) remove(root,x);
if (opt==3) writewith(getRank(x),'\n');
if (opt==4) writewith(getVal(root,x),'\n');
if (opt==5) writewith(getPre(root,x),'\n');
if (opt==6) writewith(getNxt(root,x),'\n');
}
return 0;
}
标签:cnt,cur,int,笔记,学习,rc,平衡,root,节点
From: https://www.cnblogs.com/hanx16msgr/p/16725631.html