leetcode 「10·24」程序员节编程竞赛 计算子集
给你三个整数 \(n, k, m\) 。定义 \(S=\{i\mid 1\le i\le nm+k,i\in \mathbb Z\}\)
请返回一个下标从 \(0\) 开始、长度为 \(m\) 的数组 answer
,其中 answer[i]
表示符合下列条件集合 \(T\) 的个数。
集合 \(T\) 是集合 \(S\) 的子集。
集合 \(T\) 中所有元素的和对 \(m\) 取余的值恰好为 \(i\) 。
由于答案可能很大,你只需要求出它模 998244353 的结果。
0 <= n <= 10^13
0 <= k < min(m, 500)
1 <= m <= 10^5
n * m + k > 0
\(k\) 是没有用的求出 \(nm\) 的答案再把多出来的那 \(k\) 个物品每个暴力 \(\mathcal{O}(m)\) 合并进去就行。那么现在就是求:
\[\left(\prod_{i=0}^{m-1}(1+x^i)\right)^n\pmod {x^m-1} \]循环卷积要做 dft,这里思路 和 UOJ 310 黎明前的巧克力 相同,UOJ 310 是直接把 fwt 写下来,这里来试着写一些 DFT,\(f_k\) 即为代入 \(\omega_m^k\) 的值,这里令 \(d=\gcd(k,m)\):
\[\begin{aligned} f_k&=\left(\prod_{i=0}^{m-1}(1+\omega_m^{ik})\right)^n\\ &=\left(\prod_{i=0}^{m/d-1}(1+\omega_{m/d}^{i})\right)^{dn} \end{aligned} \]这一步是因为 \(i\) 模 \(m/d\) 相同的 \(\omega_m^{ik}\) 值是一样的(单位根的定义),这样化简之后发现指数仅和 \(i\) 有关了,括号里面的形式非常好看,考虑比较经典的等式 \(\prod\limits_{i=0}^{m-1}(x-\omega_m^i)=x^m-1\)(左右两侧均为 \(m\) 次多项式且 \([x^m]\) 均为 \(1\) 故两式相等),将其代入 \(x=-1\) 就有:
\[\begin{aligned} &\prod_{i=0}^{m/d-1}(1+\omega_{m/d}^{i})\\ =&(-1)^{m/d}\prod_{i=0}^{m/d-1}(-1-\omega_{m/d}^i)\\ =&(-1)^{m/d}((-1)^{m/d}-1)\\ =&[m/d\bmod 2=1]2 \end{aligned} \]于是 \(f_k=[m/d\bmod 2=1]2^{dn}\),这里发现 \(f_k\) 的值仅和 \(d\) 有关,那么在 idft 的时候就可以利用这个转成枚举 \(d\):
\[\begin{aligned} g_k=&\frac{1}{m}\sum_{i=0}^{m-1}f_k\omega_m^{-ik}\\ =&\frac{1}{m}\sum_{d\mid m}f_d\sum_{i=0}^{m/d-1}[\gcd(i,m/d)=1]\omega_{m/d}^{-ik}\\ =&\frac{1}{m}\sum_{d\mid m}f_d\sum_{i=0}^{m/d-1}\sum_{e\mid i,e\mid m/d}\mu(e)\omega_{m/d}^{-ik}\\ =&\frac{1}{m}\sum_{de\mid m}f_d\mu(e)\sum_{0\leq i<m/de}\omega_{m/de}^{-ik}\\ =&\frac{1}{m}\sum_{de\mid m}f_d\mu(e)\frac{m}{de}[\frac{m}{de}\mid k]\\ \end{aligned} \]我朴素实现了个 \(\mathcal{O}(m(\log m+d(m)))\)。
代码
class Solution {
public:
typedef long long ll;
typedef vector<int>vi;
#define pb emplace_back
static const int mod=998244353;
static const int N=100010;
inline void cadd(int &x,int y){x=(x+y>=mod)?(x+y-mod):(x+y);}
inline void cdel(int &x,int y){x=(x-y<0)?(x-y+mod):(x-y);}
inline int add(int x,int y){return (x+y>=mod)?(x+y-mod):(x+y);}
inline int del(int x,int y){return (x-y<0)?(x-y+mod):(x-y);}
int qpow(int x,ll y){
int s=1;
while(y){
if(y&1)s=1ll*s*x%mod;
x=1ll*x*x%mod;
y>>=1;
}
return s;
}
ll n,m,k;
int vis[N];
vi pr;
int f[N],g[N];
vi vec[N];
void init(){
for(int i=2;i<=m;i++){
if(!vis[i])
pr.pb(i);
for(auto j:pr){
if(i*j>m)break;
vis[i*j]=1;
if(i%j==0)break;
}
}
}
ll gcd(ll x,ll y){return !y?x:gcd(y,x%y);}
vector<int> subsetCounting(long long q, int w, int e) {
n=q;k=w;m=e;
if(n){
init();
for(int i=1;i<=m;i++){
int d=__gcd(1ll*i,m);
if(m/d%2==1){
f[i]=qpow(2,1ll*d*n);
}
}
for(auto i:pr){
for(int j=m/i;j;j--)
cdel(f[i*j],f[j]);
}
for(int i=1;i<=m;i++)if(m%i==0)vec[m].pb(i);
int iv=qpow(m,mod-2);
for(int i=1;i<=m;i++){
for(auto de:vec[m]){
if(i%(m/de)==0){
cadd(g[i],1ll*f[de]*(m/de)%mod);
}
}
g[i]=1ll*g[i]*iv%mod;
}
g[0]=g[m];
g[m]=0;
}
else g[0]=1;
for(int i=0;i<=m;i++)f[i]=0;
for(int i=1;i<=k;i++){
for(int j=0;j<m;j++)
cadd(f[(j+i)%m],g[j]);
for(int j=0;j<m;j++)
g[j]=add(f[j],g[j]),f[j]=0;
}
vi vec;
for(int i=0;i<m;i++)vec.pb(g[i]);
return vec;
}
};