Description
给一个数组 \({a_i}, i=1, \cdots, n\),对 \(j=0, 1,\cdots, m-1\) ,计算其中有多少个大小为 \(k\) 的子序列满足其异或和为 \(j\)。
- \(n\leq 10^5\)
- $ m\leq 65536$
Solution
首先答案是
\[[y^k]\prod_{i=1}^n (1+x^{a_i}y) \]这里对 \(y\) 做的是多项式乘法,对 \(x\) 做的是异或卷积。
正常做法就是直接 FWT,算完连乘积再 IFWT 回去。
用类似《黎明前的巧克力》那道题的套路,对每个单项式 \(x^{a_i}\),它 FWT 的结果每项都是 \(\pm 1\)。
根据 FWT 的线性性,所有单项式相加后 FWT 的结果等于每个单项式 FWT 的结果之和。我们算出单项式之和的 FWT 的每一项系数 \(c_i\),然后根据
\[x_{i,1}+x_{i,-1}=n,x_{i,1}-x_{i,-1}=c_i \]即可计算出每个位置 \(\pm 1\) 的数量。
最后 FWT 的连乘积,第 \(i\) 项就是 \([y^k](1+y)^{x_{1,i}}(1-y)^{x_{-1,i}}\)。
把它求出来,然后 IFWT 回去,就是最终答案。
瓶颈在算上面这个 \(y^k\) 的系数,直接算是 \(O(km)\) 的,会T。
令 \(t=x_{-1,i}\),下面对所有的 \(t=0,1,\dots,n\) 计算 \(y^k\) 的系数。
\[[y^k](1+y)^{n-t}(1-y)^t\\=\sum_{i=0}^k(-1)^i\binom ti\binom{n-t}{k-i}\\=\sum_{i=0}^t(-1)^i\binom ti\binom{n-t}{k-i}\\=t!(n-t)!\sum_{i=0}^t\frac {(-1)^i}{i!(k-i)!(t-i)!(n-t-k+i)!} \]令
\[f_i=\frac {(-1)^i}{i!(k-i)!}, g_i=\frac 1{i!(n-k-i)!} \]然后卷积即可。
这样总复杂度为 \(O(m\log m+n\log n)\)。
Code
#define LOCAL
#include "bits/stdc++.h"
using namespace std;
using ui=unsigned; using db=long double; using ll=long long; using ull=unsigned long long; using lll=__int128;
using pii=pair<int,int>; using pll=pair<ll,ll>;
template<class T1, class T2> istream &operator>>(istream &cin, pair<T1, T2> &a) { return cin>>a.first>>a.second; }
template <std::size_t Index=0, typename... Ts> typename std::enable_if<Index==sizeof...(Ts), void>::type tuple_read(std::istream &is, std::tuple<Ts...> &t) { }
template <std::size_t Index=0, typename... Ts> typename std::enable_if<Index < sizeof...(Ts), void>::type tuple_read(std::istream &is, std::tuple<Ts...> &t) { is>>std::get<Index>(t); tuple_read<Index+1>(is, t); }
template <typename... Ts>std::istream &operator>>(std::istream &is, std::tuple<Ts...> &t) { tuple_read(is, t); return is; }
template<class T1> istream &operator>>(istream &cin, vector<T1> &a) { for (auto &x:a) cin>>x; return cin; }
template<class T1> istream &operator>>(istream &cin, valarray<T1> &a) { for (auto &x:a) cin>>x; return cin; }
template<class T1, class T2> bool cmin(T1 &x, const T2 &y) { if (y<x) { x=y; return 1; } return 0; }
template<class T1, class T2> bool cmax(T1 &x, const T2 &y) { if (x<y) { x=y; return 1; } return 0; }
istream &operator>>(istream &cin, lll &x) { x=0; static string s; cin>>s; for (char c:s) x=x*10+(c-'0'); return cin; }
ostream &operator<<(ostream &cout, lll x) { static char s[60]; int tp=1; s[0]='0'+(x%10); while (x/=10) s[tp++]='0'+(x%10); while (tp--) cout<<s[tp]; return cout; }
#if !defined(ONLINE_JUDGE)&&defined(LOCAL)
#include "my_header/IO.h"
#include "my_header/defs.h"
#else
#define dbg(...) ;
#define dbgx(...) ;
#define dbg1(x) ;
#define dbg2(x) ;
#define dbg3(x) ;
#define DEBUG(msg) ;
#define REGISTER_OUTPUT_NAME(Type, ...) ;
#define REGISTER_OUTPUT(Type, ...) ;
#endif
#define all(x) (x).begin(),(x).end()
#define print(...) cout<<format(__VA_ARGS__)
#define println(...) cout<<format(__VA_ARGS__)<<'\n'
#define err(...) cerr<<format(__VA_ARGS__)
#define errln(...) cerr<<format(__VA_ARGS__)<<'\n'
namespace NTT
{
const ull g=3, p=998244353;
const int N=1<<19;//务必修改
ull w[N];
int r[N];
ull ksm(ull x, ull y)
{
ull r=1;
while (y)
{
if (y&1) r=r*x%p;
x=x*x%p;
y>>=1;
}
return r;
}
void init(int n)
{
static int pr=0, pw=0;
if (pr==n) return;
int b=__lg(n)-1, i, j, k;
for (i=1; i<n; i++) r[i]=r[i>>1]>>1|(i&1)<<b;
if (pw<n)
{
for (j=1; j<n; j=k)
{
k=j*2;
ull wn=ksm(g, (p-1)/k);
w[j]=1;
for (i=j+1; i<k; i++) w[i]=w[i-1]*wn%p;
}
pw=n;
}
pr=n;
}
int cal(int x) { return 1<<__lg(max(x, 1)*2-1); }
struct Q:vector<ull>
{
bool flag;
Q &operator%=(int n) { resize(n); return *this; }
Q operator%(int n) const
{
if (size()<=n)
{
auto f=*this;
return f%=n;
}
return Q(vector(begin(), begin()+n));
}
int deg() const
{
int n=size()-1;
while (n>=0&&begin()[n]==0) --n;
return n;
}
explicit Q(int x=1, bool f=0):flag(f), vector<ull>(cal(x)) { }//小心:{}会调用这条而非下一条
Q(const vector<ull> &o, bool f=0):Q(o.size(), f) { copy(all(o), begin()); }
void dft()
{
int n=size(), i, j, k;
ull y, *f, *g, *wn, *a=data();
init(n);
for (i=1; i<n; i++) if (i<r[i]) ::swap(a[i], a[r[i]]);
for (k=1; k<n; k*=2)
{
wn=w+k;
for (i=0; i<n; i+=k*2)
{
g=(f=a+i)+k;
for (j=0; j<k; j++)
{
y=g[j]*wn[j]%p;
g[j]=f[j]+p-y;
f[j]+=y;
}
}
if (k*2==n||k==1<<14) for (i=0; i<n; i++) a[i]%=p;
}
if (flag)
{
y=ksm(n, p-2);
for (i=0; i<n; i++) a[i]=a[i]*y%p;
reverse(a+1, a+n);
}
flag^=1;
}
};
Q &operator*=(Q &f, Q g)//卷积
{
if (f.flag|g.flag)
{
int n=f.size(), i;
assert(n==g.size());
if (!f.flag) f.dft();
if (!g.flag) g.dft();
for (i=0; i<n; i++) (f[i]*=g[i])%=p;
f.dft();
}
else
{
int n=cal(f.size()+g.size()-1), i, j;
int m1=f.deg(), m2=g.deg();
if ((ull)m1*m2>(ull)n*__lg(n)*8)
{
(f%=n).dft(); (g%=n).dft();
for (i=0; i<n; i++) (f[i]*=g[i])%=p;
f.dft();
}
else
{
vector<ull> r(max(0, m1+m2+1));
for (i=0; i<=m1; i++) for (j=0; j<=m2; j++) (r[i+j]+=f[i]*g[j])%=p;
f=Q(n);
copy(all(r), f.begin());
}
}
return f;
}
}
using NTT::p;
using poly=NTT::Q;
int cy[100005];
void init(int n, int k)
{
poly x(n+1), y(n+1);
vector<ull> fac(n+1), inv(n+1);
for(int i=fac[0]=inv[0]=inv[1]=1; i<=n; ++i) fac[i]=(ull)fac[i-1]*i%p;
for(int i=2; i<=n; ++i) inv[i]=(ull)inv[p%i]*(p-p/i)%p;
for(int i=2; i<=n; ++i) inv[i]=(ull)inv[i-1]*inv[i]%p;
for(int i=0; i<=k; ++i) x[i]=(ull)((i&1)?(p-1ull):1ull)*inv[i]%p*inv[k-i]%p;
for(int i=0; i+k<=n; ++i) y[i]=(ull)inv[i]*inv[n-k-i]%p;
x*=y;
for(int i=0; i<=n; ++i) cy[i]=(ull)fac[i]*fac[n-i]%p*(x[i]%p+p)%p;
}
void fwt_xor(vector<ui> &A)
{
ui n=A.size(),*a=A.data(),i,j,k,l,*f,*g;
for (i=1;i<n;i=l)
{
l=i*2;
for (j=0;j<n;j+=l)
{
f=a+j;g=a+j+i;
for (k=0;k<i;k++)
{
if ((f[k]+=g[k])>=p) f[k]-=p;
g[k]=(f[k]+2*(p-g[k]))%p;
}
}
}
}
void ifwt_xor(vector<ui> &A)
{
ui n=A.size(),*a=A.data(),i,j,k,l,*f,*g,x=p+1>>1,y=1;
for (i=1;i<n;i=l)
{
l=i*2;
for (j=0;j<n;j+=l)
{
f=a+j;g=a+j+i;
for (k=0;k<i;k++)
{
if ((f[k]+=g[k])>=p) f[k]-=p;
g[k]=(f[k]+2*(p-g[k]))%p;
}
}
y=(ull)y*x%p;
}
for (i=0;i<n;i++) a[i]=(ull)a[i]*y%p;
}
int main()
{
ios::sync_with_stdio(0); cin.tie(0);
cout<<fixed<<setprecision(15);
ll n, k, b, x;
cin >> n >> k >> b;
init(n, k);
int B = NTT::cal(b);
vector<ui> f(B), g(B);
for(int i=1;i<=n;++i) cin>>x, f[x]++;
fwt_xor(f);
for(int i=0;i<B;++i)
{
int pos=((n+p-f[i])%p+p)%p/2;
g[i]=cy[pos];
}
ifwt_xor(g);
for(int i=0;i<b;++i) cout<<g[i]<<" ";
}
标签:std,XOR,using,int,Challenge,cin,istream,Grand,return
From: https://www.cnblogs.com/PaperCloud/p/18470072