整体二分
引入
对于一堆询问,如果每个单独的询问都可以二分解决的话,时间复杂度为 \(O(NM\log N)\),但实际上每次二分都会有一些残留信息被我们扔掉,如果我们将所有询问一起二分,就可以最大时间的减小复杂度。
讲解
经典例题:区间第k大
给定一个序列 a 和一个整数 S,有 2 种操作:
1. 将 a 序列的第 k 个数变为 w
2. 查询区间[l, r]中有多少数小于等于 S
这个题可以用一个树状数组来维护,对于 \(a\) 序列,我们将所有小于等于 \(S\) 的数的位置在树状数组中 \(+1\),表示这个位置有一个小于等于 \(S\) 的数。
对于 1 操作,我们可以看作删除一个数再添加一个数,先看 \(a_k\) 是否大于 S,如果不大于则让这个位置 \(-1\),再看 \(w\) 是否大于 \(S\),如果小于等于就在这个位置 \(+1\)。
对于 2 操作,可以直接执行 ask(r) - ask(l - 1)
,即为这个范围内有多少数小于等于 \(S\)。
给定 a 序列,求第 k 小的数是几
这个问题也可以用二分来解决。每次二分一个 \(mid\),查询值域 \([l, r]\) 内有多少小于 \(mid\) 的数,记为 \(cnt\)。如果 \(cnt >= k\),可以在值域 \([l, mid]\) 中接着二分。如果 \(cnt < k\),可以令 k -= cnt,在值域 \([mid + 1, r]\) 中继续二分。
给定 a 序列,有 m 个询问,每次询问区间[l, r]中的第 k 小数
我们可以按照上面第二题的思路,每个询问进行一次二分,时间复杂度为 \(O(NM\log N)\),不能承受。
考虑每次二分值域,都会有一大部分信息被扔掉。所以我们应该对所有的询问全部进行二分。
算法流程如下:
- 值域达到边界,直接将当前的这些达到边界的询问的答案记录,返回即可。
- 二分出 mid,按照上面第一题的方法,用树状数组查找 \([l, r]\) 中不大于 \(mid\) 的数的个数,记为 \(cnt\)。
- 再应用第二题的方法,将 \(k <= cnt\) 的询问放到 lq 序列中, 将 \(k > cnt\) 的询问的 k -= cnt,再放到 rq 序列中。
- 递归二分求解 lq 和 rq 序列。
同时,为了简化代码,将序列转化为 \(n\) 个插入操作,具体实现可以看上面的第一题。
代码:
#include <bits/stdc++.h>
using namespace std;
const int N = 1e6 + 10, INF = 0x3f3f3f3f;
struct Q
{
int op, x, y, k;
}q[N], lq[N], rq[N];
int ans[N];
int tt;
int n, m;
vector<int> nums;
struct tree_array
{
int c[N];
#define lowbit(x) x & -x
inline void add(int x, int val)
{
for(; x <= n; x += lowbit(x)) c[x] += val;
}
inline int query(int x)
{
int res = 0;
for(; x; x -= lowbit(x)) res += c[x];
return res;
}
}bit;
void solve(int lval, int rval, int st, int ed)
{
if(st > ed) return;
if(lval == rval)
{
for(int i = st; i <= ed; i ++ )
if(q[i].op > 0) ans[q[i].op] = lval;
return;
}
int mid = lval + rval >> 1;
int lt = 0, rt = 0;
for(int i = st; i <= ed; i ++ )
{
if(q[i].op == 0)
{
if(q[i].y <= mid) bit.add(q[i].x, 1), lq[++ lt] = q[i];
else rq[++ rt] = q[i];
}
else
{
int l = q[i].x, r = q[i].y;
int cnt = bit.query(r) - bit.query(l - 1);
if(cnt >= q[i].k) lq[++ lt] = q[i];
else q[i].k -= cnt, rq[++ rt] = q[i];
}
}
for(int i = ed; i >= st; i -- )
if(q[i].op == 0 && q[i].y <= mid)
bit.add(q[i].x, -1);
for(int i = 1; i <= lt; i ++ ) q[st + i - 1] = lq[i];
for(int i = 1; i <= rt; i ++ ) q[st + lt + i - 1] = rq[i];
solve(lval, mid, st, st + lt - 1);
solve(mid + 1, rval, st + lt, ed);
}
int main()
{
n = read(), m = read();
memset(bit.c, 0, sizeof bit.c);
for(int i = 1; i <= n; i ++ )
{
int val;
scanf("%d", &val);
q[++ tt] = {0, i, val, 0};
}
for(int i = 1; i <= m; i ++ )
{
int l, r, k;
scanf("%d%d%d", &l, &r, &k);
q[++ tt] = {i, l, r, k};
}
solve(-INF, INF, 1, tt);
for(int i = 1; i <= m; i ++ )
printf("%d\n", ans[i]);
return 0;
}
扩展:带修区间第 k 大
当然可以用树套树在线做,但是也可以用整体二分,而且运行起来更加优秀。
将每个修改操作看作 2 种操作,和上面的第一题一样,小于 \(mid\) 的就 \(-1\),添加的数小于 \(mid\) 就 \(+1\)。
时间复杂度 \(O(N\log N)\)。
完整代码:
#include <bits/stdc++.h>
using namespace std;
const int N = 1e6 + 10, INF = 0x3f3f3f3f;
int n, m, tt, id;
struct tree_array
{
int c[N];
#define lowbit(x) x & -x
inline void add(int x, int val)
{
for(; x <= n; x += lowbit(x)) c[x] += val;
}
inline int query(int x)
{
int res = 0;
for(; x; x -= lowbit(x)) res += c[x];
return res;
}
} bit;
struct Q
{
int op, x, y, k;
}q[N], lq[N], rq[N];
int ans[N], a[N];
void solve(int lval, int rval, int st, int ed)
{
if(st > ed) return;
if(lval == rval)
{
for(int i = st; i <= ed; i ++ )
if(q[i].op > 0)
ans[q[i].op] = lval;
return;
}
int mid = lval + rval >> 1;
int lt = 0, rt = 0;
for(int i = st; i <= ed; i ++ )
{
if(q[i].op <= 0)
{
if(q[i].y <= mid) bit.add(q[i].x, q[i].k), lq[++ lt] = q[i];
else rq[++ rt] = q[i];
}
else
{
int l = q[i].x, r = q[i].y;
int cnt = bit.query(r) - bit.query(l - 1);
if(cnt >= q[i].k) lq[++ lt] = q[i];
else q[i].k -= cnt, rq[++ rt] = q[i];
}
}
for(int i = st; i <= ed; i ++ )
if(q[i].op <= 0 && q[i].y <= mid)
bit.add(q[i].x, -q[i].k);
for(int i = 1; i <= lt; i ++ ) q[i + st - 1] = lq[i];
for(int i = 1; i <= rt; i ++ ) q[i + lt + st - 1] = rq[i];
solve(lval, mid, st, st + lt - 1);
solve(mid + 1, rval, st + lt, ed);
}
int main()
{
n = read(), m = read();
for(int i = 1; i <= n; i ++ )
{
a[i] = read();
q[++ tt] = {0, i, a[i], 1};
}
for(int i = 1; i <= m; i ++ )
{
char op[5];
scanf("%s", op);
if(op[0] == 'Q')
{
int l = read(), r = read(), k = read();
q[++ tt] = {++ id, l, r, k};
}
else
{
int x = read(), y = read();
q[++ tt] = {-1, x, a[x], -1};
q[++ tt] = {0, x, y, 1};
a[x] = y;
}
}
solve(-INF, INF, 1, tt);
for(int i = 1; i <= id; i ++ )
printf("%d\n", ans[i]);
return 0;
}
练习
P1527 [国家集训队]矩阵乘法
本题维护一个二维树状数组,然后和区间第 k 大没有一点不同。
struct tree_array
{
int c[N][N];
#define lowbit(x) x & -x
inline void add(int x, int y, int val)
{
for(int i = x; i <= n; i += lowbit(i))
for(int j = y; j <= n; j += lowbit(j))
c[i][j] += val;
}
inline int query(int x, int y)
{
if(!x || !y) return 0;
int res = 0;
for(int i = x; i; i -= lowbit(i))
for(int j = y; j; j -= lowbit(j))
res += c[i][j];
return res;
}
inline int ask(int x1, int y1, int x2, int y2)
{
return query(x2, y2) - query(x1 - 1, y2) - query(x2, y1 - 1) + query(x1 - 1, y1 - 1);
}
} bit;
struct Q
{
int op, x1, y1, x2, y2, k;
}q[M], lq[M], rq[M];
int ans[M];
int tt;
void solve(int lval, int rval, int st, int ed)
{
if(st > ed) return;
if(lval == rval)
{
for(int i = st; i <= ed; i ++ )
if(q[i].op != 0)
ans[q[i].op] = lval;
return;
}
int mid = lval + rval >> 1;
int lt = 0, rt = 0;
for(int i = st; i <= ed; i ++ )
{
if(q[i].op == 0)
{
if(q[i].k <= mid) bit.add(q[i].x1, q[i].y1, 1), lq[++ lt] = q[i];
else rq[++ rt] = q[i];
}
else
{
int cnt = bit.ask(q[i].x1, q[i].y1, q[i].x2, q[i].y2);
if(cnt >= q[i].k) lq[++ lt] = q[i];
else q[i].k -= cnt, rq[++ rt] = q[i];
}
}
for(int i = st; i <= ed; i ++ )
if(q[i].op == 0 && q[i].k <= mid)
bit.add(q[i].x1, q[i].y1, -1);
for(int i = 1; i <= lt; i ++ ) q[i + st - 1] = lq[i];
for(int i = 1; i <= rt; i ++ ) q[st + lt + i - 1] = rq[i];
solve(lval, mid, st, st + lt - 1);
solve(mid + 1, rval, st + lt, ed);
}
P3527 [POI2011]MET-Meteors
很明显每个询问都能二分求解,时间长的肯定有更大概率收集全。
注意 \(l < r\) 的情况,这种情况可以将数组开成两倍,统计时加上长度即可。
统计每个国家收集多少的时候可以用图论的方式统计。
#include <bits/stdc++.h>
using namespace std;
#define int long long
const int N = 1e6 + 10;
int n, m, k;
int c[N], p[N], ans[N];
struct tree_array
{
int c[N];
#define lowbit(x) x & -x
inline void add(int x, int val)
{
for(; x <= m * 2; x += lowbit(x)) c[x] += val;
}
inline int query(int x)
{
int res = 0;
for(; x; x -= lowbit(x)) res += c[x];
return res;
}
} bit;
struct C
{
int x, y, val;
}ch[N];
struct Q
{
int c, k, h;
}q[N], lq[N], rq[N];
int e[N], ne[N], idx;
void add(int a, int b)
{
e[++ idx] = b, ne[idx] = q[a].h, q[a].h = idx;
}
void solve(int lval, int rval, int st, int ed)
{
if(st > ed) return;
if(lval == rval)
{
for(int i = st; i <= ed; i ++ )
ans[q[i].c] = lval;
return;
}
int mid = lval + rval >> 1, lt = 0, rt = 0;
for(int i = lval; i <= mid; i ++ )
{
bit.add(ch[i].x, ch[i].val), bit.add(ch[i].y + 1, -ch[i].val);
}
for(int i = st; i <= ed; i ++ )
{
int cnt = 0;
for(int k = q[i].h; k && cnt <= q[i].k; k = ne[k])
{
int j = e[k];
cnt += bit.query(j) + bit.query(j + m);
}
if(cnt >= q[i].k) lq[++ lt] = q[i];
else q[i].k -= cnt, rq[++ rt] = q[i];
}
for(int i = lval; i <= mid; i ++ )
bit.add(ch[i].x, -ch[i].val), bit.add(ch[i].y + 1, ch[i].val);
for(int i = 1; i <= lt; i ++ ) q[i + st - 1] = lq[i];
for(int i = 1; i <= rt; i ++ ) q[i + st + lt - 1] = rq[i];
solve(lval, mid, st, st + lt - 1);
solve(mid + 1, rval, st + lt, ed);
}
signed main()
{
n = read(), m = read();
for(int i = 1; i <= m; i ++ )
{
c[i] = read();
add(c[i], i);
}
for(int i = 1; i <= n; i ++ )
{
q[i].k = read();
q[i].c = i;
}
k = read();
for(int i = 1; i <= k; i ++ )
{
int l = read(), r = read(), val = read();
if(r < l) r += m;
ch[i] = {l, r, val};
}
solve(1, k + 1, 1, n);
for(int i = 1; i <= n; i ++ )
if(ans[i] == k + 1) puts("NIE");
else printf("%d\n", ans[i]);
return 0;
}
P4602 [CTSC2018] 混合果汁
本题二分也很好想出来,唯一的难点在于怎么快速回答每个询问。
维护一个权值线段树,下标存储价格,内部存储物品的个数和总价格。
一开始先按美味值排序,维护一直到 mid 的前缀中的物品即可。
#include <bits/stdc++.h>
using namespace std;
#define int long long
const int N = 1e6 + 10;
int n, m, cur;
int ans[N];
struct segment
{
int v, sum;
}t[N << 2];
inline void pushup(int p)
{
t[p].v = t[p << 1].v + t[p << 1 | 1].v;
t[p].sum = t[p << 1].sum + t[p << 1 | 1].sum;
}
void change(int p, int l, int r, int pos, int v)
{
if(l == r)
{
t[p].v += v;
t[p].sum = l * t[p].v;
return;
}
int mid = l + r >> 1;
if(pos <= mid) change(p << 1, l, mid, pos, v);
else change(p << 1 | 1, mid + 1, r, pos, v);
pushup(p);
}
int query(int p, int l, int r, int v)
{
if(!v) return 0;
if(l == r) return l * v;
int mid = l + r >> 1;
if(t[p << 1].v >= v) return query(p << 1, l, mid, v);
else return t[p << 1].sum + query(p << 1 | 1, mid + 1, r, v - t[p << 1].v);
}
int query1(int p, int l, int r, int pos)
{
if(l == r) return t[p].sum;
int mid = l + r >> 1;
if(pos <= mid) return query1(p << 1, l, mid, pos);
else return query1(p << 1 | 1, mid + 1, r, pos);
}
struct data
{
int d, p, l;
bool operator<(const data &D) const
{
return d > D.d;
}
} a[N];
struct Q
{
int id, g, l;
} q[N], lq[N], rq[N];
void solve(int lval, int rval, int st, int ed)
{
if (st > ed || lval > rval)
return;
if (lval == rval)
{
for (int i = st; i <= ed; i ++)
ans[q[i].id] = a[lval].d;
return;
}
int mid = lval + rval >> 1;
while(cur < mid)
cur ++, change(1, 1, N - 1, a[cur].p, a[cur].l);
while(cur > mid)
change(1, 1, N - 1, a[cur].p, -a[cur].l), cur --;
int lt = 0, rt = 0;
for(int i = st; i <= ed; i ++ )
{
if(q[i].l > t[1].v) rq[++ rt] = q[i];
else if(query(1, 1, N - 1, q[i].l) <= q[i].g) lq[++ lt] = q[i];
else rq[++ rt] = q[i];
}
for(int i = 1; i <= lt; i ++ ) q[i + st - 1] = lq[i];
for(int i = 1; i <= rt; i ++ ) q[i + st + lt - 1] = rq[i];
solve(lval, mid, st, st + lt - 1);
solve(mid + 1, rval, st + lt, ed);
}
signed main()
{
n = read(), m = read();
for (int i = 1; i <= n; i++)
{
a[i].d = read(), a[i].p = read(), a[i].l = read();
}
a[++ n] = {-1, 0, 0x3f3f3f3f};
sort(a + 1, a + n + 1);
for (int i = 1; i <= m; i++)
{
int D = read(), L = read();
q[i] = {i, D, L};
}
solve(1, n, 1, m);
for (int i = 1; i <= m; i++)
printf("%d\n", ans[i]);
return 0;
}
标签:二分,cnt,int,mid,lval,笔记,st,学习
From: https://www.cnblogs.com/crimsonawa/p/17541929.html