概率的分母在不断变化很麻烦,我们不妨令它可以打到已死的人。由于还活着的人概率之比没有变,显然是不会影响答案的。
考虑容斥,设 \(p(S)\) 表示集合 \(S\) 中的人在 \(1\) 后被打的方案数,那么答案就是 \(\sum_{S}(-1)^{|S|}p(S)\)。\(p(S)\) 实际上就是无限开枪,每次不打 \(S\cup \{1\}\) 的概率,枚举打到 \(1\) 之前打了多少次,令 \(sum=\sum w_i\),则容易得到 \(p(S)=\sum \left({sum-w_1-\operatorname{sum}(S)\over sum}\right)^i\cdot {w_1\over sum}\)。这是个等比数列,容易得到 \({a_1\over a_1+\operatorname{sum}(S)}\)
再看答案的式子,枚举 \(S\) 是指数级的,不好算。观察数据范围我们发现 \(sum\) 只有 \(10^5\) 的级别,不妨直接枚举 \(sum\)。令 \(f(i)=\sum\limits_{\operatorname{sum}(S)=i}(-1)^{|S|}\),答案容易算出。关键在求 \(f(i)\),考虑生成函数。显然 \(f(i)=[x^i] \prod\limits_{i=2}^{n} 1-x^{w_i}\)。取 \(\ln\) 之后 \(\exp\) 即可。
#include <bits/stdc++.h>
using namespace std;
#define pi pair<int,int>
#define mp make_pair
#define poly vector<int>
const int N = 4e5 + 5, mod = 998244353, G = 3, I = 86583718, Gi = 332748118;
const int add(int a, int b) { return (a + b) >= mod ? a + b - mod : a + b; }
const int sub(int a, int b) { return a < b ? a - b + mod : a - b; }
const int mul(int a, int b) { return (1ll * a * b) % mod; }
const int power(int a, int b) {
int t = 1, y = a, k = b;
while (k) {
if (k & 1) t = mul(t, y);
y = mul(y, y); k >>= 1;
} return t;
}
const int inv2 = power(2, mod - 2), inv2I = power(mul(2, I), mod - 2);
inline int read() {
register int s = 0, f = 1; register char ch = getchar();
while (!isdigit(ch)) f = (ch == '-' ? -1 : 1), ch = getchar();
while (isdigit(ch)) s = (s << 1) + (s << 3) + (ch & 15), ch = getchar();
return s * f;
}
inline void FFT(int *a, int len, int typ) {
for (int i = 0, j = 0, k; i < len; ++i) {
if (i < j) swap(a[i], a[j]);
for (k = len >> 1; k & j; k >>= 1) j ^= k;
j ^= k;
}
for (int mid = 1; mid < len; mid <<= 1) {
int wn = power(typ == 1 ? G : Gi, (mod - 1) / (mid << 1));
for (int j = 0; j < len; j += mid << 1) {
int bas = 1;
for (int k = 0; k < mid; ++k, bas = mul(bas, wn)) {
int x = a[j + k], y = ::mul(bas, a[j + mid + k]);
a[j + k] = ::add(x, y);
a[j + mid + k] = ::sub(x, y);
}
}
}
if (!typ) {
const int iv = power(len, mod - 2);
for (int i = 0; i < len; ++i)
a[i] = ::mul(a[i], iv);
}
}
inline int max_(int a, int b) {
return a > b ? a : b;
}
inline int min_(int a, int b) {
return a < b ? a : b;
}
inline poly add(poly a, int b) {
for (int i = 0; i < a.size(); ++i) a[i] = ::add(a[i], b);
return a;
}
inline poly sub(poly a, int b) {
for (int i = 0; i < a.size(); ++i) a[i] = ::sub(a[i], b);
return a;
}
inline poly mul(poly a, int b) {
for (int i = 0; i < a.size(); ++i) a[i] = ::mul(a[i], b);
return a;
}
inline poly div(poly a, int b) {
b = ::power(b, mod - 2);
for (int i = 0; i < a.size(); ++i) a[i] = ::mul(a[i], b);
return a;
}
inline poly add(poly a, poly b) {
a.resize(max_(a.size(), b.size()));
for (int i = 0; i < b.size(); ++i) a[i] = ::add(a[i], b[i]);
return a;
}
inline poly sub(poly a, poly b) {
a.resize(max_(a.size(), b.size()));
for (int i = 0; i < b.size(); ++i) a[i] = ::sub(a[i], b[i]);
return a;
}
inline poly mul(poly a, poly b) {
int p = a.size() + b.size() - 1; int len = 1 << (int)ceil(log2(p));
a.resize(len); b.resize(len);
FFT(&a[0], len, 1); FFT(&b[0], len, 1);
for (int i = 0; i < len; ++i)
a[i] = ::mul(a[i], b[i]);
FFT(&a[0], len, 0); a.resize(p);
return a;
}
inline poly inv(poly a, int len) {
if (len == 1) return poly(1, power(a[0], mod - 2));
int n = 1 << ((int)ceil(log2(len)) + 1);
poly x = inv(a, len + 1 >> 1), y;
x.resize(n); y.resize(n);
for (int i = 0; i < len; ++i) y[i] = a[i];
FFT(&x[0], n, 1); FFT(&y[0], n, 1);
for (int i = 0; i < n; ++i) x[i] = ::mul(x[i], ::sub(2, ::mul(x[i], y[i])));
FFT(&x[0], n, 0);
x.resize(len);
return x;
}
inline poly inv(poly a) {
return inv(a, a.size());
}
inline poly rev(poly a) {
reverse(a.begin(), a.end());
return a;
}
inline poly div(poly a, poly b) {
if (a.size() < b.size()) return poly(1, 0);
int p = a.size() - b.size() + 1;
poly ra = rev(a), rb = rev(b);
ra.resize(p), rb.resize(p);
ra = mul(ra, inv(rb));
ra.resize(p);
return rev(ra);
}
inline poly remainder(poly a, poly b) {
if (a.size() < b.size()) return a;
poly c = div(a, b), d = sub(a, mul(b, c));
while (d.size() && !d.back()) d.pop_back();
if (!d.size()) d = poly(1, 0);
return d;
}
inline poly det(poly a) {
int n = a.size();
for (int i = 1; i < n; ++i) a[i - 1] = ::mul(a[i], i);
a.resize(n - 1);
return a;
}
inline poly inter(poly a) {
int n = a.size(); a.resize(n + 1);
for (int i = n; i >= 1; --i)
a[i] = ::mul(a[i - 1], power(i, mod - 2));
a[0] = 0; return a;
}
inline poly ln(poly a) {
int n = a.size();
a = inter(mul(det(a), inv(a)));
a.resize(n); return a;
}
inline poly exp(poly a, int len) {
if (len == 1) return poly(1, 1);
poly x = exp(a, len + 1 >> 1); x.resize(len);
poly y = ln(x);
for (int i = 0; i < len; ++i) y[i] = ::sub(a[i], y[i]);
++y[0]; x = mul(x, y); x.resize(len);
return x;
}
inline poly exp(poly a) {
return exp(a, a.size());
}
int n, inver[N], tim[N], sum = 0, w[N];
poly a;
int main() {
n = read();
for (int i = 1; i <= n; ++i)
sum += w[i] = read();
a.resize(sum + 1);
for (int i = 2; i <= n; ++i) ++tim[w[i]];
for (int i = 1; i <= sum; ++i) inver[i] = power(i, mod - 2);
for (int i = 1; i <= sum; ++i)
if (tim[i])
for (int j = 1; j <= sum / i; ++j)
a[j * i] = sub(a[j * i], ::mul(inver[j], tim[i]));
a = exp(a); int res = 0;
for (int i = 0; i <= sum; ++i) {
res += 1ll * a[i] * (1ll * w[1] * power(w[1] + i, mod - 2) % mod) % mod;
if (res >= mod) res -= mod;
} printf("%d\n", res);
return 0;
}
标签:return,int,poly,猎人,mul,inline,PKUWC2018,size
From: https://www.cnblogs.com/wwlwakioi/p/17368088.html