首页 > 其他分享 >随机二分

随机二分

时间:2024-01-31 10:35:13浏览次数:20  
标签:二分 return val int -- MAXN 随机 ls

思想

随机二分即随机在当前的二分区间内找出一个元素作为 \(mid\),并和普通二分一样收缩左右端点。

由于每次合法区间长度期望折半,于是复杂度仍然正确,\(O(\log n)\) 次收缩即可使区间中只有一个元素。

在元素容易比较,容易求排名,而难以根据排名求元素时可以考虑随机二分。

简单应用

数组找第 k 大

最简单的应用是从无序数组中找出第 \(k\) 大(即 nth_element,不过 std::nth_element 用的好像是 Introselect)。

显然的想法是排序,但是随机二分可以做到更优的期望复杂度:

  • 每次随机一个元素,并遍历当前数组,按照比当前元素 小/大/相等 分类,小的放左边,大的放右边。
  • 此时求出了这个元素的排名,若大于 \(k\),则取出小的部分,递归进行,小于 \(k\) 同理。

或者说:快速排序,但是只对可能有答案的那一半递归。

每次随机需要遍历剩余的所有元素,由于每次随机后元素个数减半,期望时间复杂度可以写成 \(T(n) = T(\frac{n}{2}) + O(n) = O(n)\)。

连续子序列第 k 大

定义一个数组 \(A\) 小于数组 \(B\) 当且仅当对于两个可重集中出现次数不同的最小元素 \(x\),元素 \(x\) 在 \(A\) 中出现次数更多,给定一个长为 \(n\) 的序列 \(\{S\}\),求 \(\{S\}\) 的连续子序列第 \(k\) 大。

首先可以转化为 \(n\) 进制数的大小比较,此时值域来到了 \(n^{n}\),直接对值域二分就相当于暴力。不过注意到 \(k \le \dfrac{n(n-1)}{2}\),对 \(k\) 二分是可以接受的。原问题就是求第 \(k\) 大,于是不能直接按照排名去找 \(mid\)。不过注意到如果已知一个数组,它在所有连续子序列中的排名是好求的(固定右端点,左端点越大排名越大,双指针可以 \(O(n \log n)\) 求排名),同时也容易把其它连续子序列按 大于/等于/小于 分成三类,于是利用随机二分,每次在当前区间内随机选择一个连续子序列作为 \(mid\),求出排名检查并将其余可行连续子序列分类,选择一边继续二分即可。时间复杂度期望 \(O(n \log^2 n)\),常数很大。

点击查看代码
#include <bits/stdc++.h>

using namespace std;
using ll=long long;

constexpr int MAXN=1e5+10;

int n,a[MAXN],cnt[MAXN],liml[MAXN],limr[MAXN],pos[MAXN];
bool iseq[MAXN];
ll k;

auto calcRank(int l,int r,mt19937_64 &rng) {
    set<int> st;
    memset(cnt,0,sizeof(int)*(n+1));
    memset(pos,0,sizeof(int)*(n+1));
    memset(iseq,0,sizeof(bool)*(n+1));
    for(int i=l;i<=r;++i) --cnt[a[i]],st.insert(a[i]);
    ll ls=0,eq=0,grt=0,sum=0;
    pair<int,int> rl,rg;
    for(int i=1,j=1;i<=n;++i) {
        if(!cnt[a[i]]) st.insert(a[i]);
        ++cnt[a[i]];
        if(!cnt[a[i]]) st.erase(a[i]);
        while(j<=i&&!st.empty()&&cnt[*st.begin()]>0) {
            if(!cnt[a[j]]) st.insert(a[j]);
            --cnt[a[j]];
            if(!cnt[a[j]]) st.erase(a[j]);
            ++j;
        }
        sum+=limr[i]-liml[i]+1;
        if(j>liml[i]) {
            int rnd=uniform_int_distribution<ll>(1,ls+j-liml[i])(rng);
            if(rnd>ls) rl={liml[i]+rnd-ls-1,i};
        }
        if(sum-ls-eq>grt) {
            int rnd=uniform_int_distribution<ll>(1,sum-ls-eq)(rng);
            if(rnd>grt) rg={limr[i]-(rnd-grt)+1,i};
        }
        pos[i]=j; ls+=max(0,j-liml[i]);
        if(st.empty()) ++eq,iseq[i]=true;
        grt=sum-ls-eq;
    }
    return make_tuple(ls,eq,rl,rg);
}

