容易发现一个变色龙是红色当且仅当,设 \(R\) 为红球数量,\(B\) 为蓝球数量,那么 \(R \ge B\) 或 \(R = B\) 且最后一个球是蓝球。
考虑如何判定一个颜色序列是否可行。
考虑贪心。
- 若 \(R < B\) 显然不行。
- 若 \(R \ge B + n\),每个变色龙都可以分到比蓝球数多 \(1\) 的红球,答案为 \(\binom{R + B}{R}\)。
- 若 \(R = B\),考虑删除颜色序列中的最后一位(这一位必然是 \(\texttt{B}\)),转化为 \(B < R < B + n\) 的情况。
- 若 \(B < R < B + n\),考虑分配得尽量平均,每个变色龙分到的蓝球数恰好等于红球数或等于红球数减 \(1\)。那么每个蓝球数等于红球数的变色龙,我们贪心地给它分配 \(\texttt{RB}\) 即可,剩下的全部给蓝球数等于红球数或等于红球数减 \(1\) 的变色龙。
于是现在问题转化为了,求满足可以取出 \(n - (R - B)\) 个 \(\texttt{RB}\) 的 \(\texttt{RB}\) 序列。
考虑转化成,每个前缀的红球数减蓝球数不小于 \(n - (R - B) - B = n - R\),也就是说不能匹配的蓝球数 \(\le R - n\)。
考虑画出折线图,\(\texttt{R}\) 就是 \((x, y) \to (x + 1, y + 1)\),\(\texttt{B}\) 就是 \((x, y) \to (x + 1, y - 1)\)。那么就是要求从 \((0, 0)\) 走到 \((R + B, R - B)\) 且与 \(y = n - R - 1\) 没有交点的路径数。
考虑容斥,总路径数 \(\binom{R + B}{R}\) 减去与 \(y = n - R - 1\) 有交点的路径数。把在 \(y = n - R - 1\) 下方的折线翻上去,等价于从 \((0, 2n - 2R - 2)\) 到 \((R + B, R - B)\)。设 \(x\) 为向右上走的步数,\(y\) 为向右下走的步数,那么 \(x + y = R + B, x - y = 3R - B - 2n + 2\),可得 \(x = 2R - n + 1\)。所以,与 \(y = n - R - 1\) 有交点的路径数就是 \(\binom{x + y}{x} = \binom{R + B}{2R - n + 1}\)。
时间复杂度线性。
code
// Problem: E - Ball Eat Chameleons
// Contest: AtCoder - AtCoder Grand Contest 021
// URL: https://atcoder.jp/contests/agc021/tasks/agc021_e
// Memory Limit: 256 MB
// Time Limit: 2000 ms
//
// Powered by CP Editor (https://cpeditor.org)
#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mems(a, x) memset((a), (x), sizeof(a))
using namespace std;
typedef long long ll;
typedef double db;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<ll, ll> pii;
const int maxn = 2000100;
const int N = 2000000;
const ll mod = 998244353;
inline ll qpow(ll b, ll p) {
ll res = 1;
while (p) {
if (p & 1) {
res = res * b % mod;
}
b = b * b % mod;
p >>= 1;
}
return res;
}
ll n, m, fac[maxn], ifac[maxn];
inline void init() {
fac[0] = 1;
for (int i = 1; i <= N; ++i) {
fac[i] = fac[i - 1] * i % mod;
}
ifac[N] = qpow(fac[N], mod - 2);
for (int i = N - 1; ~i; --i) {
ifac[i] = ifac[i + 1] * (i + 1) % mod;
}
}
inline ll C(ll n, ll m) {
if (n < m || n < 0 || m < 0) {
return 0;
} else {
return fac[n] * ifac[m] % mod * ifac[n - m] % mod;
}
}
void solve() {
scanf("%lld%lld", &n, &m);
ll ans = 0;
for (int i = n; i <= m; ++i) {
int j = m - i;
if (i < j) {
continue;
}
if (i >= j + n) {
ans = (ans + C(i + j, i)) % mod;
} else {
if (i == j) {
--j;
}
ans = (ans + C(i + j, i) - C(i + j, i * 2 - n + 1) + mod) % mod;
}
}
printf("%lld\n", ans);
}
int main() {
init();
int T = 1;
// scanf("%d", &T);
while (T--) {
solve();
}
return 0;
}