三模数 NTT
常数大、速度慢、精度高是它的特点。
在考虑三模数 NTT 之前先考虑一下中国剩余定理吧。
已知
\[\begin{cases} x\equiv x_1(\bmod m_1)\\ x\equiv x_2(\bmod m_2)\\ x\equiv x_3(\bmod m_3)\\ \end{cases} \]求 \(x\bmod m_1m_2m_3\)。
有
\[\begin{aligned} &k_1m_1+x_1=k_2m_2+x_2\\ &k_1\equiv \frac{x_2-x_1}{m_1}(\bmod m_2)\\ &x\equiv k_1m_1+x_1(\bmod m_1m_2)\\ &x_4=(k_1m_1+x_1)\bmod m_1m_2\\ &k_4m_1m_2+x_4=k_3m_3+x_3\\ &k_4\equiv \frac{x_3-x_4}{m_1m_2}(\bmod m_1m_2m_3)\\ &x\equiv k_4m_4+x_4(\bmod m_1m_2m_3)\\ \end{aligned} \]一点疑惑的解答(自言自语):
因为 \(k_1\equiv \frac{x_2-x_1}{m_1}(\bmod m_2)\),所以 \(k_1=\frac{x_2-x_1}{m_1}+km_2\)。又因为 \(k_1m_1\le m_1m_2\),所以 \(k_1\le m_2\)。所以 \(k\ge 0\),所以 \(k_1\) 最小为 \(\frac{x_2-x_1}{m_1}\),即 \(x\equiv k_1m_1+x_1(\bmod m_1m_2)\\\)。
进入正题:
所谓的三模数 NTT 指的是 以 \(998244353,1004535809,469762049\) 为模数(经典 NTT 模数,原根均为 \(3\))分别进行 NTT,最后用上文的计算方式计算即可。
因为以上三个模数的乘积为很大,一般答案即使不取模也不会大于该数,所以上式的 \(k_4m_4+x_4\) 就是原答案,直接对题目给出的模数取模即可。
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
#define Big __int128
const int N=3e5+1;
const ll mo1=998244353,mo2=1004535809,mo3=469762049,G=3;
inline Big Ksm(Big x,Big y,ll mo){
Big res=1;
for(;y;y>>=1,x=x*x%mo)
if(y&1)res=res*x%mo;
return res;
}
ll MOD;
const ll inv1=Ksm(mo1,mo2-2,mo2),inv2=Ksm(mo1*mo2%mo3,mo3-2,mo3);
struct Int{
ll a,b,c;
Int(ll _x=0){a=b=c=_x;}
Int(ll _a,ll _b,ll _c){a=_a,b=_b,c=_c;}
inline Int operator + (const Int &x){return Int((ll)(a+x.a)%mo1,(ll)(b+x.b)%mo2,(ll)(c+x.c)%mo3);}
inline Int operator - (const Int &x){return Int((ll)(a-x.a+mo1)%mo1,(ll)(b-x.b+mo2)%mo2,(ll)(c-x.c+mo3)%mo3);}
inline Int operator * (const Int &x){return Int((ll)a*x.a%mo1,(ll)b*x.b%mo2,(ll)c*x.c%mo3);}
inline Int operator * (ll x){return Int((ll)a*x%mo1,(ll)b*x%mo2,(ll)c*x%mo3);}
void mulinv(ll x){
a=a*Ksm(x,mo1-2,mo1)%mo1;
b=b*Ksm(x,mo2-2,mo2)%mo2;
c=c*Ksm(x,mo3-2,mo3)%mo3;
}
void inv(){
a=Ksm(a,mo1-2,mo1)%mo1;
b=Ksm(b,mo2-2,mo2)%mo2;
c=Ksm(c,mo3-2,mo3)%mo3;
}
ll gettrue(){
Big x=(Big)(b-a+mo2)%mo2*inv1%mo2*(Big)mo1+(Big)a;
return (((Big)(c-x%mo3+mo3)%mo3*inv2%mo3*(mo1%MOD*mo2%MOD)%MOD+x%MOD)%MOD+MOD)%MOD;
}
}; // mtt
int rev[N];
Int w[N];
void NTT(Int *a,int Len,bool type){
for(int i=0;i<Len;i++){
rev[i]=(rev[i>>1]>>1)+(i&1?Len>>1:0);
if(rev[i]>i)swap(a[rev[i]],a[i]);
}
for(int d=1;d<Len;d<<=1){
Int W=Int(Ksm(G,(mo1-1)/(d*2),mo1),Ksm(G,(mo2-1)/(d*2),mo2),Ksm(G,(mo3-1)/(d*2),mo3));
if(type)W.inv();
w[0]=Int(1); for(int i=1;i<d;i++)w[i]=w[i-1]*W;
for(int fir=0;fir<Len;fir+=d<<1){
int sec=fir+d;
for(int i=0;i<d;i++){
Int a0=a[fir+i],a1=w[i]*a[sec+i];
a[fir+i]=a0+a1,a[sec+i]=a0-a1;
}
}
}
if(type){for(int i=0;i<Len;i++)a[i].mulinv(Len);}
}
int n,m;
Int f[N],g[N];
int main(){
cin>>n>>m>>MOD;
for(int i=0,x;i<=n;i++)cin>>x,x%=MOD,f[i]=Int(x);
for(int i=0,x;i<=m;i++)cin>>x,x%=MOD,g[i]=Int(x);
int Len=1;
while(Len<=(n+m+4))Len<<=1;
NTT(f,Len,0),NTT(g,Len,0);
for(int i=0;i<Len;i++)f[i]=f[i]*g[i];
NTT(f,Len,1);
for(int i=0;i<=n+m;i++)cout<<f[i].gettrue()<<' ';
cout<<'\n';
return 0;
}
拆系数 FFT
常数小,速度快,精度低(\(\operatorname{long double}\) 信仰跑)是它的特色。
如果直接对原数列进行 FFT 的话会炸精度的。考虑拆系数,即 \(A_i=J\times A'_i+A''_i\)(\(A''_i< J\))。
那么:
\[\begin{aligned} F&=A\times B=(J\times A'+A'')\times(J\times B'+B'')\\ &=J^2A'B'+J(A'B''+A''B')+A''B''\\ \end{aligned} \]如果直接计算的话需要四次 dft,三次 idft,和九次 ntt 的三模数 NTT 差距并不大。
考虑优化,然而 dft/idft 中有什么地方没有用到捏?虚部!考虑将 \(A'\),\(A''\),\(B'\),\(B''\) 合并在一起进行 dft。
设:
\[\begin{aligned} P_i=A'_i+A''_ii\\ Q_i=A'_i-A''_ii\\ E_i=B'_i+B''_ii\\ \end{aligned} \]有:
\[\begin{aligned} &W_i=(P\times E)_i=(A'_iB'_i-A''_iB''_i)+(A'_iB''_i+A''_iB'_i)i\\ &R_i=(Q\times E)_i=(A'_iB'_i+A''_iB''_i)+(A'_iB''_i-A''_iB'_i)i\\ \end{aligned} \]我们可以通过 \(W\) 和 \(R\) 的加减得到我们想要的系数。
\[\begin{aligned} &W_i+R_i=2\times(A'_iB'_i+A'B''_ii)\\ &R_i-W_i=2\times(A''_iB''_i+A''_iB'_ii)\\ \end{aligned} \]注意: 是先除以二再取整!!!(代码 \(\texttt{39}\) 行)。
#include <bits/stdc++.h>
#define poly vector<int>
using namespace std;
const int N=5e5+11;
int mo;
const int base=32768;
namespace Poly{
using db = long double;
const db pi=acos(-1);
struct cp{
db x,y;
cp operator + (const cp &a){return {x+a.x,y+a.y};}
cp operator - (const cp &a){return {x-a.x,y-a.y};}
cp operator * (const cp &a){return {x*a.x-y*a.y,x*a.y+y*a.x};}
};
cp w[N]; int rev[N];
void init_rev(int Len){
for(int i=0;i<Len;i++)
rev[i]=(rev[i>>1]>>1)+(i&1?Len>>1:0);
}
void FFT(cp *a,int Len,bool type){
for(int i=0;i<Len;i++)if(rev[i]>i)swap(a[rev[i]],a[i]);
for(int d=1;d<Len;d<<=1){
cp W={cos(pi/d),sin(pi/d)};
if(type)W.y=-W.y;
w[0]={1,0};
for(int i=1;i<d;i++)w[i]=w[i-1]*W;
for(int fir=0;fir<Len;fir+=d<<1){
int sec=fir+d;
for(int i=0;i<d;i++){
cp a0=a[fir+i],a1=w[i]*a[sec+i];
a[fir+i]=a0+a1,a[sec+i]=a0-a1;
}
}
}
if(type)for(int i=0;i<Len;i++)a[i].x/=Len,a[i].y/=Len;
}
cp f[N],g[N],e[N];
long long C(db x){return (long long)(x/2.+0.5)%mo;} // important!!!
poly mul(poly x,poly y){
int tot=x.size()+y.size()-1,Len=1;
while(Len<=(tot+2))Len<<=1;
init_rev(Len);
for(int i=0;i<=Len;i++)f[i]=g[i]=e[i]={0,0};
for(int i=0;i<x.size();i++){
int a0=x[i]/base,a1=x[i]%base;
f[i]={a0,a1},g[i]={a0,-a1};
}
for(int i=0;i<y.size();i++){
int b0=y[i]/base,b1=y[i]%base;
e[i]={b0,b1};
}
FFT(f,Len,0),FFT(g,Len,0),FFT(e,Len,0);
for(int i=0;i<Len;i++)f[i]=f[i]*e[i],g[i]=g[i]*e[i];
FFT(f,Len,1),FFT(g,Len,1);
poly ret(tot,0);
for(int i=0;i<tot;i++){
ret[i]=1ll*base*base%mo*(C(f[i].x+g[i].x))%mo;
ret[i]+=1ll*base*((C(f[i].y+g[i].y))+(C(f[i].y-g[i].y)))%mo;
ret[i]%=mo;
ret[i]+=(C(g[i].x-f[i].x))%mo;
ret[i]%=mo;
}
return ret;
}
} using Poly::mul;
int a[N],n,m;
poly solve(int l,int r){
if(l==r)return {1,a[l]};
int mid=l+r>>1;
return mul(solve(l,mid),solve(mid+1,r));
}
int main(){
cin>>n>>m>>mo;
poly a(n+1,0),b(m+1,0);
for(int i=0;i<=n;i++) cin>>a[i];
for(int i=0;i<=m;i++) cin>>b[i];
a=mul(a,b);
for(int i:a)cout<<i<<' ';
return 0;
}
标签:Int,mo3,int,多项式,ll,MTT,模数,mo2,mo1
From: https://www.cnblogs.com/dadidididi/p/17478662.html