考虑到 \(N\) 的字符组成其实是固定的。
所以可以把方案数拆为 \(A\) 的方案数 \(\times\) \(A, B\) 相匹配的方案数。
对于 \(A\) 的方案数,就是多重集组合数,为 \(\dfrac{n!}{\prod\limits_{i = 0}^{25} (cnt_{A, i}!)}\)。
接下来考虑求解 \(A, B\) 相匹配的方案数。
考虑到对于字符 \(c\),只能由 \(c, c + 1\) 来补。
那如果确定了 \(c + 1\) 的个数为 \(x\),就知道了还有 \(cnt_{A, c} - x\) 个需要用 \(c\) 来填。
那么剩下的又可以去选择部分填到 \(c - 1\)。
于是可以列出 \(\text{DP}\)。
设 \(f_{i, x}\) 为考虑了后 \(i\) 个字符,第 \(i + 1\) 填了 \(x\) 个字符给 \(i\) 的方案数。
那么令 \(s = cnt_{B, i} - (cnt_{A, i} - x)\),即填完 \(i\) 后 \(i\) 字符还剩的数量。
就有转移 \(f_{i - 1, y}\leftarrow f_{i, x}\times \binom{cnt_{A, i - 1}}{y}(0\le y\le \min\{s, cnt_{A, i - 1}\})\),指选 \(y\) 个填到 \(i - 1\),就有 \(\binom{cnt_{A, i - 1}}{y}\) 种选取方案。
发现转移到 \(f_{i, x}\) 一定乘了 \(\binom{cnt_{A, i}}{x}\),令 \(f_{i, x} = g_{i, x}\times\binom{cnt_{A, i}}{x}\)。
转移就变为了 \(g_{i - 1, y}\leftarrow f_{i, x}(0\le y\le \min\{s, cnt_{A, i - 1}\})\)。
那就可以差分优化 \(g\) 的转移了。
时间复杂度 \(O(n + m)\)。
代码
#include<bits/stdc++.h>
using ll = long long;
const ll mod = 998244353;
const int maxn = 2e5 + 10;
ll fac[maxn], inv[maxn], invf[maxn];
inline ll binom(int n, int m) {
return fac[n] * invf[n - m] % mod * invf[m] % mod;
}
int cntA[26], cntB[26];
char S[maxn];
int f[26][maxn], g[26][maxn];
int main() {
int n, m;
scanf("%d%d", &n, &m);
fac[0] = fac[1] = inv[0] = inv[1] = invf[0] = invf[1] = 1;
for (int i = 2; i <= n; i++) {
fac[i] = fac[i - 1] * i % mod;
inv[i] = inv[mod % i] * (mod - mod / i) % mod;
invf[i] = invf[i - 1] * inv[i] % mod;
}
scanf("%s", S + 1);
for (int i = 1; i <= n; i++)
cntA[S[i] - 'A']++;
scanf("%s", S + 1);
for (int i = 1; i <= m; i++)
cntB[S[i] - 'A']++;
g[25][0] = 1, g[25][1] = mod - 1;
ll sum = 0;
for (int i = 25; ~ i; i--) {
for (int j = 1; j <= cntA[i]; j++)
(g[i][j] += g[i][j - 1]) %= mod;
for (int j = 0; j <= cntA[i]; j++)
f[i][j] = g[i][j] * binom(cntA[i], j) % mod;
for (int j = 0; j <= cntA[i]; j++) {
int s = cntB[i] - (cntA[i] - j);
if (s < 0)
continue;
if (! i)
(sum += f[i][j]) %= mod;
else {
s = std::min(s, cntA[i - 1]);
(g[i - 1][0] += f[i][j]) %= mod, (g[i - 1][s + 1] += mod - f[i][j]) %= mod;
}
}
}
ll c = fac[n];
for (int i = 0; i < 26; i++)
(c *= invf[cntA[i]]) %= mod;
printf("%lld\n", c * sum % mod);
return 0;
}