/*
在splay中
0不能算作是根节点,只能说是一个标记点
如果谁的父亲是0,那么谁就是根节点
*/
#include <bits/stdc++.h>
using namespace std;
const int M=1e5+5;
const int inf=1e9;
#define t tr
#define size siz
int cnt=0,root=0;
struct splay {
int ch[2],siz,cnt,val,fa;
}tr[M];
int get(int x) {
return tr[tr[x].fa].ch[1]==x;
}
void up(int x) {
tr[x].siz=tr[tr[x].ch[0]].siz+tr[tr[x].ch[1]].siz+tr[x].cnt;
}
void rotate(int x) {
int y=tr[x].fa,z=tr[y].fa;
int d1=get(x),d2=get(y);
int son=tr[x].ch[d1^1];
tr[y].ch[d1]=son;tr[son].fa=y;
tr[z].ch[d2]=x;tr[x].fa=z;
tr[x].ch[d1^1]=y;tr[y].fa=x;
up(y);up(x);
}
void splay(int x,int goal) {
while(tr[x].fa!=goal) {
int y=tr[x].fa,z=tr[y].fa;
int d1=get(x),d2=get(y);
if(z!=goal) {
if(d1==d2)rotate(y);
else rotate(x);
}
rotate(x);
}
if(goal==0)root=x;
}
int find(int val) {
int node=root;
while(tr[node].val!=val&&tr[node].ch[tr[node].val<val])node=tr[node].ch[tr[node].val<val];
return node;
}
void insert(int val) {
int node=root,fa=0;
while(tr[node].val!=val&&node)
fa=node,node=tr[node].ch[tr[node].val<val];
if(node)tr[node].cnt++;
else {
node=++cnt;
if(fa)tr[fa].ch[tr[fa].val<val]=node;
tr[node].siz=tr[node].cnt=1;
tr[node].fa=fa,tr[node].val=val;
}
splay(node,0);
}
int pre(int val,int k) {
splay(find(val),0);
int node=root;
if(k==0&&tr[node].val<val)return node;
if(k==1&&tr[node].val>val)return node;
node = tr[node].ch[k];
while(tr[node].ch[k^1])node=tr[node].ch[k^1];
return node;
}
void del(int val){
int last = pre(val,0), next = pre(val,1);
splay(last , 0); splay(next , last);
if(t[t[next].ch[0]].cnt > 1){
t[t[next].ch[0]].cnt--;
splay(t[next].ch[0] , 0);
}
else t[next].ch[0] = 0;
}
int kth(int k){
int node = root;
if(t[node].size < k) return inf;
while(1){
int son = t[node].ch[0];
if(k <= t[son].size) node = son;
else if(k > t[son].size+t[node].cnt){
k -= t[son].size+t[node].cnt;
node = t[node].ch[1];
}
else return t[node].val;
}
}
int get_rank(int val){
splay(find(val) , 0);
return t[t[root].ch[0]].size;
}
int main() {
insert(-inf);insert(inf);
int q;cin>>q;
while(q--) {
int op,x;
cin>>op>>x;
if(op==1)insert(x);
if(op==2)del(x);
if(op==3)cout<<get_rank(x)<<endl;
if(op==4)cout<<kth(x+1)<<endl;
if(op==5)cout<<tr[pre(x,0)].val<<endl;
if(op==6)cout<<tr[pre(x,1)].val<<endl;
}
return 0;
}
标签:node,ch,val,int,tr,splay,fa,平衡,模板
From: https://www.cnblogs.com/basicecho/p/17320254.html