auto kth(ll k) {
    static mt19937_64 rng(random_device{}());
    fill(liml+1,liml+n+1,1); iota(limr+1,limr+n+1,1);
    for(int l=uniform_int_distribution<int>(1,n)(rng),r=uniform_int_distribution<int>(1,n)(rng);;) {
        if(l>r) swap(l,r);
        auto [ls,eq,rl,rg]=calcRank(l,r,rng);
        if(ls<k&&k<=ls+eq) {
            vector<int> res;
            for(int i=l;i<=r;++i) res.emplace_back(a[i]);
            sort(res.begin(),res.end());
            return res;
        }
        if(k<=ls) {
            tie(l,r)=rl;
            for(int i=1;i<=n;++i) limr[i]=min(limr[i],pos[i]-1);
        } else {
            tie(l,r)=rg; k-=ls;
            for(int i=1;i<=n;++i) liml[i]=max(liml[i],pos[i]);
        }
    }
    abort();
}

int main() {
    ios::sync_with_stdio(false); cin.tie(nullptr);

    cin>>n>>k;
    for(int i=1;i<=n;++i) cin>>a[i];
    auto ans=kth(k);
    for(auto v:ans) cout<<v<<" ";
    cout<<endl;

    return 0;
}

不过这题好像有不依赖随机化的做法(利用 Segment Tree Beats),而且跑得很快。但是我不会

例题

[PA2019] Podatki drogowe

仍然是 \(n\) 进制比大小,考虑类似上面的随机二分。这题在树上,权值是路径权值和,于是考虑先做一次点分治,用可持久化线段树维护每个分治中心到这一层所有点的路径上的权值(以及后缀哈希,比较要用),于是所有路径可以通过合并同一分治中心且不在同一方向上的两个线段树得到,合并是简单的权值相加。先把每个分治中心的所有线段树排个序,之后容易通过双指针在 \(O(n \log^2 n)\) 的时间内求出任意 \(n\) 进制数的排名,小改一下容易加上左右端点的限制求排名。套上随机二分,时间复杂度为 \(O(n \log^3 n)\),有点难写。

点击查看代码
#include <bits/stdc++.h>

using namespace std;
using ll=long long;
using ull=unsigned long long;

constexpr int MAXN=2.5e4+10,inf=0x3f3f3f3f,mod=1e9+7;

struct Hasher {
    static constexpr ull b=1e9+7;
    ull pw[MAXN];
    Hasher(int n=MAXN-1) {
        pw[0]=1;
        for(int i=1;i<=n;++i) pw[i]=pw[i-1]*b;
    }
    ull append(ull lval,int v) {return lval*b+v;}
    ull merge(ull lval,ull rval,int rlen) {
        return lval*pw[rlen]+rval;
    }
}hs;

struct Node {
    int l,r,val; ull hash;
    Node() {l=r=val=0; hash=0ull;}
};

struct PersistentSegTree {
    static_assert(__cplusplus>=201703L);
    vector<Node> t; int tot;
    PersistentSegTree() {tot=0; t.clear(); t.emplace_back();}
    int create() {t.emplace_back(); return t.size()-1;}
    int clone(int k) {t.emplace_back(t[k]); return t.size()-1;}
    void pushup(int k,int l,int r,int mid) {
        int ls=t[k].l,rs=t[k].r;
        t[k].hash=hs.merge(t[ls].hash,t[rs].hash,r-mid);
    }
    void modify(int k,int l,int r,int x,int v) {
        if(l==r) {t[k].val+=v; t[k].hash=t[k].val; return;}
        int mid=(l+r)>>1;
        if(x<=mid) modify(t[k].l=clone(t[k].l),l,mid,x,v);
        else modify(t[k].r=clone(t[k].r),mid+1,r,x,v);
        pushup(k,l,r,mid);
    }
    int cmp(int k1,int k2,int k3,int l,int r) {
        if(t[k1].hash+t[k2].hash==t[k3].hash) return 0;
        if(l==r) {
            if(t[k1].val+t[k2].val>t[k3].val) return 1;
            else if(t[k1].val+t[k2].val==t[k3].val) return 0;
            else return -1;
        }
        int mid=(l+r)>>1,r1=t[k1].r,r2=t[k2].r,r3=t[k3].r;
        if(t[r1].hash+t[r2].hash==t[r3].hash) return cmp(t[k1].l,t[k2].l,t[k3].l,l,mid);
        return cmp(t[k1].r,t[k2].r,t[k3].r,mid+1,r);
    }
    void queryAll(int k,int l,int r,vector<int> &res) {
        if(l==r) {res.emplace_back(t[k].val); return;}
        int mid=(l+r)>>1;
        queryAll(t[k].l,l,mid,res); queryAll(t[k].r,mid+1,r,res);
    }
}t;

