神题。
设第 \(i\) 个箱子有 \(x_i\) 个红球,\(y_i\) 个蓝球,那么要求找到最大的 \(K\) 使得 \(\sum\limits_{i = 1}^K x_i \le R, \sum\limits_{i = 1}^K y_i \le B\),且 \((x_i, y_i)\) 两两不等。
显然我们都希望 \(x_i, y_i\) 尽量小。但是当 \(R, B\) 有一定差距时,我们希望贡献一些 \(x_i\) 让一些 \(y_i\) 相等,不好考虑。
人类智慧地,考虑类似 wqs 二分地给每一维赋一个权值 \((p, q)\),那么当 \(\sum x_i \le R \land \sum y_i \le B\) 时就有 \(p \sum x_i + q \sum y_i \le pR + qB\),所以某种程度我们希望 \(p \sum x_i + q \sum y_i\) 尽量小。
先二分一个 \(K\),把对于每个 \((p, q)\),\(p x_i + q y_i\) 最小的 \(K\) 个点的 \((\sum x_i, \sum y_i)\) 扔到平面上,发现它们会形成一个下凸包,我们需要判断 \((R, B)\) 是不是在这个凸包的右上方。
无法直接求出这个凸包,考虑直接钦定 \(p + q\) 等于一个大常数(比如 \(10^9 + 7\)),感性理解一下因为 \(p + q\) 足够大所以有用的情况都能考虑到。然后我们二分 \(p\),用这组 \((p, q)\) 去搞出凸包上的点。
现在要求使得 \(p x_i + q y_i\) 最小的一组 \((x_i, y_i)\) 的 \((\sum x_i, \sum y_i)\),考虑二分一个 \(z\),求出 \(px + qy \le z\) 的点数和对应方案的 \((\sum x_i, \sum y_i)\)。注意到如果选了一组 \((x, y)\),它左下角的点一定都被选了,所以 \(\min(x, y)\) 是 \(O(V^{\frac{1}{3}})\) 级别的。可以直接枚举其中一维计算答案。
求出来这组 \((\sum x_i, \sum y_i)\) 后,如果已经满足 \(\sum x_i \le R \land \sum y_i \le B\) 了就直接知道这个 \(K\) 可行了;否则若 \(\sum x_i > R \land \sum y_i > B\) 就一定不可行;否则根据 \(\sum x_i\) 与 \(R\) 的大小判断增大 \(p\) 还是减小即可。
同时我们还要记录左边界和右边界的凸包上的点,方便最后判断 \((R, B)\) 是否在这个凸包的右上方,这个用叉积判一下即可。
总共需要三层二分。时间复杂度 \(O(V^{\frac{1}{3}} \log^3 V)\)。
code
// Problem: D - Distinct Boxes
// Contest: AtCoder - World Tour Finals 2019
// URL: https://atcoder.jp/contests/wtf19/tasks/wtf19_d
// Memory Limit: 1024 MB
// Time Limit: 2000 ms
//
// Powered by CP Editor (https://cpeditor.org)
#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mkp make_pair
#define mems(a, x) memset((a), (x), sizeof(a))
using namespace std;
typedef long long ll;
typedef double db;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<ll, ll> pii;
const ll M = 1000000007;
ll n, m;
struct node {
ll x, y;
node(ll a = 0, ll b = 0) : x(a), y(b) {}
};
inline node operator + (const node &a, const node &b) {
return node(a.x + b.x, a.y + b.y);
}
inline node operator - (const node &a, const node &b) {
return node(a.x - b.x, a.y - b.y);
}
inline ll operator * (const node &a, const node &b) {
return a.x * b.y - a.y * b.x;
}
inline node calc(ll p, ll q, ll k) {
bool flag = 0;
if (p < q) {
swap(p, q);
flag = 1;
}
ll l = 0, r = 1e18, z = 0;
while (l <= r) {
ll mid = (l + r) >> 1, t = 0;
for (ll i = 0; i * p <= mid; ++i) {
t += (mid - i * p) / q + 1;
if (t > k) {
break;
}
}
if (t <= k) {
z = mid;
l = mid + 1;
} else {
r = mid - 1;
}
}
ll sx = 0, sy = 0;
for (ll i = 0; i * p <= z; ++i) {
ll t = (z - i * p) / q;
sx += (t + 1) * i;
sy += t * (t + 1) / 2;
}
if (flag) {
swap(sx, sy);
}
return node(sx, sy);
}
inline bool check(ll k) {
ll l = 1, r = M - 1;
node L(0, 0), R(0, 0);
while (l <= r) {
ll p = (l + r) >> 1;
ll q = M - p;
node t = calc(p, q, k);
if (t.x <= n && t.y <= m) {
return 1;
} else if (t.x > n && t.y > m) {
return 0;
} else if (t.x <= n) {
L = t;
r = p - 1;
} else {
R = t;
l = p + 1;
}
}
return (R - L) * (node(n, m) - L) >= 0;
}
void solve() {
scanf("%lld%lld", &n, &m);
ll l = 1, r = 1e9, ans = -1;
while (l <= r) {
ll mid = (l + r) >> 1;
if (check(mid)) {
ans = mid;
l = mid + 1;
} else {
r = mid - 1;
}
}
printf("%lld\n", ans - 1);
}
int main() {
int T = 1;
// scanf("%d", &T);
while (T--) {
solve();
}
return 0;
}