套路生成函数。
写出答案的式子,设 \(f_i(x)=\sum j^{c_i} x^j\),不难得到答案为:
\[[x^W]{1\over 1-x}\prod_{i=1}^n f_i(x) \]考虑求 \(f_i(x)\)。看到指数上有 \(c_i\),想到用斯特林数展开:
\[f_i(x)=\sum_{j=0}^{\infty} x^j \sum_{k=0}^{c_i} {c_i\brace k}\binom{j}{k}k! \]\[=\sum_{k=0}^{c_i} {c_i\brace k} k! \sum_{j=k}^{\infty}\binom{j}{k}x^j \]注意到后面的式子是组合数一列的生成函数,设其为 \(g_k(x)\),乘上 \(1\over 1-x\) 对其系数做前缀和,由组合数按列求和的公式可知: \({1\over 1-x}g_k(x)={1\over x}g_{k+1}(x),g_k(x)={x^k\over (1-x)^{k+1}}\)。将每个 \(x^k\over (1-x)^{k+1}\) 分母中的一个 \(1-x\) 提到外面,可以将答案写为:
\[[x^W]{1\over(1-x)^{n+1}}\prod_{i=1}^n\sum_{k=0}^{c_i}{c_i\brace k}k! \left({x\over 1-x}\right)^k \]将后面的 \({x\over 1-x}\) 看成一整个变量,后面的那坨东西是可以分治 NTT 快速计算的。设得到的多项式为 \(F(t)\),那么总答案可以写为:
\[[x^W] \sum_k {F_kx^k\over (1-x)^{n+k+1}} \]虽然 \(W\) 很大,但是 \(F\) 有值的项数非常少。我们枚举 \(F\) 的项,由牛顿二项式定理可知 \([x^{W-k}]{1\over (1-x)^{n+k+1}}=\binom{W+n}{n+k}\),处理下降幂即可通过。
#include <bits/stdc++.h>
using namespace std;
#define pi pair<int,int>
#define mp make_pair
#define poly vector<int>
typedef long long ll;
const int N = 3e5 + 5, mod = 998244353, G = 3, Gi = (mod + 1) / G;
const int add(int a, int b) { return (a + b) >= mod ? a + b - mod : a + b; }
const int sub(int a, int b) { return a < b ? a - b + mod : a - b; }
const int mul(int a, int b) { return (1ll * a * b) % mod; }
const int power(int a, int b) {
int t = 1, y = a, k = b;
while (k) {
if (k & 1) t = mul(t, y);
y = mul(y, y); k >>= 1;
} return t;
}
inline int read() {
register int s = 0, f = 1; register char ch = getchar();
while (!isdigit(ch)) f = (ch == '-' ? -1 : 1), ch = getchar();
while (isdigit(ch)) s = (s << 1) + (s << 3) + (ch & 15), ch = getchar();
return s * f;
}
inline ll readll() {
register ll s = 0, f = 1; register char ch = getchar();
while (!isdigit(ch)) f = (ch == '-' ? -1 : 1), ch = getchar();
while (isdigit(ch)) s = (s << 1) + (s << 3) + (ch & 15), ch = getchar();
return s * f;
}
inline void FFT(int *a, int len, int typ) {
for (register int i = 0, j = 0, k; i < len; ++i) {
if (i < j) swap(a[i], a[j]);
for (k = len >> 1; k & j; k >>= 1) j ^= k;
j ^= k;
}
for (register int mid = 1; mid < len; mid <<= 1) {
int wn = power(typ == 1 ? G : Gi, (mod - 1) / (mid << 1));
for (register int j = 0; j < len; j += mid << 1) {
int bas = 1;
for (register int k = 0; k < mid; ++k, bas = mul(bas, wn)) {
int x = a[j + k], y = ::mul(bas, a[j + mid + k]);
a[j + k] = ::add(x, y);
a[j + mid + k] = ::sub(x, y);
}
}
}
if (!typ) {
const int iv = power(len, mod - 2);
for (register int i = 0; i < len; ++i)
a[i] = ::mul(a[i], iv);
}
}
inline int max_(int a, int b) {
return a > b ? a : b;
}
inline int min_(int a, int b) {
return a < b ? a : b;
}
inline poly add(poly a, int b) {
for (register int i = 0; i < a.size(); ++i) a[i] = ::add(a[i], b);
return a;
}
inline poly sub(poly a, int b) {
for (register int i = 0; i < a.size(); ++i) a[i] = ::sub(a[i], b);
return a;
}
inline poly mul(poly a, int b) {
for (register int i = 0; i < a.size(); ++i) a[i] = ::mul(a[i], b);
return a;
}
inline poly div(poly a, int b) {
b = ::power(b, mod - 2);
for (register int i = 0; i < a.size(); ++i) a[i] = ::mul(a[i], b);
return a;
}
inline poly add(poly a, poly b) {
a.resize(max_(a.size(), b.size()));
for (register int i = 0; i < b.size(); ++i) a[i] = ::add(a[i], b[i]);
return a;
}
inline poly sub(poly a, poly b) {
a.resize(max_(a.size(), b.size()));
for (register int i = 0; i < b.size(); ++i) a[i] = ::sub(a[i], b[i]);
return a;
}
inline poly mul(poly a, poly b) {
int p = a.size() + b.size() - 1; int len = 1 << (int)ceil(log2(p));
a.resize(len); b.resize(len);
FFT(&a[0], len, 1); FFT(&b[0], len, 1);
for (register int i = 0; i < len; ++i)
a[i] = ::mul(a[i], b[i]);
FFT(&a[0], len, 0); a.resize(p);
return a;
}
int fac[N], ifac[N];
inline poly stir2R(int n) {
poly a, b, c;
for (int i = 0; i <= n; ++i) {
a.push_back((1ll * power(i, n) * ifac[i]) % mod);
b.push_back(ifac[i]); if (i & 1) b[i] = -b[i] + mod;
if (b[i] >= mod) b[i] -= mod;
} c = mul(a, b); c.resize(n + 1); return c;
}
inline int C(int n, int m) {
if (n < m) return 0;
return 1ll * fac[n] * (1ll * ifac[m] * ifac[n - m] % mod) % mod;
}
inline void init() {
fac[0] = 1;
for (register int i = 1; i < N; ++i) fac[i] = (1ll * i * fac[i - 1]) % mod;
ifac[N - 1] = power(fac[N - 1], mod - 2);
for (register int i = N - 2; ~i; --i) ifac[i] = (1ll * (i + 1) * ifac[i + 1]) % mod;
}
inline void otp(int x) {
(x >= 10) ? otp(x / 10), putchar((x % 10) ^ 48) : putchar(x ^ 48);
}
int n, c[N], cnt[N], mx = 0;
ll W;
vector<poly> F;
poly g;
inline poly calc(int l, int r) {
if (l == r) return F[l];
if (l + 1 == r) return mul(F[l], F[r]);
int mid = l + r >> 1;
return mul(calc(l, mid), calc(mid + 1, r));
}
int main() {
init();
n = read(); W = readll();
for (int i = 1; i <= n; ++i) ++cnt[c[i] = read()], mx = max(mx, c[i]);
for (int i = 0; i <= mx; ++i) {
if (!cnt[i]) continue;
poly s = stir2R(i);
for (int j = 0; j <= i; ++j)
s[j] = 1ll * s[j] * fac[j] % mod;
while (cnt[i]--) F.push_back(s);
} g = calc(0, F.size() - 1);
int x = 1, res = 0;
for (int i = 0; i < n; ++i) x = 1ll * ((W + n - i) % mod) * x % mod;
for (int k = 0; k < g.size() && x; ++k) {
res += 1ll * g[k] * (1ll * x * ifac[n + k] % mod) % mod;
if (res >= mod) res -= mod;
x = 1ll * ((W - k) % mod) * x % mod;
} printf("%d\n", res);
return 0;
}
标签:Ball,return,int,poly,gym103415A,inline,size,Math,mod
From: https://www.cnblogs.com/wwlwakioi/p/17929388.html