vector<pair<int,int>> g[MAXN],rt[MAXN];
vector<int> posj[MAXN],posk[MAXN],posl[MAXN],posr[MAXN],cj[MAXN],ck[MAXN],cl[MAXN],cr[MAXN];
int siz[MAXN],minv,drt,n,maxp;
bool vis[MAXN];
ll k;

void getRoot(int u,int fa,int tot) {
    siz[u]=1; int cur=0;
    for(auto [v,w]:g[u]) {
        if(v==fa||vis[v]) continue;
        getRoot(v,u,tot); siz[u]+=siz[v]; cur=max(cur,siz[v]);
    }
    cur=max(cur,tot-siz[u]);
    if(cur<minv) minv=cur,drt=u;
}

void addP(int u,int fa,int lstr,int lstw,int drt,int dir) {
    int p=t.clone(lstr);
    t.modify(p,1,n,lstw,1); rt[drt].emplace_back(p,dir);
    for(auto [v,w]:g[u]) if(v!=fa&&!vis[v]) addP(v,u,p,w,drt,dir);
}

void divide(int u) {
    vis[u]=true;
    for(auto [v,w]:g[u]) if(!vis[v]) addP(v,u,0,w,u,v);
    for(auto [v,w]:g[u]) {
        if(vis[v]) continue;
        minv=inf; drt=0; getRoot(v,u,siz[v]); getRoot(drt,u,siz[v]); divide(drt);
    }
}

void initD() {
    minv=inf; getRoot(1,0,n); getRoot(drt,0,n);
    divide(drt);

    for(int u=1;u<=n;++u) {
        sort(rt[u].begin(),rt[u].end(),[&](pair<int,int> x,pair<int,int> y) {
            int res=t.cmp(x.first,0,y.first,1,n);
            if(res<0) return true;
            if(res>0) return false;
            return x.second<y.second;
        });
    }
}

int mergeRoot(int p1,int p2) {
    vector<int> v1,v2; t.queryAll(p1,1,n,v1); t.queryAll(p2,1,n,v2);
    int res=t.create();
    for(int i=0;i<int(v1.size());++i) if(v1[i]+v2[i]) t.modify(res,1,n,i+1,v1[i]+v2[i]);
    return res;
}

auto getRank(int p,int pl,int pr) {
    static int cntj[MAXN],cntk[MAXN],cntl[MAXN],cntr[MAXN];
    ll ls=0,eq=0,grt=0;
    for(int u=1;u<=n;++u) {
        posj[u].assign(rt[u].size(),n); posk[u].assign(rt[u].size(),n); posl[u].assign(rt[u].size(),n); posr[u].assign(rt[u].size(),n);
        cj[u].assign(rt[u].size(),n); ck[u].assign(rt[u].size(),n); cl[u].assign(rt[u].size(),n); cr[u].assign(rt[u].size(),n);
        for(int i=int(rt[u].size())-1,j=0,k=0,l=0,r=0;i>=0;--i) {
            while(j>i) --j,--cntj[rt[u][j].second];
            while(k>i) --k,--cntk[rt[u][k].second];
            while(l>i) --l,--cntl[rt[u][l].second];
            while(r>i) --r,--cntr[rt[u][r].second];
            while(r<i&&t.cmp(rt[u][i].first,rt[u][r].first,pr,1,n)<0) ++cntr[rt[u][r].second],++r;
            while(l<r&&t.cmp(rt[u][i].first,rt[u][l].first,pl,1,n)<=0) ++cntl[rt[u][l].second],++l;
            while(j<r&&t.cmp(rt[u][i].first,rt[u][j].first,p,1,n)<0) ++cntj[rt[u][j].second],++j;
            while(k<r&&t.cmp(rt[u][i].first,rt[u][k].first,p,1,n)<=0) ++cntk[rt[u][k].second],++k;
            posj[u][i]=j; posk[u][i]=k; posl[u][i]=l; posr[u][i]=r;
            cj[u][i]=cntj[rt[u][i].second]; ck[u][i]=cntk[rt[u][i].second];
            cl[u][i]=cntl[rt[u][i].second]; cr[u][i]=cntr[rt[u][i].second];
            ls+=max(0,(j-cntj[rt[u][i].second])-(l-cntl[rt[u][i].second]));
            eq+=max(0,(k-cntk[rt[u][i].second])-(j-cntj[rt[u][i].second]));
            grt+=max(0,(r-cntr[rt[u][i].second])-(k-cntk[rt[u][i].second]));
        }
        for(auto [r,d]:rt[u]) {
            int val=t.cmp(r,0,p,1,n);
            if(t.cmp(r,0,pl,1,n)<=0||t.cmp(r,0,pr,1,n)>=0) continue;
            if(val<0) ++ls;
            else if(val==0) ++eq;
            else ++grt;
        }
        for(auto [r,d]:rt[u]) cntj[d]=cntk[d]=cntl[d]=cntr[d]=0;
    }
    return make_tuple(ls,eq,grt);
}

