A. 云影密码

题解 按照题目的要求模拟即可.
C++ code by me
using namespace std;
using LL = long long;
int main(){
	int T;
	cin >> T;
		string s;
		cin >> s;
		string ans;
		for(int i = 0; i < s.size(); i++){
			int j = i;
			int sum = s[i] - '0';
			while(j + 1 < s.size() && s[j + 1] != '0')
				sum += s[++j] - '0';
			ans += char('a' + sum - 1);
			i = j + 1;
		cout << ans << '\n';
Java code by ChatGPT
import java.util.*;
public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int t = Integer.parseInt(sc.nextLine().trim());
        for (int i = 0; i < t; i++) {
            String[] nums = sc.nextLine().trim().split("0");
            StringBuilder sb = new StringBuilder();
            for (String num : nums) {
                int sum = 0;
                for (int j = 0; j < num.length(); j++) {
                    sum += Integer.parseInt(String.valueOf(num.charAt(j)));
                char c = (char) (sum + 'a' - 1);
Python code by me
import sys
input = sys.stdin.readline
for _ in range(int(input())):
    print("".join(chr(ord("a") - 1 + sum(int(num) for num in s)) for s in input().strip().split("0")))




根据实现的方式不同,算法的时间复杂度为$O(\sum{n\cdot n!)}$或者$O(\sum n!)$.

Bonus: 当$n \le 20$时如何解决?

C++ code by me
using namespace std;
using LL = long long;
int main(){
	int T;
	cin >> T;
		int n, m;
		cin >> n >> m;
		vector<array<int, 3> > p(n);
		for(int i = 0; i < n; i++){
			int a, b, c;
			cin >> a >> b >> c;
			p[i] = {a, b, c};
		vector<int> id(n);
		for(int i = 0; i < n; i++) id[i] = i;
		int ans = 0;
			int sum = 0, t = 0;
			for(int i = 0; i < n; i++){
				auto [a, b, c] = p[id[i]];
				t += b;
				if (t <= m){
					int delta = a / 250;
					sum += max(a - t * delta - 50 * c, 3 * a / 10);
			ans = max(ans, sum);
		while(next_permutation(id.begin(), id.end()));
		cout << ans << '\n';
Java code by ChatGPT
import java.util.*;
public class Main {
    public static void main(String[] args) {
        Scanner in = new Scanner(System.in);
        int T = in.nextInt();
        while (T-- > 0) {
            int n = in.nextInt(), m = in.nextInt();
            ArrayList<int[]> p = new ArrayList<int[]>();
            for (int i = 0; i < n; i++) {
                int a = in.nextInt(), b = in.nextInt(), c = in.nextInt();
                p.add(new int[]{a, b, c});
            int[] id = new int[n];
            for (int i = 0; i < n; i++) id[i] = i;
            int ans = 0;
            do {
                int sum = 0, t = 0;
                for (int i = 0; i < n; i++) {
                    int[] item = p.get(id[i]);
                    int a = item[0], b = item[1], c = item[2];
                    t += b;
                    if (t <= m) {
                        int delta = a / 250;
                        sum += Math.max(a - t * delta - 50 * c, 3 * a / 10);
                ans = Math.max(ans, sum);
            } while (nextPermutation(id));
    public static boolean nextPermutation(int[] nums) {
        int n = nums.length, i = n - 2, j = n - 1;
        while (i >= 0 && nums[i] >= nums[i + 1]) i--;
        if (i < 0) return false;
        while (j >= 0 && nums[j] <= nums[i]) j--;
        swap(nums, i, j);
        reverse(nums, i + 1, n - 1);
        return true;
    public static void swap(int[] nums, int i, int j) {
        int temp = nums[i];
        nums[i] = nums[j];
        nums[j] = temp;
    public static void reverse(int[] nums, int i, int j) {
        while (i < j) {
            swap(nums, i, j);
Python code by me
from itertools import permutations
import sys
input = sys.stdin.readline
for _ in range(int(input())):
    n, m = map(int, input().split())
    p = []
    for i in range(n):
        p.append(tuple(map(int, input().split())))
    ans = 0
    for id in permutations(range(n)):
        sum, t = 0, 0
        for i in range(n):
            a, b, c = p[id[i]]
            t += b
            if t <= m:
                sum += max(a - t * a // 250 - 50 * c, 3 * a // 10)
        ans = max(ans, sum)



对$a_{1\sim n}$分解质因子,考虑计算每个质因子对答案的贡献.

不妨设质因子$p$出现的位置有$b_1,b_2,b_3,\cdots ,b_k$,对于任意一个区间只要包含任意的$b_i$就会对答案产生$1$的贡献.



对于每个质因子我们先给答案加上$\frac{n(n+1)}{2}$,然后对于所有的相邻$b$之间的区间,假设长度为$len$,我们让答案减去$\frac{len(len + 1)}{2}$.




C++ code by me
using namespace std;
using LL = long long;
int main(){
    auto g = [](int x){
        return 1LL * x * (x + 1) / 2;
    int T;
    cin >> T;
        int n;
        cin >> n;
        map<int, vector<int> > mp;
        for(int i = 1; i <= n; i++){
            int x;
            cin >> x;
            for(int j = 2; j * j <= x; j++){
                if (x % j == 0){
                    while(x % j == 0) x /= j;
            if (x > 1) mp[x].push_back(i);
        LL ans = 0;
        for(auto &[x, v] : mp){
            ans += g(n);
            v.push_back(n + 1);
            int last = 0;
            for(auto u : v){
                ans -= g(u - last - 1);
                last = u;
        cout << ans << '\n';

Java code by ChatGPT
import java.util.*;
public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int T = sc.nextInt();
        while(T-- > 0){
            int n = sc.nextInt();
            Map<Integer, List<Integer>> mp = new HashMap<>();
            for(int i = 1; i <= n; i++){
                int x = sc.nextInt();
                for(int j = 2; j * j <= x; j++){
                    if (x % j == 0){
                        List<Integer> list = mp.getOrDefault(j, new ArrayList<>());
                        mp.put(j, list);
                        while(x % j == 0) x /= j;
                if (x > 1){
                    List<Integer> list = mp.getOrDefault(x, new ArrayList<>());
                    mp.put(x, list);
            long ans = 0;
            for(Map.Entry<Integer, List<Integer>> entry : mp.entrySet()){
                int x = entry.getKey();
                List<Integer> v = entry.getValue();
                ans += g(n);
                v.add(n + 1);
                int last = 0;
                for(int u : v){
                    ans -= g(u - last - 1);
                    last = u;
    static long g(int x){
        return 1L * x * (x + 1) / 2;
Python code by me
import sys
from collections import defaultdict
input = sys.stdin.readline
primes = [[] for i in range(1000001)]
for i in range(2, 1000001):
    if len(primes[i]) == 0:
        j = i
        while j <= 1000000:
            j += i
for _ in range(int(input())):
    n = int(input())
    a = tuple(map(int, input().split()))
    mp = defaultdict(list)
    for i in range(n):
        for p in primes[a[i]]:
    ans = 0
    for key, lst in mp.items():
        ans += n * (n + 1) // 2
        last = -1
        for pos in lst:
            ans -= (pos - last - 1) * (pos - last) // 2
            last = pos


题解 格式化输出即可.
C++ code by me
using namespace std;
using LL = long long;
int main(){
	char s[40];
	scanf("%s", s);
	printf("guan zhu %s miao, guan zhu %s xie xie miao!", s, s);
Java code by ChatGPT
import java.util.Scanner;
public class Main {
    public static void main(String[] args) {
        Scanner scanner = new Scanner(System.in);
        String s = scanner.next();
        System.out.printf("guan zhu %s miao, guan zhu %s xie xie miao!", s, s);
Python code by me
s = input()
print(f"guan zhu {s} miao, guan zhu {s} xie xie miao!")

E.Construction Complete!



对于任意一个矩形,首先要满足矩形内所有点都是空地,其次要满足矩形内存在一个$(x, y)$点满足$dist_{x,y} \le d$.


对于第二个条件,我们令所有$dist_{x,y} \le d$的点$(x, y)$为$1$,其他位置为$0$,同样处理出二维前缀和,查询时只需判断矩形和是否大于$0$即可.


C++ code by me
using namespace std;
using LL = long long;
int main(){
	const int dx[4] = {-1, 0, 1, 0};
	const int dy[4] = {0, -1, 0, 1};
	int T;
	cin >> T;
		int n, m, r, c, dist;
		cin >> n >> m >> r >> c >> dist;
		vector<string> g(n + 1);
		for(int i = 1; i <= n; i++){
			g[i].resize(m + 1);
			for(int j = 1; j <= m; j++)
				cin >> g[i][j];
		vector<vector<int> > d(n + 1, vector<int>(m + 1, n + m));
		queue<pair<int, int> > q;
		for(int i = 1; i <= n; i++)
			for(int j = 1; j <= m; j++)
				if (g[i][j] == 'x'){
					d[i][j] = 0;
					q.push({i, j});
			auto [x, y] = q.front(); q.pop();
			for(int u = 0; u < 4; u++){
				int nx = x + dx[u], ny = y + dy[u];
				if (nx < 1 || nx > n || ny < 1 || ny > m) continue;
				if (d[nx][ny] < n + m) continue;
				d[nx][ny] = d[x][y] + 1;
				q.push({nx, ny});
		vector<vector<int> > s1(n + 1, vector<int>(m + 1, 0)), s2(n + 1, vector<int>(m + 1, 0));
		for(int i = 1; i <= n; i++)
			for(int j = 1; j <= m; j++){
				s1[i][j] = (d[i][j] <= dist) + s1[i - 1][j] + s1[i][j - 1] - s1[i - 1][j - 1];
				s2[i][j] = (g[i][j] == '.') + s2[i - 1][j] + s2[i][j - 1] - s2[i - 1][j - 1];
		auto query = [&](const vector<vector<int> > &s, int x1, int y1, int x2, int y2){
			return s[x2][y2] - s[x2][y1 - 1] - s[x1 - 1][y2] + s[x1 - 1][y1 - 1];
		int ans = 0;
		for(int i = 1; i + r - 1 <= n; i++)
			for(int j = 1; j + c - 1 <= m; j++)
				if (query(s1, i, j, i + r - 1, j + c - 1) > 0 && query(s2, i, j, i + r - 1, j + c - 1) == r * c)
		cout << ans << '\n';
Java code by ChatGPT
import java.util.*;
import java.io.*;
public class Main {
    static int n, m, r, c, dist;
    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        int t = Integer.parseInt(br.readLine());
        while (t-- > 0) {
            StringTokenizer st = new StringTokenizer(br.readLine());
            n = Integer.parseInt(st.nextToken());
            m = Integer.parseInt(st.nextToken());
            r = Integer.parseInt(st.nextToken());
            c = Integer.parseInt(st.nextToken());
            dist = Integer.parseInt(st.nextToken());
            int[][] d = new int[n + 1][m + 1];
            int[][] s1 = new int[n + 1][m + 1];
            int[][] s2 = new int[n + 1][m + 1];
            for (int i = 0; i <= n; i++) {
                Arrays.fill(d[i], n + m);
            char[][] g = new char[n + 1][];
            Queue<int[]> q = new LinkedList<>();
            for (int i = 1; i <= n; i++) {
                g[i] = (" " + br.readLine()).toCharArray();
                for (int j = 1; j <= m; j++) {
                    if (g[i][j] == 'x') {
                        q.add(new int[]{i, j});
                        d[i][j] = 0;
            int[][] directions = {{-1, 0}, {0, -1}, {1, 0}, {0, 1}};
            while (!q.isEmpty()) {
                int[] curr = q.poll();
                int x = curr[0], y = curr[1];
                for (int[] dir : directions) {
                    int nx = x + dir[0], ny = y + dir[1];
                    if (nx < 1 || nx > n || ny < 1 || ny > m || d[nx][ny] < n + m) continue;
                    d[nx][ny] = d[x][y] + 1;
                    q.add(new int[]{nx, ny});
            for (int i = 1; i <= n; i++) {
                for (int j = 1; j <= m; j++) {
                    s1[i][j] = (d[i][j] <= dist ? 1 : 0) + s1[i - 1][j] + s1[i][j - 1] - s1[i - 1][j - 1];
                    s2[i][j] = (g[i][j] == '.' ? 1 : 0) + s2[i - 1][j] + s2[i][j - 1] - s2[i - 1][j - 1];
            int ans = 0;
            for (int i = 1; i <= n - r + 1; i++) {
                for (int j =1; j <= m - c + 1; j++) {
                    if (query(s1, i, j, i + r - 1, j + c - 1) > 0 && query(s2, i, j, i + r - 1, j + c - 1) == r * c) {
                        ans += 1;
    static int query(int[][] s, int x1, int y1, int x2, int y2) {
        return s[x2][y2] - s[x2][y1 - 1] - s[x1 - 1][y2] + s[x1 - 1][y1 - 1];
Python code by me
import sys
from collections import deque
input = sys.stdin.readline
def query(s, x1, y1, x2, y2):
    return s[x2][y2] - s[x2][y1 - 1] - s[x1 - 1][y2] + s[x1 - 1][y1 - 1]
for _ in range(int(input())):
    n, m, r, c, dist = map(int, input().split())
    d = [[n + m for j in range(m + 1)] for i in range(n + 1)]
    s1 = [[0 for j in range(m + 1)] for i in range(n + 1)]
    s2 = [[0 for j in range(m + 1)] for i in range(n + 1)]
    q = deque()
    g = [[]]
    for i in range(1, n + 1):
        g.append(" " + input())
        for j in range(1, m + 1):
            if g[i][j] == 'x':
                q.append((i, j))
                d[i][j] = 0
    while q:
        x, y = q.popleft()
        for dx, dy in ((-1, 0), (0, -1), (1, 0), (0, 1)):
            nx, ny = x + dx, y + dy
            if nx < 1 or nx > n or ny < 1 or ny > m or d[nx][ny] < n + m: continue
            d[nx][ny] = d[x][y] + 1
            q.append((nx, ny))
    for i in range(1, n + 1):
        for j in range(1, m + 1):
            s1[i][j] = (d[i][j] <= dist) + s1[i - 1][j] + s1[i][j - 1] - s1[i - 1][j - 1]
            s2[i][j] = (g[i][j] == '.') + s2[i - 1][j] + s2[i][j - 1] - s2[i - 1][j - 1]
    ans = 0
    for i in range(1, n - r + 2):
        for j in range(1, m - c + 2):
            if query(s1, i, j, i + r - 1, j + c - 1) > 0 and query(s2, i, j, i + r - 1, j + c - 1) == r * c:
                ans += 1




我们只需要证明对于任意的$a < b < c$,都满足最小值一定是$a\oplus b$或者$b \oplus c$,而不可能是$a \oplus c$.



对于$a\oplus c$,找到其二进制表示中最高的为$1$的位为第$x$位,则$a$和$c$高于$x$的位一定相同,第$x$位一定不同,因为$a < c$,一定是$a$第$x$位为$0$,$c$第$x$位为$1$.

因为$a < b < c$,高于$x$的位$b$一定也和$a,c$都相同,$b$的第$x$位可能为$0$或者为$1$.

所以$a\oplus b$和$b \oplus c$中一定有一个数第$x$位为$0$,而$a\oplus c$第$x$位为$1$,即一定有一个数是小于$a \oplus c$的,由此结论得证.


每次加入一个数$x$时,找到有序集合中$x$的前驱$pre$和后继$nxt$,在维护异或值的数据结构中删除$nxt\oplus pre$,加入$x \oplus pre$和$x \oplus nxt$.

删除一个数$x$时类似,删除$x$后同样找到$x$的前驱$pre$和后继$nxt$,加入$nxt\oplus pre$,删除$x \oplus pre$和$x \oplus nxt$即可.









  1. 当前子树内能凑出的最小异或对的答案
  2. 子树内的元素个数
  3. 如果子树内只有一个元素,记录这个元素的值.









C++ code with multiset by me
using namespace std;
using LL = long long;
int main(){
    multiset<int> s, v;
    int n;
    cin >> n;
        string op;
        cin >> op;
        if (op[0] == 'A'){
            int x;
            cin >> x;
            auto it = s.lower_bound(x);
            if (it != s.end())
                v.insert(*it ^ x);
            if (it != s.begin()){
                v.insert(*prev(it) ^ x);
            if (it != s.begin() && it != s.end()){
                v.erase(v.lower_bound(*it ^ *prev(it)));
        else if (op[0] == 'D'){
            int x;
            cin >> x;
            auto it = s.lower_bound(x);
            if (it != s.end())
                v.erase(v.lower_bound(*it ^ x));
            if (it != s.begin()){
                v.erase(v.lower_bound(*prev(it) ^ x));
            if (it != s.begin() && it != s.end()){
                v.insert(*it ^ *prev(it));
        else cout << *v.begin() << '\n';
C++ code with 01Trie by me
using namespace std;
using LL = long long;
const int maxn = 2e5 + 5, INF = 1 << 30;
struct Info {
    int dp, cnt, val;
    Info() : dp(INF), cnt(0), val(-1) {} 
    Info(int x) : dp(INF), cnt(1), val(x) {} 
    Info(int dp, int cnt, int val) : dp(dp), cnt(cnt), val(val) {}
}tr[maxn * 32];
int son[maxn * 32][2];
int root, idx;
Info operator+(const Info &a, const Info &b){
    Info ret = Info();
    ret.cnt = a.cnt + b.cnt;
    ret.dp = min(a.dp, b.dp);
    if (a.cnt == 1 && b.cnt == 1) ret.dp = a.val ^ b.val;
    if (ret.cnt == 1) ret.val = a.cnt ? a.val : b.val;
    return ret;
void insert(int &u, int x, int dep){
    if (!u) u = ++idx;
    if (dep == -1){
        if (tr[u].cnt == 1) tr[u] = Info(x);
        else if (tr[u].cnt == 2) tr[u].dp = 0;
    int bit = (x >> dep & 1);
    insert(son[u][bit], x, dep - 1);
    tr[u] = tr[son[u][0]] + tr[son[u][1]];
void del(int &u, int x, int dep){
    if (dep == -1){
        if (tr[u].cnt == 0) tr[u] = Info();
        else if (tr[u].cnt == 1) tr[u] = Info(x);
    int bit = (x >> dep & 1);
    del(son[u][bit], x, dep - 1);
    tr[u] = tr[son[u][0]] + tr[son[u][1]];
int main(){
    int n;
    cin >> n;
    int root = ++idx;
        string op;
        cin >> op;
        if (op[0] == 'A'){
            int x;
            cin >> x;
            insert(root, x, 29);
        else if (op[0] == 'D'){
            int x;
            cin >> x;
            del(root, x, 29);
        else cout << tr[root].dp << '\n';
Java code by ChatGPT
import java.util.*;
public class Main {
    public static void main(String[] args) {
        Scanner in = new Scanner(System.in);
        int n = in.nextInt();
        TreeMap<Integer, Integer> s = new TreeMap<>();
        TreeMap<Integer, Integer> v = new TreeMap<>();
        int cnt = 0;  // 记录s中数量大于等于2的元素的个数
        s.put(Integer.MAX_VALUE, 0);
        while (n-- > 0) {
            String op = in.next();
            if (op.charAt(0) == 'A') {
                int x = in.nextInt();
                if (!s.containsKey(x)) {
                    Integer it = s.higherKey(x);
                    if (it != null) {
                        int xor = it ^ x;
                        v.put(xor, v.getOrDefault(xor, 0) + 1);
                    if (it != null && s.lowerKey(it) != null) {
                        int xor = s.lowerKey(it) ^ it;
                        v.put(xor, v.getOrDefault(xor, 0) - 1);
                        if (v.get(xor) == 0) {
                    if (s.lowerKey(it) != null) {
                        int xor = s.lowerKey(it) ^ x;
                        v.put(xor, v.getOrDefault(xor, 0) + 1);
                    s.put(x, 1);
                    if (s.get(x) > 1) {
                } else {
                    s.put(x, s.get(x) + 1);
                    if (s.get(x) == 2) {
            } else if (op.charAt(0) == 'D') {
                int x = in.nextInt();
                if (s.containsKey(x)) {
                    int nums = s.get(x);
                    if (nums == 1) {
                        Integer it = s.higherKey(x);
                        if (it != null && s.lowerKey(it) != null) {
                            int xor = s.lowerKey(it) ^ it;
                            v.put(xor, v.getOrDefault(xor, 0) + 1);
                        if (s.lowerKey(it) != null && it != null) {
                            int xor = s.lowerKey(it) ^ x;
                            v.put(xor, v.getOrDefault(xor, 0) - 1);
                            if (v.get(xor) == 0) {
                        if (it != null) {
                            int xor = it ^ x;
                            v.put(xor, v.getOrDefault(xor, 0) - 1);
                            if (v.get(xor) == 0) {
                    } else {
                        s.put(x, nums - 1);
                        if (nums == 2) {
            } else {
                if (cnt > 0) {
                } else {
                    System.out.println(v.isEmpty() ? 0 : v.firstKey());
Python code by me
import sys
input = sys.stdin.readline   
import random
class TreapMultiSet(object):
    root = 0
    size = 0
    def __init__(self, data=None):
        if data:
            data = sorted(data)
            self.root = treap_builder(data)
            self.size = len(data)
    def add(self, key):
        self.root = treap_insert(self.root, key)
        self.size += 1
    def remove(self, key):
        self.root = treap_erase(self.root, key)
        self.size -= 1
    def discard(self, key):
        except KeyError:
    def ceiling(self, key):
        x = treap_ceiling(self.root, key)
        return treap_keys[x] if x else None
    def higher(self, key):
        x = treap_higher(self.root, key)
        return treap_keys[x] if x else None
    def floor(self, key):
        x = treap_floor(self.root, key)
        return treap_keys[x] if x else None
    def lower(self, key):
        x = treap_lower(self.root, key)
        return treap_keys[x] if x else None
    def max(self):
        return treap_keys[treap_max(self.root)]
    def min(self):
        return treap_keys[treap_min(self.root)]
    def __len__(self):
        return self.size
    def __nonzero__(self):
        return bool(self.root)
    __bool__ = __nonzero__
    def __contains__(self, key):
        return self.floor(key) == key
    def __repr__(self):
        return "TreapMultiSet({})".format(list(self))
    def __iter__(self):
        if not self.root:
            return iter([])
        out = []
        stack = [self.root]
        while stack:
            node = stack.pop()
            if node > 0:
                if right_child[node]:
                if left_child[node]:
        return iter(out)
class TreapSet(TreapMultiSet):
    def add(self, key):
        self.root, duplicate = treap_insert_unique(self.root, key)
        if not duplicate:
            self.size += 1
    def __repr__(self):
        return "TreapSet({})".format(list(self))
class TreapHashSet(TreapMultiSet):
    def __init__(self, data=None):
        if data:
            self.keys = set(data)
            super(TreapHashSet, self).__init__(self.keys)
            self.keys = set()
    def add(self, key):
        if key not in self.keys:
            super(TreapHashSet, self).add(key)
    def remove(self, key):
        super(TreapHashSet, self).remove(key)
    def discard(self, key):
        if key in self.keys:
    def __contains__(self, key):
        return key in self.keys
    def __repr__(self):
        return "TreapHashSet({})".format(list(self))
class TreapHashMap(TreapMultiSet):
    def __init__(self, data=None):
        if data:
            self.map = dict(data)
            super(TreapHashMap, self).__init__(self.map.keys())
            self.map = {}
    def __setitem__(self, key, value):
        if key not in self.map:
            super(TreapHashMap, self).add(key)
        self.map[key] = value
    def __getitem__(self, key):
        return self.map[key]
    def add(self, key):
        raise TypeError("add on TreapHashMap")
    def get(self, key, default=None):
        return self.map.get(key, default=default)
    def remove(self, key):
        super(TreapHashMap, self).remove(key)
    def discard(self, key):
        if key in self.map:
    def __contains__(self, key):
        return key in self.map
    def __repr__(self):
        return "TreapHashMap({})".format(list(self))
left_child = [0]
right_child = [0]
treap_keys = [0]
treap_prior = [0.0]
def treap_builder(sorted_data):
    """Build a treap in O(n) time using sorted data"""
    def build(begin, end):
        if begin == end:
            return 0
        mid = (begin + end) // 2
        root = treap_create_node(sorted_data[mid])
        left_child[root] = build(begin, mid)
        right_child[root] = build(mid + 1, end)
        # sift down the priorities
        ind = root
        while True:
            lc = left_child[ind]
            rc = right_child[ind]
            if lc and treap_prior[lc] > treap_prior[ind]:
                if rc and treap_prior[rc] > treap_prior[lc]:
                    treap_prior[ind], treap_prior[rc] = treap_prior[rc], treap_prior[ind]
                    ind = rc
                    treap_prior[ind], treap_prior[lc] = treap_prior[lc], treap_prior[ind]
                    ind = lc
            elif rc and treap_prior[rc] > treap_prior[ind]:
                treap_prior[ind], treap_prior[rc] = treap_prior[rc], treap_prior[ind]
                ind = rc
        return root
    return build(0, len(sorted_data))
def treap_create_node(key):
    return len(treap_keys) - 1
def treap_split(root, key):
    left_pos = right_pos = 0
    while root:
        if key < treap_keys[root]:
            left_child[right_pos] = right_pos = root
            root = left_child[root]
            right_child[left_pos] = left_pos = root
            root = right_child[root]
    left, right = right_child[0], left_child[0]
    right_child[left_pos] = left_child[right_pos] = right_child[0] = left_child[0] = 0
    return left, right
def treap_merge(left, right):
    where, pos = left_child, 0
    while left and right:
        if treap_prior[left] > treap_prior[right]:
            where[pos] = pos = left
            where = right_child
            left = right_child[left]
            where[pos] = pos = right
            where = left_child
            right = left_child[right]
    where[pos] = left or right
    node = left_child[0]
    left_child[0] = 0
    return node
def treap_insert(root, key):
    if not root:
        return treap_create_node(key)
    left, right = treap_split(root, key)
    return treap_merge(treap_merge(left, treap_create_node(key)), right)
def treap_insert_unique(root, key):
    if not root:
        return treap_create_node(key), False
    left, right = treap_split(root, key)
    if left and treap_keys[left] == key:
        return treap_merge(left, right), True
    return treap_merge(treap_merge(left, treap_create_node(key)), right), False
def treap_erase(root, key):
    if not root:
        raise KeyError(key)
    if treap_keys[root] == key:
        return treap_merge(left_child[root], right_child[root])
    node = root
    while root and treap_keys[root] != key:
        parent = root
        root = left_child[root] if key < treap_keys[root] else right_child[root]
    if not root:
        raise KeyError(key)
    if root == left_child[parent]:
        left_child[parent] = treap_merge(left_child[root], right_child[root])
        right_child[parent] = treap_merge(left_child[root], right_child[root])
    return node
def treap_ceiling(root, key):
    while root and treap_keys[root] < key:
        root = right_child[root]
    if not root:
        return 0
    min_node = root
    min_key = treap_keys[root]
    while root:
        if treap_keys[root] < key:
            root = right_child[root]
            if treap_keys[root] < min_key:
                min_key = treap_keys[root]
                min_node = root
            root = left_child[root]
    return min_node
def treap_higher(root, key):
    while root and treap_keys[root] <= key:
        root = right_child[root]
    if not root:
        return 0
    min_node = root
    min_key = treap_keys[root]
    while root:
        if treap_keys[root] <= key:
            root = right_child[root]
            if treap_keys[root] < min_key:
                min_key = treap_keys[root]
                min_node = root
            root = left_child[root]
    return min_node
def treap_floor(root, key):
    while root and treap_keys[root] > key:
        root = left_child[root]
    if not root:
        return 0
    max_node = root
    max_key = treap_keys[root]
    while root:
        if treap_keys[root] > key:
            root = left_child[root]
            if treap_keys[root] > max_key:
                max_key = treap_keys[root]
                max_node = root
            root = right_child[root]
    return max_node
def treap_lower(root, key):
    while root and treap_keys[root] >= key:
        root = left_child[root]
    if not root:
        return 0
    max_node = root
    max_key = treap_keys[root]
    while root:
        if treap_keys[root] >= key:
            root = left_child[root]
            if treap_keys[root] > max_key:
                max_key = treap_keys[root]
                max_node = root
            root = right_child[root]
    return max_node
def treap_min(root):
    if not root:
        raise ValueError("min on empty treap")
    while left_child[root]:
        root = left_child[root]
    return root
def treap_max(root):
    if not root:
        raise ValueError("max on empty treap")
    while right_child[root]:
        root = right_child[root]
    return root
s, v = TreapMultiSet(), TreapMultiSet()
for _ in range(int(input())):
    op = input()
    if op[0] == 'Q':
        x = int(op[4:].strip())
        if op[0] == 'A':
            nxt = s.higher(x)
            pre = s.floor(x)
            if nxt is not None and pre is not None: 
                v.remove(nxt ^ pre)
            if pre is not None:
                v.add(x ^ pre)
            if nxt is not None:
                v.add(nxt ^ x)
            nxt = s.higher(x)
            pre = s.floor(x)
            if nxt is not None and pre is not None:
                v.add(nxt ^ pre)
            if pre is not None:
                v.remove(x ^ pre)
            if nxt is not None:
                v.remove(nxt ^ x)



对于$1 \sim n$我们都用一个$8$位的二进制数表示其状态,第$i$位取$0/1$表示第$i$张卡片上是否有这个数字,然后我们考虑这些数之间需要具有怎样的性质.

对于任意数字$x$,假设其没有出现在第$c_1, c_2, c_3,\cdots$张卡片上,那么除了$x$以外的所有数字都至少在第$c_1,c_2,c_3,\cdots$上出现过一次.即第$c_1,c_2,c_3,\cdots$上至少有一位为$1$.所以所有除了$x$以外的数字的二进制表示一定不是$x$的子集.


一种简单的构造方法是在所有$8$位的二进制数字里选取二进制表示中$1$的数量为$4$的数来构造,这些数显然不会有一个数是另一个数的子集.并且这样的数有$C_8^4 = 70$个,而$n \le 60$.

C++ code by me
using namespace std;
using LL = long long;
int main(){
    int T;
    cin >> T;
        const int N = 8;
        int n;
        cin >> n;
        int cnt = 0;
        vector<vector<int> > ans(N);
        for(int i = 0; i < 1 << N; i++){
            if (__builtin_popcount(i) != 4) continue;
            if (++cnt > n) break;
            for(int j = 0; j < N; j++)
                if (i >> j & 1) ans[j].push_back(cnt);
        cout << N << '\n';
        for(int i = 0; i < N; i++){
            cout << ans[i].size();
            for(auto x : ans[i]) cout << ' ' << x;
            cout << '\n';
Java code by ChatGPT
import java.util.ArrayList;
import java.util.List;
import java.util.Scanner;
public class Main {
    public static void main(String[] args) {
        Scanner scanner = new Scanner(System.in);
        int T = scanner.nextInt();
        while (T-- > 0) {
            final int N = 8;
            int n = scanner.nextInt();
            int cnt = 0;
            List<List<Integer>> ans = new ArrayList<>();
            for (int i = 0; i < N; i++) {
                ans.add(new ArrayList<>());
            for (int i = 0; i < (1 << N); i++) {
                if (Integer.bitCount(i) != 4) continue;
                if (++cnt > n) break;
                for (int j = 0; j < N; j++) {
                    if ((i >> j & 1) == 1) {
            for (int i = 0; i < N; i++) {
                for (int x : ans.get(i)) {
                    System.out.print(" " + x);
Python code by me
cand = []
for i in range(1 << 8):
    cnt = 0
    for j in range(8):
        cnt += (i >> j & 1)
    if cnt == 4:
for _ in range(int(input())):
    n = int(input())
    ans = [[] for i in range(8)]
    for i in range(n):
        for j in range(8):
            if cand[i] >> j & 1: 
                ans[j].append(i + 1)
    for i in range(8):
        print(len(ans[i]), *ans[i])



我们用$dp_{i, a, b, c}$表示击退对手$i$只精灵,攻击等级为$a$,防御等级为$b$,速度等级为$c$时能剩余的最大体力,如果无法达到该状态则为$-1$.








C++ code by me
using namespace std;
using LL = long long;
const int maxn = 1e4 + 5;
int f[maxn][4][4][4];
int main(){
    int T;
    cin >> T;
        int n, k;
        cin >> n >> k;
        int H, A[4], D[4], S[4];
        cin >> H;
        for(int i = 0; i <= k; i++) cin >> A[i];
        for(int i = 0; i <= k; i++) cin >> D[i];
        for(int i = 0; i <= k; i++) cin >> S[i];
        for(int i = 0; i <= n; i++){
            for(int a = 0; a <= k; a++)
            for(int b = 0; b <= k; b++)
            for(int c = 0; c <= k; c++)
                f[i][a][b][c] = -1;
        f[0][0][0][0] = H;
        for(int i = 0; i < n; i++){
            int h, a, d, s;
            cin >> h >> a >> d >> s;
            auto update = [&](int l1, int l2, int l3){
				int cur = f[i][l1][l2][l3];
                auto get = [&](int c1, int c2, int c3){
					int remain = cur;
                    for(int i = 1; i <= c2; i++){
						if (S[l3] >= s){
							remain -= max(0, a - D[l2 + i]);
							if (remain <= 0) return -1;
							remain -= max(0, a - D[l2 + i - 1]);
							if (remain <= 0) return -1;
					if (A[l1 + c1] <= d) return -1;
					int dec1 = max(0, a - D[l2 + c2]);
					int dec2 = A[l1 + c1] - d;
					if (1LL * (c1 + c3) * dec1 >= remain) return -1;
					if (dec1 == 0) return remain;
					remain -= (c1 + c3) * dec1;
					int cnt = (h + dec2 - 1) / dec2 - (S[l3 + c3] >= s);
					if (1LL * cnt * dec1 >= remain) return -1;
					remain -= cnt * dec1;
					return remain;
                for(int c1 = 0; c1 <= k - l1; c1++)
                for(int c2 = 0; c2 <= k - l2; c2++)
                for(int c3 = 0; c3 <= k - l3; c3++)
                    f[i + 1][l1 + c1][l2 + c2][l3 + c3] = max(f[i + 1][l1 + c1][l2 + c2][l3 + c3], get(c1, c2, c3));
            for(int a = 0; a <= k; a++)
            for(int b = 0; b <= k; b++)
            for(int c = 0; c <= k; c++)
                if (f[i][a][b][c] > 0)
                    update(a, b, c);
        int ans = -1;
        for(int a = 0; a <= k; a++)
        for(int b = 0; b <= k; b++)
        for(int c = 0; c <= k; c++)
            ans = max(ans, f[n][a][b][c]);
        cout << ans << '\n';
Java code by ChatGPT
import java.util.*;
public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int T = Integer.parseInt(sc.nextLine().trim());
        while (T-- > 0) {
            int n = sc.nextInt();
            int k = sc.nextInt();
            int H = sc.nextInt();
            int[] A = new int[k + 1];
            int[] D = new int[k + 1];
            int[] S = new int[k + 1];
            for (int i = 0; i <= k; i++) {
                A[i] = sc.nextInt();
            for (int i = 0; i <= k; i++) {
                D[i] = sc.nextInt();
            for (int i = 0; i <= k; i++) {
                S[i] = sc.nextInt();
            int[][][][] f = new int[n + 1][k + 1][k + 1][k + 1];
            for (int i = 0; i <= n; i++) {
                for (int l1 = 0; l1 <= k; l1++) {
                    for (int l2 = 0; l2 <= k; l2++) {
                        for (int l3 = 0; l3 <= k; l3++) {
                            f[i][l1][l2][l3] = -1;
            f[0][0][0][0] = H;
            for (int i = 0; i < n; i++) {
                int h = sc.nextInt();
                int a = sc.nextInt();
                int d = sc.nextInt();
                int s = sc.nextInt();
                for (int l1 = 0; l1 <= k; l1++) {
                    for (int l2 = 0; l2 <= k; l2++) {
                        for (int l3 = 0; l3 <= k; l3++) {
                            if (f[i][l1][l2][l3] <= 0) continue;
                            for (int c1 = 0; c1 <= k - l1; c1++) {
                                for (int c2 = 0; c2 <= k - l2; c2++) {
                                    for (int c3 = 0; c3 <= k - l3; c3++) {
                                        int remain = f[i][l1][l2][l3];
                                        for (int t = 1; t <= c2; t++) {
                                            if (S[l3] >= s) {
                                                remain -= Math.max(0, a - D[l2 + t]);
                                            } else {
                                                remain -= Math.max(0, a - D[l2 + t - 1]);
                                            if (remain <= 0) break;
                                        if (remain <= 0 || A[l1 + c1] <= d) continue;
                                        int dec1 = Math.max(0, a - D[l2 + c2]);
                                        int dec2 = A[l1 + c1] - d;
                                        if (1L * (c1 + c3) * dec1 >= remain) continue;
                                        remain -= (c1 + c3) * dec1;
                                        int cnt = (h + dec2 - 1) / dec2 - (S[l3 + c3] >= s ? 1 : 0);
                                        if (1L * cnt * dec1 >= remain) continue;
                                        remain -= cnt * dec1;
                                        f[i + 1][l1 + c1][l2 + c2][l3 + c3] = Math.max(f[i + 1][l1 + c1][l2 + c2][l3 + c3], remain);
            int ans = -1;
            for (int a = 0; a <= k; a++) {
                for (int b = 0; b <= k; b++) {
                    for (int c = 0; c <= k; c++) {
                        ans = Math.max(ans, f[n][a][b][c]);
Python code by me
import sys
input = sys.stdin.readline    
for _ in range(int(input())):
    n, k = map(int, input().split())
    H = int(input())
    A = tuple(map(int, input().split()))
    D = tuple(map(int, input().split()))
    S = tuple(map(int, input().split()))
    f = [[[[-1 for a in range(k + 1)] for b in range(k + 1)] for c in range(k + 1)] for d in range(n + 1)]
    f[0][0][0][0] = H
    for i in range(n):
        h, a, d, s = map(int, input().split())
        for l1 in range(k + 1):
            for l2 in range(k + 1):
                for l3 in range(k + 1):
                    if f[i][l1][l2][l3] <= 0: continue
                    for c1 in range(k - l1 + 1):
                        for c2 in range(k - l2 + 1):
                            for c3 in range(k - l3 + 1):
                                remain = f[i][l1][l2][l3]
                                for t in range(1, c2 + 1):
                                    if S[l3] >= s:
                                        remain -= max(0, a - D[l2 + t])
                                        remain -= max(0, a - D[l2 + t - 1])
                                    if remain <= 0: 
                                if remain <= 0 or A[l1 + c1] <= d: continue
                                dec1, dec2 = max(0, a - D[l2 + c2]), A[l1 + c1] - d
                                if (c1 + c3) * dec1 >= remain: continue
                                remain -= (c1 + c3) * dec1
                                cnt = (h + dec2 - 1) // dec2 - (S[l3 + c3] >= s)
                                if cnt * dec1 >= remain: continue
                                remain -= cnt * dec1
                                f[i + 1][l1 + c1][l2 + c2][l3 + c3] = max(f[i + 1][l1 + c1][l2 + c2][l3 + c3], remain)
    print(max(f[n][a][b][c] for a in range(k + 1) for b in range(k + 1) for c in range(k + 1)))



$f(x) = x^k$与$g(x) = \gcd(x, k)$都是积性函数,可以在线性筛的过程中计算.



$f(x) = x^k$是完全积性函数,因此有$f(x) = f(p) \cdot f(\frac{x}{p})$.

对于$g(x) = \gcd(x, k)$,只需判断是否有$\frac{k}{g(\frac{x}{p})} \equiv 0 (\bmod \ p)$,如果是则$g(x) = p \cdot g(\frac{x}{p})$,否则$g(x) = g(\frac{x}{p})$.

质数密度$\pi(n) \approx \frac{n}{\ln{n}}$,所以总时间复杂度为$O(\frac{\log_2{k}}{\ln{n}} \cdot n + n) \approx O(n)$.


$f(x) = x^k$很容易发现是积性函数,但是观察到$g(x) = \gcd(x, k)$是积性函数还是有一定难度的.

$f(x) = x^k$通过解法$1$中的方式计算.

而对于$g(x) = \gcd(x, k)$,我们发现$g(x)$的取值是非常有限的,只可能是$k$小于等于$n$的因子.因此我们可以枚举所有因子的倍数更新$g(x)$的值.

枚举$1 \sim n$所有倍数的时间复杂度是$O(n\log{n})$,但因为$k$的因子在$1 \sim n$中的分布是非常稀疏的,所以实际复杂度远小于这个上界,时间复杂度不明,但足以通过本题.

C++ code by me
using namespace std;
const int maxn = 2e7 + 5, mod = 1e9 + 7;
typedef long long LL;
int primes[maxn], cnt;
int x[maxn], g[maxn];
bool isPrime[maxn];
int qpow(int a, int b, int mod){
    int res = 1;
    while (b){
        if (b & 1) res = 1LL * res * a % mod;
        a = 1LL * a * a % mod;
        b >>= 1;
    return res;
int f(int n, LL k){
    int res = 1;
    int up = k % (mod - 1);
    for(int i = 2; i <= n; i++){
        if (!isPrime[i]){
            primes[cnt++] = i;
            x[i] = qpow(i, up, mod);
            g[i] = __gcd(1LL * i, k);
        for(int j = 0; i * primes[j] <= n; j++){
            isPrime[i * primes[j]] = 1;
            x[i * primes[j]] = 1LL * x[i] * x[primes[j]] % mod;
            if (i % primes[j] == 0){
                g[i * primes[j]] = g[i];
                if ((k / g[i]) % primes[j] == 0) g[i * primes[j]] *= primes[j];
            g[i * primes[j]] = g[i] * g[primes[j]];
        res = (res + 1LL * g[i] * x[i]) % mod;
    return res;
int main(){
    int n; LL k;
    cin >> n >> k;
    cout << f(n, k) << '\n';
Java code by ChatGPT
import java.util.*;
public class Main {
    static final int maxn = 20000005;
    static final int mod = 1000000007;
    static int[] primes = new int[maxn];
    static int[] x = new int[maxn];
    static int[] g = new int[maxn];
    static boolean[] isPrime = new boolean[maxn];
    static int cnt;
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        long k = sc.nextLong();
        int ans = f(n, k);
    public static int qpow(int a, int b, int mod){
        int res = 1;
        while (b > 0){
            if ((b & 1) != 0) {
                res = (int) ((1L * res * a) % mod);
            a = (int) ((1L * a * a) % mod);
            b >>= 1;
        return res;
    public static int f(int n, long k){
        int res = 1;
        int up = (int) (k % (mod - 1));
        for(int i = 2; i <= n; i++){
            if (!isPrime[i]){
                primes[cnt++] = i;
                x[i] = qpow(i, up, mod);
                g[i] = gcd(1L * i, k);
            for(int j = 0; i * primes[j] <= n; j++){
                isPrime[i * primes[j]] = true;
                x[i * primes[j]] = (int) ((1L * x[i] * x[primes[j]]) % mod);
                if (i % primes[j] == 0){
                    g[i * primes[j]] = g[i];
                    if ((k / g[i]) % primes[j] == 0) g[i * primes[j]] *= primes[j];
                g[i * primes[j]] = g[i] * g[primes[j]];
            res = (int) ((res + 1L * g[i] * x[i]) % mod);
        return res;
    public static int gcd(long a, long b) {
        return b == 0 ? (int) a : gcd(b, a % b);
Python code by me
from math import gcd
from array import array
MOD = 1000000007
n, k = map(int, input().split())
primes = array('i')
cnt = 0
x = array('i', [0] * (n + 1))
g = array('i', [0] * (n + 1))
is_prime = array('b', [False] * (n + 1))
res = 1
up = k % (MOD - 1)
for i in range(2, n + 1):
    if not is_prime[i]:
        cnt += 1
        x[i] = pow(i, up, MOD)
        g[i] = gcd(i, k)
    for p in primes:
        if i * p > n: break
        is_prime[i * p] = True
        x[i * p] = x[i] * x[p] % MOD
        if i % p == 0:
            g[i * p] = g[i]
            if (k // g[i]) % p == 0:
                g[i * p] *= p
        g[i * p] = g[i] * g[p]
    res = (res + g[i] * x[i]) % MOD





C++ code by me
using namespace std;
using LL = long long;
int main(){
    int T;
    cin >> T;
        int n;
        cin >> n;
        assert(__gcd(n - 1, n * (n - 2)) == 1);
        cout << n - 1 << ' ' << n * (n - 2) << '\n';
Java code by ChatGPT
import java.util.*;
import java.lang.*;
public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int T = sc.nextInt();
        while(T-- > 0){
            int n = sc.nextInt();
            assert gcd(n - 1, n * (n - 2)) == 1;
            System.out.println((n - 1) + " " + n * (n - 2));
    static int gcd(int a, int b){
        if (b == 0) return a;
        return gcd(b, a % b);
Python code by me
for _ in range(int(input())):
    n = int(input())
    print(n - 1, n * (n - 2))

