思想
随机二分即随机在当前的二分区间内找出一个元素作为 \(mid\),并和普通二分一样收缩左右端点。
由于每次合法区间长度期望折半,于是复杂度仍然正确,\(O(\log n)\) 次收缩即可使区间中只有一个元素。
在元素容易比较,容易求排名,而难以根据排名求元素时可以考虑随机二分。
简单应用
数组找第 k 大
最简单的应用是从无序数组中找出第 \(k\) 大(即 nth_element
,不过 std::nth_element
用的好像是 Introselect)。
显然的想法是排序,但是随机二分可以做到更优的期望复杂度:
- 每次随机一个元素,并遍历当前数组,按照比当前元素 小/大/相等 分类,小的放左边,大的放右边。
- 此时求出了这个元素的排名,若大于 \(k\),则取出小的部分,递归进行,小于 \(k\) 同理。
或者说:快速排序,但是只对可能有答案的那一半递归。
每次随机需要遍历剩余的所有元素,由于每次随机后元素个数减半,期望时间复杂度可以写成 \(T(n) = T(\frac{n}{2}) + O(n) = O(n)\)。
连续子序列第 k 大
定义一个数组 \(A\) 小于数组 \(B\) 当且仅当对于两个可重集中出现次数不同的最小元素 \(x\),元素 \(x\) 在 \(A\) 中出现次数更多,给定一个长为 \(n\) 的序列 \(\{S\}\),求 \(\{S\}\) 的连续子序列第 \(k\) 大。
首先可以转化为 \(n\) 进制数的大小比较,此时值域来到了 \(n^{n}\),直接对值域二分就相当于暴力。不过注意到 \(k \le \dfrac{n(n-1)}{2}\),对 \(k\) 二分是可以接受的。原问题就是求第 \(k\) 大,于是不能直接按照排名去找 \(mid\)。不过注意到如果已知一个数组,它在所有连续子序列中的排名是好求的(固定右端点,左端点越大排名越大,双指针可以 \(O(n \log n)\) 求排名),同时也容易把其它连续子序列按 大于/等于/小于 分成三类,于是利用随机二分,每次在当前区间内随机选择一个连续子序列作为 \(mid\),求出排名检查并将其余可行连续子序列分类,选择一边继续二分即可。时间复杂度期望 \(O(n \log^2 n)\),常数很大。
点击查看代码
#include <bits/stdc++.h>
using namespace std;
using ll=long long;
constexpr int MAXN=1e5+10;
int n,a[MAXN],cnt[MAXN],liml[MAXN],limr[MAXN],pos[MAXN];
bool iseq[MAXN];
ll k;
auto calcRank(int l,int r,mt19937_64 &rng) {
set<int> st;
memset(cnt,0,sizeof(int)*(n+1));
memset(pos,0,sizeof(int)*(n+1));
memset(iseq,0,sizeof(bool)*(n+1));
for(int i=l;i<=r;++i) --cnt[a[i]],st.insert(a[i]);
ll ls=0,eq=0,grt=0,sum=0;
pair<int,int> rl,rg;
for(int i=1,j=1;i<=n;++i) {
if(!cnt[a[i]]) st.insert(a[i]);
++cnt[a[i]];
if(!cnt[a[i]]) st.erase(a[i]);
while(j<=i&&!st.empty()&&cnt[*st.begin()]>0) {
if(!cnt[a[j]]) st.insert(a[j]);
--cnt[a[j]];
if(!cnt[a[j]]) st.erase(a[j]);
++j;
}
sum+=limr[i]-liml[i]+1;
if(j>liml[i]) {
int rnd=uniform_int_distribution<ll>(1,ls+j-liml[i])(rng);
if(rnd>ls) rl={liml[i]+rnd-ls-1,i};
}
if(sum-ls-eq>grt) {
int rnd=uniform_int_distribution<ll>(1,sum-ls-eq)(rng);
if(rnd>grt) rg={limr[i]-(rnd-grt)+1,i};
}
pos[i]=j; ls+=max(0,j-liml[i]);
if(st.empty()) ++eq,iseq[i]=true;
grt=sum-ls-eq;
}
return make_tuple(ls,eq,rl,rg);
}
auto kth(ll k) {
static mt19937_64 rng(random_device{}());
fill(liml+1,liml+n+1,1); iota(limr+1,limr+n+1,1);
for(int l=uniform_int_distribution<int>(1,n)(rng),r=uniform_int_distribution<int>(1,n)(rng);;) {
if(l>r) swap(l,r);
auto [ls,eq,rl,rg]=calcRank(l,r,rng);
if(ls<k&&k<=ls+eq) {
vector<int> res;
for(int i=l;i<=r;++i) res.emplace_back(a[i]);
sort(res.begin(),res.end());
return res;
}
if(k<=ls) {
tie(l,r)=rl;
for(int i=1;i<=n;++i) limr[i]=min(limr[i],pos[i]-1);
} else {
tie(l,r)=rg; k-=ls;
for(int i=1;i<=n;++i) liml[i]=max(liml[i],pos[i]);
}
}
abort();
}
int main() {
ios::sync_with_stdio(false); cin.tie(nullptr);
cin>>n>>k;
for(int i=1;i<=n;++i) cin>>a[i];
auto ans=kth(k);
for(auto v:ans) cout<<v<<" ";
cout<<endl;
return 0;
}
不过这题好像有不依赖随机化的做法(利用 Segment Tree Beats),而且跑得很快。但是我不会
例题
[PA2019] Podatki drogowe
仍然是 \(n\) 进制比大小,考虑类似上面的随机二分。这题在树上,权值是路径权值和,于是考虑先做一次点分治,用可持久化线段树维护每个分治中心到这一层所有点的路径上的权值(以及后缀哈希,比较要用),于是所有路径可以通过合并同一分治中心且不在同一方向上的两个线段树得到,合并是简单的权值相加。先把每个分治中心的所有线段树排个序,之后容易通过双指针在 \(O(n \log^2 n)\) 的时间内求出任意 \(n\) 进制数的排名,小改一下容易加上左右端点的限制求排名。套上随机二分,时间复杂度为 \(O(n \log^3 n)\),有点难写。
点击查看代码
#include <bits/stdc++.h>
using namespace std;
using ll=long long;
using ull=unsigned long long;
constexpr int MAXN=2.5e4+10,inf=0x3f3f3f3f,mod=1e9+7;
struct Hasher {
static constexpr ull b=1e9+7;
ull pw[MAXN];
Hasher(int n=MAXN-1) {
pw[0]=1;
for(int i=1;i<=n;++i) pw[i]=pw[i-1]*b;
}
ull append(ull lval,int v) {return lval*b+v;}
ull merge(ull lval,ull rval,int rlen) {
return lval*pw[rlen]+rval;
}
}hs;
struct Node {
int l,r,val; ull hash;
Node() {l=r=val=0; hash=0ull;}
};
struct PersistentSegTree {
static_assert(__cplusplus>=201703L);
vector<Node> t; int tot;
PersistentSegTree() {tot=0; t.clear(); t.emplace_back();}
int create() {t.emplace_back(); return t.size()-1;}
int clone(int k) {t.emplace_back(t[k]); return t.size()-1;}
void pushup(int k,int l,int r,int mid) {
int ls=t[k].l,rs=t[k].r;
t[k].hash=hs.merge(t[ls].hash,t[rs].hash,r-mid);
}
void modify(int k,int l,int r,int x,int v) {
if(l==r) {t[k].val+=v; t[k].hash=t[k].val; return;}
int mid=(l+r)>>1;
if(x<=mid) modify(t[k].l=clone(t[k].l),l,mid,x,v);
else modify(t[k].r=clone(t[k].r),mid+1,r,x,v);
pushup(k,l,r,mid);
}
int cmp(int k1,int k2,int k3,int l,int r) {
if(t[k1].hash+t[k2].hash==t[k3].hash) return 0;
if(l==r) {
if(t[k1].val+t[k2].val>t[k3].val) return 1;
else if(t[k1].val+t[k2].val==t[k3].val) return 0;
else return -1;
}
int mid=(l+r)>>1,r1=t[k1].r,r2=t[k2].r,r3=t[k3].r;
if(t[r1].hash+t[r2].hash==t[r3].hash) return cmp(t[k1].l,t[k2].l,t[k3].l,l,mid);
return cmp(t[k1].r,t[k2].r,t[k3].r,mid+1,r);
}
void queryAll(int k,int l,int r,vector<int> &res) {
if(l==r) {res.emplace_back(t[k].val); return;}
int mid=(l+r)>>1;
queryAll(t[k].l,l,mid,res); queryAll(t[k].r,mid+1,r,res);
}
}t;
vector<pair<int,int>> g[MAXN],rt[MAXN];
vector<int> posj[MAXN],posk[MAXN],posl[MAXN],posr[MAXN],cj[MAXN],ck[MAXN],cl[MAXN],cr[MAXN];
int siz[MAXN],minv,drt,n,maxp;
bool vis[MAXN];
ll k;
void getRoot(int u,int fa,int tot) {
siz[u]=1; int cur=0;
for(auto [v,w]:g[u]) {
if(v==fa||vis[v]) continue;
getRoot(v,u,tot); siz[u]+=siz[v]; cur=max(cur,siz[v]);
}
cur=max(cur,tot-siz[u]);
if(cur<minv) minv=cur,drt=u;
}
void addP(int u,int fa,int lstr,int lstw,int drt,int dir) {
int p=t.clone(lstr);
t.modify(p,1,n,lstw,1); rt[drt].emplace_back(p,dir);
for(auto [v,w]:g[u]) if(v!=fa&&!vis[v]) addP(v,u,p,w,drt,dir);
}
void divide(int u) {
vis[u]=true;
for(auto [v,w]:g[u]) if(!vis[v]) addP(v,u,0,w,u,v);
for(auto [v,w]:g[u]) {
if(vis[v]) continue;
minv=inf; drt=0; getRoot(v,u,siz[v]); getRoot(drt,u,siz[v]); divide(drt);
}
}
void initD() {
minv=inf; getRoot(1,0,n); getRoot(drt,0,n);
divide(drt);
for(int u=1;u<=n;++u) {
sort(rt[u].begin(),rt[u].end(),[&](pair<int,int> x,pair<int,int> y) {
int res=t.cmp(x.first,0,y.first,1,n);
if(res<0) return true;
if(res>0) return false;
return x.second<y.second;
});
}
}
int mergeRoot(int p1,int p2) {
vector<int> v1,v2; t.queryAll(p1,1,n,v1); t.queryAll(p2,1,n,v2);
int res=t.create();
for(int i=0;i<int(v1.size());++i) if(v1[i]+v2[i]) t.modify(res,1,n,i+1,v1[i]+v2[i]);
return res;
}
auto getRank(int p,int pl,int pr) {
static int cntj[MAXN],cntk[MAXN],cntl[MAXN],cntr[MAXN];
ll ls=0,eq=0,grt=0;
for(int u=1;u<=n;++u) {
posj[u].assign(rt[u].size(),n); posk[u].assign(rt[u].size(),n); posl[u].assign(rt[u].size(),n); posr[u].assign(rt[u].size(),n);
cj[u].assign(rt[u].size(),n); ck[u].assign(rt[u].size(),n); cl[u].assign(rt[u].size(),n); cr[u].assign(rt[u].size(),n);
for(int i=int(rt[u].size())-1,j=0,k=0,l=0,r=0;i>=0;--i) {
while(j>i) --j,--cntj[rt[u][j].second];
while(k>i) --k,--cntk[rt[u][k].second];
while(l>i) --l,--cntl[rt[u][l].second];
while(r>i) --r,--cntr[rt[u][r].second];
while(r<i&&t.cmp(rt[u][i].first,rt[u][r].first,pr,1,n)<0) ++cntr[rt[u][r].second],++r;
while(l<r&&t.cmp(rt[u][i].first,rt[u][l].first,pl,1,n)<=0) ++cntl[rt[u][l].second],++l;
while(j<r&&t.cmp(rt[u][i].first,rt[u][j].first,p,1,n)<0) ++cntj[rt[u][j].second],++j;
while(k<r&&t.cmp(rt[u][i].first,rt[u][k].first,p,1,n)<=0) ++cntk[rt[u][k].second],++k;
posj[u][i]=j; posk[u][i]=k; posl[u][i]=l; posr[u][i]=r;
cj[u][i]=cntj[rt[u][i].second]; ck[u][i]=cntk[rt[u][i].second];
cl[u][i]=cntl[rt[u][i].second]; cr[u][i]=cntr[rt[u][i].second];
ls+=max(0,(j-cntj[rt[u][i].second])-(l-cntl[rt[u][i].second]));
eq+=max(0,(k-cntk[rt[u][i].second])-(j-cntj[rt[u][i].second]));
grt+=max(0,(r-cntr[rt[u][i].second])-(k-cntk[rt[u][i].second]));
}
for(auto [r,d]:rt[u]) {
int val=t.cmp(r,0,p,1,n);
if(t.cmp(r,0,pl,1,n)<=0||t.cmp(r,0,pr,1,n)>=0) continue;
if(val<0) ++ls;
else if(val==0) ++eq;
else ++grt;
}
for(auto [r,d]:rt[u]) cntj[d]=cntk[d]=cntl[d]=cntr[d]=0;
}
return make_tuple(ls,eq,grt);
}
int kth(ll k) {
static mt19937 rng(random_device{}());
int tot=0,p=-1;
for(int u=0;u<=n;++u) tot+=rt[u].size();
tot=uniform_int_distribution<int>(0,tot-1)(rng);
for(int u=0;u<=n;++u) {
if(tot<int(rt[u].size())) {p=rt[u][tot].first; break;}
tot-=rt[u].size();
}
int pl=0,pr=maxp;
while(true) {
auto [ls,eq,grt]=getRank(p,pl,pr);
if(ls<k&&k<=ls+eq) {
vector<int> res; t.queryAll(p,1,n,res);
int pw=n,ans=0;
for(auto v:res) ans=(ans+1ll*pw*v%mod)%mod,pw=1ll*pw*n%mod;
return ans;
}
if(k<=ls) {
int nxt=uniform_int_distribution<int>(1,ls)(rng),nxtp=-1;
for(int u=1;u<=n;++u) {
for(int i=int(rt[u].size())-1;i>=0;--i) {
int cur=max(0,(posj[u][i]-cj[u][i])-(posl[u][i]-cl[u][i]));
if(nxt<=cur) {
vector<int> val;
for(int j=0;j<i;++j) {
if(t.cmp(rt[u][i].first,rt[u][j].first,pl,1,n)<=0) continue;
if(t.cmp(rt[u][i].first,rt[u][j].first,pr,1,n)>=0) continue;
if(t.cmp(rt[u][i].first,rt[u][j].first,p,1,n)<0) val.emplace_back(rt[u][j].first);
}
nxtp=mergeRoot(rt[u][i].first,val[nxt-1]);
break;
}
nxt-=cur;
}
if(nxtp!=-1) break;
for(auto [r,d]:rt[u]) {
int val=t.cmp(r,0,p,1,n);
if(t.cmp(r,0,pl,1,n)<=0||t.cmp(r,0,pr,1,n)>=0) continue;
if(val<0) {
if(!--nxt) {nxtp=r; break;}
}
}
if(nxtp!=-1) break;
}
pr=p; p=nxtp;
} else {
int nxt=uniform_int_distribution<int>(1,grt)(rng),nxtp=-1;
for(int u=1;u<=n;++u) {
for(int i=int(rt[u].size())-1;i>=0;--i) {
int cur=max(0,(posr[u][i]-cr[u][i])-(posk[u][i]-ck[u][i]));
if(nxt<=cur) {
vector<int> val;
for(int j=0;j<i;++j) {
if(rt[u][j].second==rt[u][i].second) continue;
if(t.cmp(rt[u][i].first,rt[u][j].first,pl,1,n)<=0) continue;
if(t.cmp(rt[u][i].first,rt[u][j].first,pr,1,n)>=0) continue;
if(t.cmp(rt[u][i].first,rt[u][j].first,p,1,n)>0) val.emplace_back(rt[u][j].first);
}
nxtp=mergeRoot(rt[u][i].first,val[nxt-1]);
break;
}
nxt-=cur;
}
if(nxtp!=-1) break;
for(auto [r,d]:rt[u]) {
int val=t.cmp(r,0,p,1,n);
if(t.cmp(r,0,pl,1,n)<=0||t.cmp(r,0,pr,1,n)>=0) continue;
if(val>0) {
if(!--nxt) {nxtp=r; break;}
}
}
if(nxtp!=-1) break;
}
pl=p; p=nxtp; k-=ls+eq;
}
}
abort();
return -1;
}
int main() {
ios::sync_with_stdio(false); cin.tie(nullptr);
cin>>n>>k;
for(int i=1;i<n;++i) {
int u,v,w; cin>>u>>v>>w; g[u].emplace_back(v,w); g[v].emplace_back(u,w);
}
initD(); maxp=t.create(); t.modify(maxp,1,n,n,n+1);
cout<<kth(k)<<endl;
return 0;
}