int kth(ll k) {
    static mt19937 rng(random_device{}());
    int tot=0,p=-1;
    for(int u=0;u<=n;++u) tot+=rt[u].size();
    tot=uniform_int_distribution<int>(0,tot-1)(rng);
    for(int u=0;u<=n;++u) {
        if(tot<int(rt[u].size())) {p=rt[u][tot].first; break;}
        tot-=rt[u].size();
    }
    int pl=0,pr=maxp;
    while(true) {
        auto [ls,eq,grt]=getRank(p,pl,pr);

        if(ls<k&&k<=ls+eq) {
            vector<int> res; t.queryAll(p,1,n,res);
            int pw=n,ans=0;
            for(auto v:res) ans=(ans+1ll*pw*v%mod)%mod,pw=1ll*pw*n%mod;
            return ans;
        }
        if(k<=ls) {
            int nxt=uniform_int_distribution<int>(1,ls)(rng),nxtp=-1;
            for(int u=1;u<=n;++u) {
                for(int i=int(rt[u].size())-1;i>=0;--i) {
                    int cur=max(0,(posj[u][i]-cj[u][i])-(posl[u][i]-cl[u][i]));
                    if(nxt<=cur) {
                        vector<int> val;
                        for(int j=0;j<i;++j) {
                            if(t.cmp(rt[u][i].first,rt[u][j].first,pl,1,n)<=0) continue;
                            if(t.cmp(rt[u][i].first,rt[u][j].first,pr,1,n)>=0) continue;
                            if(t.cmp(rt[u][i].first,rt[u][j].first,p,1,n)<0) val.emplace_back(rt[u][j].first);
                        }
                        nxtp=mergeRoot(rt[u][i].first,val[nxt-1]);
                        break;
                    }
                    nxt-=cur;
                }
                if(nxtp!=-1) break;
                for(auto [r,d]:rt[u]) {
                    int val=t.cmp(r,0,p,1,n);
                    if(t.cmp(r,0,pl,1,n)<=0||t.cmp(r,0,pr,1,n)>=0) continue;
                    if(val<0) {
                        if(!--nxt) {nxtp=r; break;}
                    }
                }
                if(nxtp!=-1) break;
            }
            pr=p; p=nxtp;
        } else {
            int nxt=uniform_int_distribution<int>(1,grt)(rng),nxtp=-1;
            for(int u=1;u<=n;++u) {
                for(int i=int(rt[u].size())-1;i>=0;--i) {
                    int cur=max(0,(posr[u][i]-cr[u][i])-(posk[u][i]-ck[u][i]));
                    if(nxt<=cur) {
                        vector<int> val;
                        for(int j=0;j<i;++j) {
                            if(rt[u][j].second==rt[u][i].second) continue;
                            if(t.cmp(rt[u][i].first,rt[u][j].first,pl,1,n)<=0) continue;
                            if(t.cmp(rt[u][i].first,rt[u][j].first,pr,1,n)>=0) continue;
                            if(t.cmp(rt[u][i].first,rt[u][j].first,p,1,n)>0) val.emplace_back(rt[u][j].first);
                        }
                        nxtp=mergeRoot(rt[u][i].first,val[nxt-1]);
                        break;
                    }
                    nxt-=cur;
                }
                if(nxtp!=-1) break;
                for(auto [r,d]:rt[u]) {
                    int val=t.cmp(r,0,p,1,n);
                    if(t.cmp(r,0,pl,1,n)<=0||t.cmp(r,0,pr,1,n)>=0) continue;
                    if(val>0) {
                        if(!--nxt) {nxtp=r; break;}
                    }
                }
                if(nxtp!=-1) break;
            }
            pl=p; p=nxtp; k-=ls+eq;
        }
    }
    abort();
    return -1;
}

