分治NTT
题意
给一个长度为 \(n\;(1<=n<=10^5)\) 的数组 \(A\;(A[i]<=10)\), 给定 \(M\;(1<=M<=10^6)\), 求在 \(A\) 中选 奇数 个数,满足它们的和为 \(M\) 的方案数
思路
-
先不考虑要选奇数个数,根据生成函数,$F=\prod (1+x^{A[i]}) $
\(ans=[x^M]F\)
-
若要求奇数个,就是 \(\prod (1+x^{A[i]})\) 中只能有奇数个选择 \(x^{A[i]}\) 这一项
-
令 \(G=\prod(1-x^{A[i]})\),
则 \([x^M]F\) 为选择了 奇数 + 偶数 个项的方案数,\([x^M]G\) 为选择了 偶数 - 奇数 个项的方案数
-
\(ans=[x^M]\frac {F-G}2\)
-
用分治NTT计算 F, G 的复杂度为 \(O(nlog^2n)\), 不一定能过,可以考虑优化
-
观察题目条件,\(1<=A[i]<=10\), 说明 F,G 中不同的 \(1+x^{A[i]}\) 只有 10 种,可记录下 \(cnt[i]\) 表示 \(1+x^i\) 的项的次数,用二项式定理展开,即求 F,G 的复杂度为 \(O(nlogn*log10)\)
代码
#include <bits/stdc++.h>
using namespace std;
const int md = 998244353;
inline void add(int &x, int y) {
x += y;
if (x >= md) {
x -= md;
}
}
inline void sub(int &x, int y) {
x -= y;
if (x < 0) {
x += md;
}
}
inline int mul(int x, int y) {
return (long long) x * y % md;
}
inline int power(int x, int y) {
int res = 1;
for (; y; y >>= 1, x = mul(x, x)) {
if (y & 1) {
res = mul(res, x);
}
}
return res;
}
inline int inv(int a) {
a %= md;
if (a < 0) {
a += md;
}
int b = md, u = 0, v = 1;
while (a) {
int t = b / a;
b -= t * a;
swap(a, b);
u -= t * v;
swap(u, v);
}
if (u < 0) {
u += md;
}
return u;
}
namespace ntt {
int base = 1, root = -1, max_base = -1;
vector<int> rev = {0, 1}, roots = {0, 1};
void init() {
int temp = md - 1;
max_base = 0;
while (temp % 2 == 0) {
temp >>= 1;
++max_base;
}
root = 2;
while (true) {
if (power(root, 1 << max_base) == 1 && power(root, 1 << (max_base - 1)) != 1) {
break;
}
++root;
}
}
void ensure_base(int nbase) {
if (max_base == -1) {
init();
}
if (nbase <= base) {
return;
}
assert(nbase <= max_base);
rev.resize(1 << nbase);
for (int i = 0; i < 1 << nbase; ++i) {
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (nbase - 1));
}
roots.resize(1 << nbase);
while (base < nbase) {
int z = power(root, 1 << (max_base - 1 - base));
for (int i = 1 << (base - 1); i < 1 << base; ++i) {
roots[i << 1] = roots[i];
roots[i << 1 | 1] = mul(roots[i], z);
}
++base;
}
}
void dft(vector<int> &a) {
int n = a.size(), zeros = __builtin_ctz(n);
ensure_base(zeros);
int shift = base - zeros;
for (int i = 0; i < n; ++i) {
if (i < rev[i] >> shift) {
swap(a[i], a[rev[i] >> shift]);
}
}
for (int i = 1; i < n; i <<= 1) {
for (int j = 0; j < n; j += i << 1) {
for (int k = 0; k < i; ++k) {
int x = a[j + k], y = mul(a[j + k + i], roots[i + k]);
a[j + k] = (x + y) % md;
a[j + k + i] = (x + md - y) % md;
}
}
}
}
vector<int> multiply(vector<int> a, vector<int> b) {
int need = a.size() + b.size() - 1, nbase = 0;
while (1 << nbase < need) {
++nbase;
}
ensure_base(nbase);
int sz = 1 << nbase;
a.resize(sz);
b.resize(sz);
bool equal = a == b;
dft(a);
if (equal) {
b = a;
} else {
dft(b);
}
int inv_sz = inv(sz);
for (int i = 0; i < sz; ++i) {
a[i] = mul(mul(a[i], b[i]), inv_sz);
}
reverse(a.begin() + 1, a.end());
dft(a);
a.resize(need);
return a;
}
vector<int> inverse(vector<int> a) {
int n = a.size(), m = (n + 1) >> 1;
if (n == 1) {
return vector<int>(1, inv(a[0]));
} else {
vector<int> b = inverse(vector<int>(a.begin(), a.begin() + m));
int need = n << 1, nbase = 0;
while (1 << nbase < need) {
++nbase;
}
ensure_base(nbase);
int sz = 1 << nbase;
a.resize(sz);
b.resize(sz);
dft(a);
dft(b);
int inv_sz = inv(sz);
for (int i = 0; i < sz; ++i) {
a[i] = mul(mul(md + 2 - mul(a[i], b[i]), b[i]), inv_sz);
}
reverse(a.begin() + 1, a.end());
dft(a);
a.resize(n);
return a;
}
}
}
using ntt::multiply;
using ntt::inverse;
vector<int>& operator += (vector<int> &a, const vector<int> &b) {
if (a.size() < b.size()) {
a.resize(b.size());
}
for (int i = 0; i < b.size(); ++i) {
add(a[i], b[i]);
}
return a;
}
vector<int> operator + (const vector<int> &a, const vector<int> &b) {
vector<int> c = a;
return c += b;
}
vector<int>& operator -= (vector<int> &a, const vector<int> &b) {
if (a.size() < b.size()) {
a.resize(b.size());
}
for (int i = 0; i < b.size(); ++i) {
sub(a[i], b[i]);
}
return a;
}
vector<int> operator - (const vector<int> &a, const vector<int> &b) {
vector<int> c = a;
return c -= b;
}
vector<int>& operator *= (vector<int> &a, const vector<int> &b) {
if (min(a.size(), b.size()) < 128) {
vector<int> c = a;
a.assign(a.size() + b.size() - 1, 0);
for (int i = 0; i < c.size(); ++i) {
for (int j = 0; j < b.size(); ++j) {
add(a[i + j], mul(c[i], b[j]));
}
}
} else {
a = multiply(a, b);
}
return a;
}
vector<int> operator * (const vector<int> &a, const vector<int> &b) {
vector<int> c = a;
return c *= b;
}
vector<int> fz_ntt(vector<vector<int> > &f, int l, int r)
{
if (l == r)
return f[l];
int mid = l + r >> 1;
auto L = fz_ntt(f, l, mid);
auto R = fz_ntt(f, mid + 1, r);
return L * R;
}
typedef long long ll;
const int N = 1e5 + 10;
int cnt[12];
ll fac[N], finv[N];
ll qmi(ll a, ll b)
{
ll ans = 1;
while(b)
{
if (b & 1)
ans = ans * a % md;
b >>= 1;
a = a * a % md;
}
return ans;
}
void presolve(int n)
{
fac[0] = finv[0] = 1;
for (int i = 1; i <= n; i++)
fac[i] = fac[i-1] * i % md;
finv[n] = qmi(fac[n], md - 2);
for (int i = n - 1; i >= 1; i--)
finv[i] = finv[i+1] * (i + 1) % md;
}
ll C(int n, int m)
{
if (m < 0 || n - m < 0)
return 0;
return fac[n] * finv[m] % md * finv[n-m] % md;
}
int main()
{
int n, m;
scanf("%d%d", &n, &m);
presolve(n);
for (int i = 1; i <= n; i++)
{
int x;
scanf("%d", &x);
cnt[x]++;
}
vector<vector<int> > pa, pb;
for (int i = 1; i <= 10; i++)
{
vector<int> a(cnt[i] * i + 1, 0);
vector<int> b(cnt[i] * i + 1, 0);
for (int j = 0; j <= cnt[i]; j++)
{
a[j * i] = C(cnt[i], j);
b[j * i] = (j % 2 == 0 ? a[j * i] : (md - a[j * i]));
}
pa.push_back(a);
pb.push_back(b);
}
auto ans1 = fz_ntt(pa, 0, 9);
auto ans2 = fz_ntt(pb, 0, 9);
if (ans1.size() - 1 < m)
{
puts("0");
return 0;
}
int ans = (ll)(ans1[m] - ans2[m]) * (md + 1) / 2 % md;
if (ans < 0)
ans += md;
printf("%d\n", ans);
return 0;
}
标签:md,return,int,Sum,ABC267Ex,vector,const,Odd,size
From: https://www.cnblogs.com/hzy717zsy/p/16755235.html