平衡树系列
splay
点击查看代码
#include<bits/stdc++.h>
using namespace std;
const int N=2e6+10;
int siz[N],cnt[N],tot,fa[N],ch[N][2],val[N],root,n;
inline int read(){
register int ans=0;register char ch=getchar();register bool flag=0;
while(ch<'0'||ch>'9'){if(ch=='-')flag=1;ch=getchar();}
while(ch>='0'&&ch<='9')ans=(ans<<3)+(ans<<1)+(ch^48),ch=getchar();
return flag?(~ans+1):ans;
}
void print(register int x){
if(x<0)putchar('-'),x=-x;
if(x>9)print(x/10);
putchar(x%10+'0');
}
void update(int x){
siz[x]=cnt[x]+siz[ch[x][0]]+siz[ch[x][1]];
}
void rotate(int x){
int y=fa[x],z=fa[y];
int k=ch[y][1]==x;
ch[z][ch[z][1]==y]=x;
fa[x]=z;
ch[y][k]=ch[x][k^1];
fa[ch[x][k^1]]=y;
ch[x][k^1]=y;
fa[y]=x;
update(y);update(x);
}
void splay(int x,int s){
while(fa[x]!=s){
int y=fa[x],z=fa[y];
if(z^s){
(ch[z][1]==y)^(ch[y][1]==x)?rotate(x):rotate(y);
}rotate(x);
}
if(!s)root=x;
}
void find(int x){
int u=root;
if(!u)return;
while(ch[u][x>val[u]]&&x!=val[u])
u=ch[u][x>val[u]];
splay(u,0);
}
void insert(int x){
int u=root,y=0;
while(u&&val[u]!=x){
y=u;
u=ch[u][x>val[u]];
}
if(u)cnt[u]++;
else{
u=++tot;
if(y)ch[y][x>val[y]]=u;
ch[u][0]=ch[u][1]=0;
fa[tot]=y;val[tot]=x;
siz[tot]=cnt[tot]=1;
}
splay(u,0);
}
int lower(int x){
//返回前驱的编号
find(x);
int u=root;
if(val[u]<x)return u;
u=ch[u][0];
while(ch[u][1])u=ch[u][1];
return u;
}
int upper(int x){
//返回后继的编号
find(x);
int u=root;
if(val[u]>x)return u;
u=ch[u][1];
while(ch[u][0])u=ch[u][0];
return u;
}
void del(int x){
//删除键值为x的点
int last=lower(x),nxt=upper(x);
splay(last,0);splay(nxt,last);
int d=ch[nxt][0];
if(cnt[d]>1){
cnt[d]--;
splay(d,0);
}else ch[nxt][0]=0;
}
int get_kth(int k){
int u=root;
if(siz[u]<k)return 0;
while(1){
int y=ch[u][0];
if(k>siz[y]+cnt[u]){
k-=siz[y]+cnt[u];
u=ch[u][1];
}else if(k<=siz[y])u=y;
else return val[u];
}
}
int main(){
n=read();
insert(INT_MAX);
insert(INT_MIN);
while(n--){
int op=read(),x=read();
switch(op){
case 1:{insert(x);break;}
case 2:{del(x);break;}
case 3:{find(x),print(siz[ch[root][0]]+(x>val[root])),puts("");break;}
case 4:{print(get_kth(x+1)),putchar('\n');break;}
case 5:{print(val[lower(x)]),putchar('\n');break;}
case 6:{print(val[upper(x)]),putchar('\n');break;}
}
}
return 0;
}
无旋Treap
点击查看代码
#include<bits/stdc++.h>
using namespace std;
const int N=1e5+10;
int rd[N],ch[N][2],siz[N],val[N],n,root,tot;
inline int read(){
register int ans=0;register char ch=getchar_unlocked();register bool flag=0;
while(ch<'0'||ch>'9'){if(ch=='-')flag=1;ch=getchar_unlocked();}
while(ch>='0'&&ch<='9')ans=(ans<<3)+(ans<<1)+(ch^48),ch=getchar_unlocked();
return flag?(~ans+1):ans;
}
void print(register int x){
if(x<0)putchar_unlocked('-'),x=-x;
if(x>9)print(x/10);
putchar_unlocked(x%10+'0');
}
void update(int x){
siz[x]=siz[ch[x][0]]+siz[ch[x][1]]+1;
}
int newpoint(int va){
val[++tot]=va;rd[tot]=rand();
siz[tot]=1;
return tot;
}
int merge(int x,int y){
if(!x||!y)return x+y;
if(rd[x]<rd[y])return ch[x][1]=merge(ch[x][1],y),update(x),x;
return ch[y][0]=merge(x,ch[y][0]),update(y),y;
}
void va_split(int x,int va,int &a,int &b){
if(!x){
a=b=0;
return;
}
if(val[x]<=va)a=x,va_split(ch[x][1],va,ch[x][1],b);
else b=x,va_split(ch[x][0],va,a,ch[x][0]);
update(x);
}
int get_kth(int x,int k){
while(1){
if(k<=siz[ch[x][0]])x=ch[x][0];
else if(k>siz[ch[x][0]]+1)k-=siz[ch[x][0]]+1,x=ch[x][1];
else return val[x];
}
}
void insert(int va){
int x,y;
va_split(root,va,x,y);
root=merge(merge(x,newpoint(va)),y);
}
void del(int va){
int x,y,z;
va_split(root,va,x,y);
va_split(x,va-1,x,z);
z=merge(ch[z][0],ch[z][1]);
root=merge(merge(x,z),y);
}
int get_rank(int va){
int x,y,ans;
va_split(root,va-1,x,y),ans=siz[x]+1;
return root=merge(x,y),ans;
}
int lower(int va){
int x,y,ans;
va_split(root,va-1,x,y),ans=get_kth(x,siz[x]);
return root=merge(x,y),ans;
}
int upper(int va){
int x,y,ans;
va_split(root,va,x,y),ans=get_kth(y,1);
return root=merge(x,y),ans;
}
int main(){
n=read();
while(n--){
int op=read(),x=read();
switch(op){
case 1:{insert(x);break;}
case 2:{del(x);break;}
case 3:{print(get_rank(x)),putchar_unlocked('\n');break;}
case 4:{print(get_kth(root,x)),putchar_unlocked('\n');break;}
case 5:{print(lower(x)),putchar_unlocked('\n');break;}
case 6:{print(upper(x)),putchar_unlocked('\n');break;}
}
}
}
树链剖分
点击查看代码
#include<bits/stdc++.h>
using namespace std;
#define ld (x<<1)
#define rd (x<<1|1)
const int N=3e4+10;
inline int read(){
int ans=0;char ch=getchar();bool flag=0;
while(ch<'0'||ch>'9'){if(ch=='-')flag=1;ch=getchar();}
while(ch>='0'&&ch<='9')ans=(ans<<3)+(ans<<1)+(ch^48),ch=getchar();
return flag?~ans+1:ans;
}
void print(register int x){if(x<0)putchar('-'),x=-x;if(x>9)print(x/10);putchar(x%10|48);}
struct stu{
int l,r,m,sum;
}s[N<<2];
int top[N],fa[N],son[N],siz[N],to[N<<1],nt[N<<1],tot,h[N],w[N];
int n,m,rnk[N],dfn[N],dfntot,dep[N];
void push_up(int x){
s[x].m=max(s[ld].m,s[rd].m);
s[x].sum=s[ld].sum+s[rd].sum;
}
void build(int x,int l,int r){
s[x].l=l,s[x].r=r;
if(l==r){
s[x].m=s[x].sum=w[rnk[l]];
return;
}
int mid=l+r>>1;
build(ld,l,mid);
build(rd,mid+1,r);
push_up(x);
}
int query(int x,int l,int r){
if(l<=s[x].l&&s[x].r<=r)return s[x].sum;
int ans=0,mid=s[x].l+s[x].r>>1;
if(l<=mid)ans+=query(ld,l,r);
if(r>mid)ans+=query(rd,l,r);
return ans;
}
int query_x(int x,int y){
int ans=0;
while(top[x]^top[y]){
if(dep[top[x]]<dep[top[y]])swap(x,y);
ans+=query(1,dfn[top[x]],dfn[x]);
x=fa[top[x]];
}
if(dep[x]>dep[y])swap(x,y);
ans+=query(1,dfn[x],dfn[y]);
return ans;
}
int get_max(int x,int l,int r){
if(l<=s[x].l&&s[x].r<=r)return s[x].m;
int ans=INT_MIN,mid=s[x].l+s[x].r>>1;
if(l<=mid)ans=max(ans,get_max(ld,l,r));
if(r>mid)ans=max(ans,get_max(rd,l,r));
return ans;
}
int get_max_x(int x,int y){
int ans=INT_MIN;
while(top[x]^top[y]){
if(dep[top[x]]<dep[top[y]])swap(x,y);
ans=max(ans,get_max(1,dfn[top[x]],dfn[x]));
x=fa[top[x]];
}
if(dep[x]>dep[y])swap(x,y);
ans=max(ans,get_max(1,dfn[x],dfn[y]));
return ans;
}
void change(int x,int pos,int val){
if(s[x].l==s[x].r){
s[x].m=s[x].sum=val;
return;
}
int mid=s[x].l+s[x].r>>1;
if(pos<=mid)change(ld,pos,val);
else change(rd,pos,val);
push_up(x);
}
void add(int x,int y){
to[++tot]=y;
nt[tot]=h[x];
h[x]=tot;
}
void dfs1(int x){
siz[x]=1;
for(int i=h[x];i;i=nt[i]){
int y=to[i];
if(!dep[y]){
fa[y]=x;
dep[y]=dep[x]+1;
dfs1(y);
siz[x]+=siz[y];
if(siz[son[x]]<siz[y])son[x]=y;
}
}
}
void dfs2(int x,int t){
dfn[x]=++dfntot;
rnk[dfntot]=x;
top[x]=t;
if(!son[x])return;
dfs2(son[x],t);
for(int i=h[x];i;i=nt[i]){
int y=to[i];
if((fa[x]^y)&&(son[x]^y))dfs2(y,y);
}
}
int main(){
n=read();
for(int i=1;i<n;i++){
int x=read(),y=read();
add(x,y);add(y,x);
}
for(int i=1;i<=n;i++)w[i]=read();
dep[1]=1;dfs1(1);dfs2(1,1);
build(1,1,n);
m=read();
while(m--){
string s;cin>>s;
int x=read(),y=read();
if(s[0]=='Q'){
if(s[1]=='M')print(get_max_x(x,y)),putchar('\n');
else print(query_x(x,y)),putchar('\n');
}
else change(1,dfn[x],y);
}
return 0;
}
主席树
点击查看代码
#include<bits/stdc++.h>
using namespace std;
const int N=1e5+10;
int sum[N<<5],ld[N<<5],rd[N<<5],tot,n,m;
int a[N],b[N],len,rt[N<<5];
inline int read(){
int ans=0;bool flag=0;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')flag=1;ch=getchar();}
while(ch>='0'&&ch<='9')ans=(ans<<3)+(ans<<1)+(ch^48),ch=getchar();
return flag?~ans+1:ans;
}
void build(int &x,int l,int r){
x=++tot;
if(l==r)return;
int mid=l+r>>1;
build(ld[x],l,mid);
build(rd[x],mid+1,r);
}
int update(int x,int l,int r,int pos){
int op=++tot;
ld[op]=ld[x],rd[op]=rd[x];
sum[op]=sum[x]+1;
if(l==r)return op;
int mid=l+r>>1;
if(pos<=mid)ld[op]=update(ld[x],l,mid,pos);
else rd[op]=update(rd[x],mid+1,r,pos);
return op;
}
int query(int x,int y,int l,int r,int k){
if(l==r)return l;
int mid=l+r>>1,pos=sum[ld[y]]-sum[ld[x]];
if(k<=pos)return query(ld[x],ld[y],l,mid,k);
else return query(rd[x],rd[y],mid+1,r,k-pos);
}
int main(){
n=read(),m=read();
for(int i=1;i<=n;i++)b[i]=a[i]=read();
sort(b+1,b+1+n);
len=unique(b+1,b+1+n)-b-1;
build(rt[0],1,len);
for(int i=1;i<=n;i++)rt[i]=update(rt[i-1],1,len,lower_bound(b+1,b+1+len,a[i])-b);
while(m--){
int l=read(),r=read(),k=read();
printf("%d\n",b[query(rt[l-1],rt[r],1,len,k)]);
}
}
Tarjan
求强连通分量
#include<bits/stdc++.h>
using namespace std;
const int N=2e5+10;
int n,dfn[N],low[N],to[N],nt[N],tot,h[N],cnt;
int t,siz[N];bool vis[N];
stack<int>q;
void add(int x,int y){
to[++tot]=y;
nt[tot]=h[x];
h[x]=tot;
}
void tarjan(int x){
dfn[x]=low[x]=++cnt;
vis[x]=1;q.push(x);
for(int i=h[x];i;i=nt[i]){
int y=to[i];
if(!dfn[y]){
tarjan(y);
low[x]=min(low[x],low[y]);
}else if(vis[x])low[x]=min(dfn[y],low[x]);
}
if(low[x]==dfn[x]){
t++;int y;
do{
y=q.top();
q.pop();
vis[y]=0;
siz[t]++;
}while(x^y);
}
}
int main(){
scanf("%d",&n);
for(int i=1;i<=n;i++){
int x;scanf("%d",&x);
add(i,x);
}tot=INT_MAX;
for(int i=1;i<=n;i++)if(!dfn[i])tarjan(i);
for(int i=1;i<=t;i++)tot=min(tot,siz[i]==1?INT_MAX:siz[i]);
printf("%d",tot);
}
求割点
#include<bits/stdc++.h>
using namespace std;
const int N=150;
int dfn[N],to[N<<1],nt[N<<1],h[N],tot;
int low[N],cut[N],n,m,cnt,root;
void add(int x,int y){
to[++tot]=y;
nt[tot]=h[x];
h[x]=tot;
}
void tarjan(int x){
dfn[x]=low[x]=++cnt;
int son=0;
for(int i=h[x];i;i=nt[i]){
int y=to[i];
if(!dfn[y]){
son++;
tarjan(y);
low[x]=min(low[x],low[y]);
if(dfn[x]<=low[y])
if(x!=root||son>1)cut[x]=1;
}else low[x]=min(low[x],dfn[y]);
}
}
int main(){
scanf("%d",&n);
int x,y;
while(scanf("%d%d",&x,&y)!=EOF){
add(x,y);add(y,x);
}
for(int i=1;i<=n;i++){
root=i;
fill(dfn+1,dfn+1+n,0);
fill(low+1,low+1+n,0);
tarjan(i);
}root=0;
for(int i=1;i<=n;i++)root+=cut[i];
printf("%d\n",root);
for(int i=1;i<=n;i++)
(cut[i])&&(printf("%d\n",i));
return 0;
}
高斯消元
慢一点,但是可以判断无解、无数解、唯一解。
高斯消元
#include<bits/stdc++.h>
using namespace std;
const int N=150;
double a[N][N],eps=1e-6,ans[N];
int n;
int solve(int n,int m){
int c=1,r=1;
for(;r<=n&&c<=m;r++,c++){
int mr=r;
for(int i=r+1;i<=n;i++){
if(abs(a[i][c])>abs(a[mr][c])){
mr=i;
}
}
if(mr!=r)swap(a[r],a[mr]);
if(fabs(a[r][c])<eps){
r--;
continue;
}
for(int i=r+1;i<=n;i++){
if(fabs(a[i][c])>eps){
double k=a[i][c]/a[r][c];
for(int j=c;j<=m+1;j++)a[i][j]-=a[r][j]*k;
a[i][c]=0;
}
}
}
for(int i=r;i<=m;i++){
if(fabs(a[i][c])>eps){
return -1;
//无解
}
}
if(r<=m)return m-r+1;
//有无数组解
for(int i=m;i;i--){
for(int j=i+1;j<=m;j++)a[i][m+1]-=a[i][j]*ans[j];
ans[i]=a[i][m+1]/a[i][i];
}
return 0;
}
int main(){
// freopen("1.in","r",stdin);
scanf("%d",&n);
for(int i=1;i<=n;i++){
for(int j=1;j<=n+1;j++){
scanf("%lf",&a[i][j]);
}
}
int cnt=solve(n,n);
if(cnt==-1)printf("-1");
else if(cnt)printf("0");
else{
for(int i=1;i<=n;i++){
// printf("%.2lf\n",ans[i]);
// if(fabs(ans[i])<eps)printf("x%d=0\n",i);
// else
printf("x%d=%.2lf\n",i,ans[i]);
}
}
return 0;
}
快一点,但是只能判断有无唯一解。
upd:实测并没有快多少