Tokitsukaze and Min-Max XOR
题目描述
Tokitsukaze 有一个长度为 $n$ 的序列 $a_1,a_2,\ldots,a_n$ 和一个整数 $k$。
她想知道有多少种序列 $b_1,b_2,\ldots,b_m$,满足:
- $1 \leq b_i \leq n$
- $b_{i−1}<b_i$ $(2 \leq i \leq m)$
- $\min(a_{b_1} \, , a_{b_2} \, , \ldots, a_{b_m}) \oplus \max(a_{b_1} \, , a_{b_2} \, , \ldots, a_{b_m}) \leq k$
其中 $\oplus$ 为按位异或,具体参见 百度百科:异或
答案可能很大,请输出 $\bmod 10^9+7$ 后的结果。
输入描述:
第一行包含一个整数 $T$ ($1 \leq T \leq 2 \cdot 10^5$),表示 $T$ 组测试数据。
对于每组测试数据:
第一行包含两个整数 $n$, $k$ ($1 \leq n \leq 2 \cdot 10^5$; $0 \leq k \leq 10^9$)。
第二行包含 $n$ 个整数 $a_1,a_2,\ldots,a_n$ ($0 \leq a_i \leq 10^9$)。
保证 $\sum{n}$ 不超过 $2 \cdot 10^5$。
输出描述:
对于每组测试数据,输出一个整数,表示答案 $\bmod 10^9+7$ 后的结果。
示例1
输入
3
3 2
1 3 2
5 3
1 3 5 2 4
5 0
0 0 0 0 0
输出
6
10
31
说明
第一组测试数据,$k$ 为 $2$:
- 选择的序列 $b$ 为 $[1]$,$\min(a_1) \oplus \max(a_1)=1 \oplus 1=0 \leq 2$;
- 选择的序列 $b$ 为 $[2]$,$\min(a_2) \oplus \max(a_2)=3 \oplus 3=0 \leq 2$;
- 选择的序列 $b$ 为 $[3]$,$\min(a_3) \oplus \max(a_3)=2 \oplus 2=0 \leq 2$;
- 选择的序列 $b$ 为 $[1,2]$,$\min(a_1,a_2) \oplus \max(a_1,a_2)=1 \oplus 3=2 \leq 2$;
- 选择的序列 $b$ 为 $[2,3]$,$\min(a_2,a_3) \oplus \max(a_2,a_3)=2 \oplus 3=1 \leq 2$;
- 选择的序列 $b$ 为 $[1,2,3]$,$\min(a_1,a_2,a_3) \oplus \max(a_1,a_2,a_3)=1 \oplus 3=2 \leq 2$;
所以第一组测试数据的答案为 $6$ 。
解题思路
看了点提示就做出来了,解法和昨天想到的思路差不多一样,不过最后没多少时间写了。
容易知道 $b_1, \ldots, b_m$ 实际上是 $a$ 的一个子序列,并且由于我们只关注子序列中的最大值和最小值,因此可以先对 $a$ 从小到大排序,再选择子序列。接着对子序列中的最大值进行分类,可以分成 $n$ 类。即从左到右依次枚举 $a_i$ 作为子序列中的最大值,那么最小值就会在 $a_j, \, j \in [0, i]$ 中选。当满足 $a_i \oplus a_j \leq k$,那么以 $a_i$ 为最大值,$a_j$ 为最小值的子序列的数量就是 $2^{\max\{ 0,i-j-1 \}}$,特别的当 $i=j$ 时答案为 $1$。
暴力的做法就是逐个枚举 $a_j$ 判断是否满足条件,时间复杂度是 $O(n^2)$ 的。由于涉及到异或运算所以尝试能不能用 trie 来维护 $a_j$ 的信息。如果 $a_j$ 满足条件,那么对答案的贡献是 $2^{i-j-1}$,也就是 $\frac{1}{2^{j+1}} \cdot 2^i$,因此在把 $a_j$ 按位插入 trie 中时,同时在对应节点加上 $\frac{1}{2^{j+1}}$。
枚举到 $a_i$ 时,此时已经往 trie 中插入了 $a_0 \sim a_{i-1}$,枚举 $a_i$ 的每一位,用 $x_i$ 和 $m_i$ 分别表示 $a_i$ 和 $m$ 在二进制下第 $i$ 位上的值。如果 $x_i \oplus 0 < m_i$,说明此时 $0$ 的分支剩余的 $a_j$ 都满足条件,把该分支节点上的关于 $\frac{1}{2^{j+1}}$ 的和累加到答案 $s$。同理如果 $x_i \oplus 1 < m_i$,说明此时 $1$ 的分支剩余的 $a_j$ 都满足条件。然后走到下一个分支节点,如果 $x_i \oplus 0 = m_i$ 则走到 $0$ 的分支节点,否则走到 $1$ 的分支节点。最后以 $a_i$ 为最大值的子序列的数量就是 $1 + s \cdot 2^{i}$。
AC 代码如下,时间复杂度为$O\left(n (\log{A} + \log{n})\right)$:
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 2e5 + 10, mod = 1e9 + 7;
int n, m;
int a[N];
int tr[N * 30][2], idx, s[N * 30];
int qmi(int a, int k) {
int ret = 1;
while (k) {
if (k & 1) ret = 1ll * ret * a % mod;
a = 1ll * a * a % mod;
k >>= 1;
}
return ret;
}
void add(int x, int c) {
int p = 0;
for (int i = 29; i >= 0; i--) {
int t = x >> i & 1;
if (!tr[p][t]) tr[p][t] = ++idx;
p = tr[p][t];
s[p] = (s[p] + c) % mod;
}
}
int query(int x, int c) {
int p = 0, ret = 0;
for (int i = 29; i >= 0; i--) {
int t = x >> i & 1;
if (tr[p][0] && t < (m >> i & 1)) ret = (ret + 1ll * s[tr[p][0]] * c) % mod;
if (tr[p][1] && (t ^ 1) < (m >> i & 1)) ret = (ret + 1ll * s[tr[p][1]] * c) % mod;
if (t == (m >> i & 1)) {
if (!tr[p][0]) return ret;
else p = tr[p][0];
}
else {
if (!tr[p][1]) return ret;
else p = tr[p][1];
}
}
ret = (ret + 1ll * s[p] * c) % mod;
return ret;
}
void solve(){
scanf("%d %d", &n, &m);
for (int i = 0; i < n; i++) {
scanf("%d", a + i);
}
sort(a, a + n);
idx = 0;
for (int i = 0; i <= n * 30; i++) {
tr[i][0] = tr[i][1] = s[i] = 0;
}
int ret = 0;
for (int i = 0; i < n; i++) {
ret = (ret + 1 + query(a[i], qmi(2, i))) % mod;
add(a[i], qmi(qmi(2, i + 1), mod - 2));
}
printf("%d\n", ret);
}
int main() {
int t;
scanf("%d", &t);
while (t--) {
solve();
}
return 0;
}
参考资料
【题解】2024牛客寒假算法基础集训营2:https://ac.nowcoder.com/discuss/1251379/
标签:XOR,leq,int,Max,tr,ret,Tokitsukaze,序列,oplus From: https://www.cnblogs.com/onlyblues/p/18010174