其实是套路题,但是为什么做不出来啊
第一步就是经典套路。枚举 \(k\),统计中位数 \(> k\) 的方案数,加起来就是中位数的总和。
那么现在 \(x_{1 \sim n}, y_{1 \sim m}\) 就变成了 \(0/1\) 序列,考虑一次操作,如果 \((x,y) = (0,0)\),那么 \(a\) 会变成 \(0\);如果 \((x,y) = (1,1)\),那么 \(a\) 会变成 \(1\);否则 \(a\) 不变。
到这里我就卡住了。想着枚举最后一次 \((0,0)\) 或 \((1,1)\) 的操作,然后发现根本算不了。
其实,发现如果出现了一次 \(x = y\) 的操作,最后 \(a\) 的取值就跟 \(a\) 原来的值无关了。进一步发现,由于对称性,\(k = p\) 时最后一次 \(x = y\) 的操作是 \((0,0)\) 的方案数和 \(k = V - p\) 时最后一次 \(x = y\) 的操作是 \((1,1)\) 的方案数相等。
现在考虑统计只出现 \(x \ne y\) 的操作的方案数。
考虑一些特殊情况,例如 \(\gcd(n, m) = 1\),那么每个数对 \((p,q), p \in [0,n), q \in [0,m)\) 在 \(\{(i \bmod n, i \bmod m) | i \in [0,nm)\}\) 中出现且仅出现一次。
那么对于一般性的情况,考虑计算 \(g = \gcd(n, m)\),那么 \(i + gt, i \in [0, g), t \in [0, \frac{n}{g})\) 只有可能跟 \(i + gt, i \in [0, g), t \in [0, \frac{m}{g})\) 配对。那么就是要求,\(\forall i \in [0, g)\):
- \(x_i = x_{i + g} = x_{i + 2g} = \cdots = x_{i + (\frac{n}{g} - 1) g}\);
- \(y_i = y_{i + g} = y_{i + 2g} = \cdots = y_{i + (\frac{m}{g} - 1) g}\);
- \(x_i \ne y_i\)。
这个的方案数容易统计,\([1, V]\) 中 \(\le k\) 的数有 \(k\) 个,\(> k\) 的数有 \(V - k\) 个,那么分别讨论 \((x, y)\) 取 \((0, 0)\) 或 \((1, 1)\) 的情况,方案数即为:
\[(k^{\frac{n}{g}} (V - k)^{\frac{m}{g}} + (V - k)^{\frac{n}{g}} k^{\frac{m}{g}})^g \]那么这种情况时 \(a\) 不变,只有原来 \(a\) 为 \(1\) 时才产生贡献。还要计算出现过 \(x = y\) 操作的情况。总方案数 \(V^{n + m}\) 减去上面的式子,得出的就是出现过 \(x = y\) 操作的方案数总和,除以 \(2\) 就是 \((1, 1)\) 的情况。
然后我们就以 \(O(V \log (n + m))\) 的时间复杂度做完了。
code
// Problem: E - Cyclic Medians
// Contest: AtCoder - AtCoder Regular Contest 133
// URL: https://atcoder.jp/contests/arc133/tasks/arc133_e
// Memory Limit: 1024 MB
// Time Limit: 2000 ms
//
// Powered by CP Editor (https://cpeditor.org)
#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mems(a, x) memset((a), (x), sizeof(a))
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef double db;
typedef long double ldb;
typedef pair<ll, ll> pii;
const ll mod = 998244353;
const ll inv2 = (mod + 1) / 2;
inline ll qpow(ll b, ll p) {
ll res = 1;
while (p) {
if (p & 1) {
res = res * b % mod;
}
b = b * b % mod;
p >>= 1;
}
return res;
}
ll n, m, V, A;
void solve() {
scanf("%lld%lld%lld%lld", &n, &m, &V, &A);
ll ans = qpow(V, n + m), g = __gcd(n, m);
for (int i = 1; i < V; ++i) {
ll t = qpow((qpow(i, n / g) * qpow(V - i, m / g) % mod + qpow(V - i, n / g) * qpow(i, m / g) % mod) % mod, g);
if (A > i) {
ans = (ans + t) % mod;
}
ll all = qpow(V, n + m);
t = (all - t + mod) % mod;
ans = (ans + t * inv2 % mod) % mod;
}
printf("%lld\n", ans);
}
int main() {
int T = 1;
// scanf("%d", &T);
while (T--) {
solve();
}
return 0;
}