int main() {
    ios::sync_with_stdio(false); cin.tie(nullptr);

    cin>>n>>k;
    for(int i=1;i<n;++i) {
        int u,v,w; cin>>u>>v>>w; g[u].emplace_back(v,w); g[v].emplace_back(u,w);
    }
    initD(); maxp=t.create(); t.modify(maxp,1,n,n,n+1);

    cout<<kth(k)<<endl;

    return 0;
}

标签:二分,return,val,int,--,MAXN,随机,ls
From: https://www.cnblogs.com/xzmxzm/p/17998679/random_binary_search

相关文章

  • 二分查找(折半查找)
    二分查找的前提:数组中的数据必须是有序的核心思想:每次排除一半的数据,查询数据的性能明显提高很多实现步骤1.定义两个变量,一个代表左边位置,一个代表右边位置2.定义一个循环控制折半3.每次折半,都算出中间位置处的索引4.判断当前要找的元素值,与中间位置处的元素值的大小情况往......
  • AtCoder Beginner Contest 338 c题二分思路
    观察题目可知,会有一个最大的x(两个菜的最大制作数),大于这个x就不能做任何一盘菜,小于这个x那么一定可以做出来,这样分析就是显而易见的递归。实现递归的check函数,那么我们就可以把两个菜的总制作数传进去。那么什么时候true什么时候false呢,就是判断每种材料的制作量有没有超过原材料......
  • Leetcode刷题第五天-二分法-回溯
    215:第k个最大元素链接:215.数组中的第K个最大元素-力扣(LeetCode)em~~怎么说呢,快速选择,随机定一个目标值,开始找,左边比目标小,右边比目标大,左右同时不满足时,交换左右位置,直到左指针比右指针大,交换目标和右指针位置的值,此时右指针位置即时目标值的在排序好数组中的位置,如果k在右......
  • 二分算法
    二分算法个人感想洛谷二分题单基本完成,发现二分确实是比较模板的方式解答题目,难点往往是寻找出答案的单调性和如何高效验证答案的正确性。二分个人感觉就是枚举的优化,在时间复杂度上的极大优化,有一种暴力的美.目前发现的不足对题目的理解太浅,有时很难看懂题目的意思,理解有问......
  • P5400 [CTS2019] 随机立方体 题解
    题目链接点击打开链接题目解法参考cmd的博客好复杂的推式子题,而且三维的对我来说好难想象/ll首先二项式反演,把恰好\(k\)个变成求至少\(i\)个的方案数令极大格子有至少\(i\)个的方案数为\(f_i\),\(R=\min\{n,m,k\}\)特判掉\(k>R\)答案为\(0\)根据二项式反演,答案......
  • 【学习笔记】二分图
    1.定义一个二分图满足有一种划分方案使得它节点的被分为两部分,且所有边的端点所在的部分不相同。即每条边都连接两个部分。变量说明:没有特殊说明时,\(n\)表示a部分点数,\(m\)表示b部分点数,\(e\)表示边数。2.判定显然我们给二分图染色,确定一个点所有点都确定。如果在染的......
  • java中二分查找前提必须是升序吗?
    二分查找不必须是升序,降序排列的数组也可以执行二分查找。二分查找算法是一种高效的搜索方法,它要求数据集是有序的,无论是升序还是降序都可以。在升序排列的情况下,算法会将目标值与中间值比较,如果目标值较小,则在左半部分继续查找;如果目标值较大,则在右半部分继续查找。在降序排列的......
  • net8 随机数类Random GetItems() 、Shuffle()方法
    1、在8中对随机数类Random提供了GetItems()方法,可以根据指定的数量在提供的一个集合中随机抽取数据项生成一个新的集合:ReadOnlySpan<string>colors=new[]{"Red","Green","Blue","Black"};string[]t1=Random.Shared.GetItems(colors,10);Console.WriteLine(......
  • 洛谷题单指南-排序-P1059 [NOIP2006 普及组] 明明的随机数
    原题链接:https://www.luogu.com.cn/problem/P1059题意解读:此题主要做两件事:排序+去重,用计数排序即可解决,直接给出代码。100分代码:#include<bits/stdc++.h>usingnamespacestd;constintN=1005;inta[N];intn;intmain(){cin>>n;intx;intcnt......
  • Java二分查找
    二分查找\789.数的范围给定一个按照升序排列的长度为n的整数数组,以及q个查询。对于每个查询,返回一个元素k的起始位置和终止位置(位置从0开始计数)。如果数组中不存在该元素,则返回-1-1。输入格式第一行包含整数n和q,表示数组长度和询问个数。第二行包含n个整数(均在1∼1......