FFT入门
给一个非常好的入门视频:
复数与单位根
定义:\(i^2=-1\)为虚数单位,我们称形如\(a+bi(a,b\in R)\)的数为复数。
我们可以用复数在复平面上表示点\((0,0)->(a,b)\)的向量,我们称\(x\)的正轴与该向量的夹角为幅角,\(\sqrt{a^2+b^2}\)为模长。
下文中,我们默认\(n,m\)为2的幂。
我们以原点为起点,单位圆的\(n\)等分点为终点,设幅角为正且最小的向量所对应的复数为\(w_n^0\),表示\(n\)次单位根,设其余\(n-1\)个单位根为\(w_n^1,w_n^2…\)。
这里显然有:\(w_n^n=w_n^0=1\)。
引出欧拉公式:
\[w_n^k=\cos \frac{2k\pi}{n}+i\sin\frac{2k\pi}{n} \]那么由向量的运算法则(模长相乘,幅角相加),显然有\(w_n^k=(w_n^1)^k\)。
那么几个小性质:
- \(w_n^0=w_n^n=1\)
- \(w_{dn}^{dk}=w_n^k\)
- \(w^k_n=(w^1_n)^k\)
- \(w^{k+\frac{n}{2}}_n=-w^k_n\)
关于性质4的证明,可以将\(k+\frac{n}{2}\)看作旋转180度的结果,自然为复。
给张图
点值表示法
对于多项式\(A(x)=a_0+a_1x+a_2x^2+…+a_{n-1}x^{n-1}\),\((a_0,a_1,a_2…)\)为其系数表示法。
而带入\(n\)个不同的\(x\),得到的\(n\)个二元组\((x_0,y_0),(x_1,y_1),(x_2,y_2)……\)即为多项式的点值表示法。
对于两个多项式\(A(x),B(x)\),计算他们的乘积\(C(x)=A(x)B(x)\)的点值表示就只需要将\(A(x),B(x)\)对应的点值表示的\(y\)值相乘即可,也即\((x_0,y_{0,A}\times y_{0,B})……\),但我们需要\(n+m\)个点。
可见点值表示法求多项式的积是非常方便的,这引出了一个极为伟大的思想:求出\(A,B\)的点值表示->求出\(A(x)\times B(x)\)的点值表示->将点值表示化为系数表示
其中第二步可以\(O(n+m)\)做到。其中表示\(A,B\)分别为\(n,m\)阶多项式。
因为\(C\)是\(n+m\)阶多项式,所以可以将\(A,B\)不足的位补0。下面我们来尝试优化第一个和第三个步骤,这就是FFT所做的事情。
FFT流程
我们现在讨论如何求出\(A\)的点值表达。将单位根\(w_n^k\)带入,得到:
\(A(w_n^k)=\sum_{k'=0}^{n-1}a_kw^{kk'}_n\)。
因为\(n\)是\(2\)的幂。考虑奇偶分组:
\(A1(x)=a_0+a_2x^2+a_4x^4…,A2(x)=a_1x+a_3x^3+a_5x^5…\)。
得到:\(A(x)=A1(x^2)+xA2(x^2)\)。
将\(w_n^k(k<\frac{n}{2})\)带入得到:
\[A(w^k_n)=A1(w^{2k}_n)+w^k_nA2(w^{2k}_n) \]将\(w^{k+\frac{n}{2}}_n\)带入得到:
\[A(w^{k+\frac{n}{2}}_n)=A1(w^{2k+n}_n)+w^{k+\frac{n}{2}}_nA2(w^{2k+n}_n)=A1(w^n_nw^{2k}_n)-w^{k}_nA2(w^{2k}_n) \]\(k,k+\frac{n}{2}\)正好取遍\([0,n-1]\)。
这个式子告诉我们,只需要我们处理左半区间的\(A1,A2\)的点值表达式,就可以快速得到\(A\)的整个区间的点值表达式。故可以分治处理,复杂度\(O(n\log n)\)。
可以写出代码:
//node 是复数类
db pi=acos(-1.0);
void solve(int limit,node a[],int tag){
if(limit==1)return ;//单个常数没有用
node a1[(limit>>1)+5],a2[(limit>>1)+5];
for(int i=0;i<=limit;i+=2){
a1[i>>1]=a[i];
a2[i>>1]=a[i+1];
}//处理出系数
solve(limit>>1,a1,tag);
solve(limit>>1,a2,tag);
node wn={cos(2.0*pi/limit),tag*sin(2.0*pi/limit)},w={1.0,0};//pi是圆周率
for(int i=0;i<(limit>>1);i++,w=w*wn){
a[i]=a1[i]+w*a2[i];
a[i+(limit>>1)]=a1[i]-w*a2[i];
}
}
//tag=1
a
数组就给出了\(x=w_n^0,w_n^1,w_n^2…\)的\(A\)的点值表示。
至于我们为什么要搞一个tag
,会告诉你答案的。
点值化系数
直觉告诉我们,因为上文的solve
在对其进行系数化点值的时候,我们搞了个变量tag=1
,不难想到它的逆过程就是tag=-1
,就可以让点值化系数了!
但是,这个系数不是我们想要的。
设上述FFT把\(A(x)\)的点值表示求出为\((y_0,y_1,y_2…)\),逆过程求出来的系数是\((c_0,c_1,c_2…)\),则\(c_k=\sum_{j=0}^{n-1}y_jw^{-jk}_n\)。
展开:
\[c_k=\sum_{j=0}^{n-1}\left(\sum_{i=0}^{n-1}a_iw_n^{ij}\right)w^{-jk}_n=\sum_{i=0}^{n-1}a_i\sum_{j=0}^{n-1}w_n^{j(i-k)} \]将\(w^{i-k}_n\)看作整体,运用等比数列求和公式,得到:\(c_k=\sum_{i=0}^{n-1}a_i\frac{w_n^{(i-k)^{n}}-1}{w_n^{i-k}-1}\)
因为\(w_n^{(i-k)^n}=(w^n_n)^{i-k}=1\),故分子为0,分母显然不为0,但当\(i=k\)时,得到其为:\(na_i\)显然可以给出:\(c_k=na_k\),故可以给出:
\[a_k=\frac{c_k}{n} \]这就是FFT求逆。点值化系数。
这有什么用呢?我们在求出来\(A,B\)的点值表示并且相乘后,再求逆得出\((c_0,c_1…)\),再对每一个\(c\)除以\(n+m\)就可以得出\(C\)的系数了!
所以整个过程是这样的。盗用一张图:
代码实现:
#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
using namespace std;
#define db double
#define N 10000050
struct node{
db x,y;
node operator+(const node b){
return {x+b.x,y+b.y};
}
node operator-(const node b){
return {x-b.x,y-b.y};
}
node operator*(const node b){
return {x*b.x-y*b.y,x*b.y+y*b.x};
}
}a[N],b[N],c[N];
int n,m,limit=1;
db pi=acos(-1.0);
void solve(int limit,node a[],int tag){
if(limit==1)return ;//单个常数没有用
node a1[(limit>>1)+5],a2[(limit>>1)+5];
for(int i=0;i<=limit;i+=2){
a1[i>>1]=a[i];
a2[i>>1]=a[i+1];
}//处理出系数
solve(limit>>1,a1,tag);
solve(limit>>1,a2,tag);
node wn={cos(2.0*pi/limit),tag*sin(2.0*pi/limit)},w={1.0,0};//pi是圆周率
for(int i=0;i<(limit>>1);i++,w=w*wn){
a[i]=a1[i]+w*a2[i];
a[i+(limit>>1)]=a1[i]-w*a2[i];
}
}
int main(){
// freopen("data.in","r",stdin);
// freopen("data.out","w",stdout);
ios::sync_with_stdio(false);
cin>>n>>m;
for(int i=0;i<=n;i++)cin>>a[i].x;
for(int i=0;i<=m;i++)cin>>b[i].x;
while(limit<=n+m)limit<<=1;
solve(limit,a,1);
solve(limit,b,1);
for(int i=0;i<=limit;i++)c[i]=a[i]*b[i];
solve(limit,c,-1);
for(int i=0;i<=n+m;i++)cout<<(int)(c[i].x/limit+0.5)<<" ";
}
迭代换递归
不难看出,递归的解法虽然简洁明了,但却有极大的空间和常数开销,我们考虑将其从递归转化为迭代:
观察这一张图。
我们将底层的数的下标写作二进制数:
原序列 | 000 | 001 | 010 | 011 | 100 | 101 | 110 | 111 |
---|---|---|---|---|---|---|---|---|
重排后 | 000 | 100 | 010 | 110 | 001 | 101 | 011 | 111 |
乍一看,我们好像发现:这个二进制数好像上下对应是反过来了。考虑证明这个性质,实际也不难,每一次划分就对一个二进制位进行了交换,自然会反过来。
将一个l
位的二进制数\(x\)倒置,可以这样做:设r[x]
为\(x\)倒置后的结果,则有r[x]=(r[x>>1]>>1)|((x&1)<<(l-1))
,其中\(limit=2^l\)。
所以我们可以预处理出这个倒置的数组r
,即可处理出合并的顺序,然后将长度为二的幂的区间自大到小合并即可,就省去了自顶向下带来的巨大常数和空间开销。
#define db double
#define N 3000050
struct node{
db x,y;
node operator+(const node b){
return {x+b.x,y+b.y};
}
node operator-(const node b){
return {x-b.x,y-b.y};
}
node operator*(const node b){
return {x*b.x-y*b.y,x*b.y+y*b.x};
}
}a[N],b[N],c[N];
int n,m,limit=1,l,r[N];
db pi=acos(-1.0);
void solve(int limit,node a[],int tag){
for(int i=0;i<limit;i++){
if(i<r[i])swap(a[i],a[r[i]]);
}
for(int len=1;len<limit;len<<=1){
for(int j=0;j<limit;j+=(len<<1)){
node wn={cos(pi/len),tag*sin(pi/len)},w={1,0};//需要合并的区间长度是len<<1,故这里约去一个2
for(int i=0;i<len;i++,w=w*wn){
node x=a[j+i],y=w*a[j+i+len];
a[j+i]=x+y;
a[j+i+len]=x-y;
}
}
}
}
int main(){
// freopen("data.in","r",stdin);
// freopen("data.out","w",stdout);
ios::sync_with_stdio(false);
cin>>n>>m;
for(int i=0;i<=n;i++)cin>>a[i].x;
for(int i=0;i<=m;i++)cin>>b[i].x;
while(limit<=n+m)limit<<=1,++l;//这里必须是<=
for(int i=0;i<=limit;i++)r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
solve(limit,a,1);
solve(limit,b,1);
for(int i=0;i<=limit;i++)c[i]=a[i]*b[i];
solve(limit,c,-1);
for(int i=0;i<=n+m;i++)cout<<(int)(c[i].x/limit+0.5)<<" ";
}
NTT
首先,类比FFT,我们利用了单位根的以下性质:
- \(w_n^0=w_n^n=1\)
- \(w_{dn}^{dk}=w_n^k\)
- \(w^k_n=(w^1_n)^k\)
- \(w^{k+\frac{n}{2}}_n=-w^k_n\)
而NTT是解决在模意义下的多项式乘法。
为什么我们需要NTT?因为FFT它炸精了!
而NTT的重要思想就是在整数域,模意义下的同样具有以上性质的整数,这让我们发现了——原根!
前置知识
阶
定义:
若\(a,p\)互素,那么满足\(a^n\equiv 1(\bmod p)\)的最小正整数\(n\)即为\(a\)模\(p\)的阶,记作\(\delta_p(a)\)
例如\(\delta_7(2)=3\)
原根
设\(p\in \mathbb{N^+},a\in \mathbb{Z}\),若满足\(\delta_p(a)=\varphi(p)\),则称\(a\)为模\(p\)的一个原根。
注意,对于模数\(p\),如果它有原根,那么它的原根数量是\(\varphi(\varphi(p))\)。
对于\(m\)来说,存在模\(m\)的原根的重要条件是:\(m=2,4,p^a,2p^a(p\in Prime,a\in \mathbb{N^+})\)
性质
对于原根,存在一个非常重要的定理:
设\(p\)是素数,\(g\)是模\(p\)的一个原根,那么\(g^i\bmod p(1<g<p,0\le i<p-1)\)互不相同。
用原根代替单位根
这里因为\(n\)是2的幂,所以我们对\(p\)有一定要求,\(p=a2^x+1\),常见的有:
\(998244353=119\times 2^{23}+1,1004535809=479\times 2^{21}+1\)。\(3\)是他们的原根之一。
设\(g_n^n\equiv 1(\bmod p)\)且\(g_n^1,g^1_n……g^{n-1}_n\)在模\(p\)下互不相同,则\(g_n\equiv g^{\frac{p-1}{n}}(\bmod p)\)等价于\(w_n^1\)。
证明:
- \(g^n_n\equiv 1(\bmod p)\)
根据定义显然
- \(g_n^1,g^1_n……g^{n-1}_n\)在模\(p\)下互不相同
根据定义显然
- \(w^{k+\frac{n}{2}}_n=-w^{k}_n,w^2_n=w^1_{\frac{n}{2}}\)
由于\(g_n^1=g^\frac{p-1}{n}\),设\(m=\frac{p-1}{n}\),则\(nm=p-1\),当\(n'=\frac{n}{2}\)时,\(m'=2m\)。所以\(g^2_n=g^{2m}=g^1_{\frac{n}{2}}\),这样我们就证明了后面一条定理
然后对于\(g^{k+\frac{n}{2}}_n=g^k_n·g^{\frac{n}{2}}_n\),\(g^{\frac{n}{2}}_n=g^{\frac{p-1}{2}}\),根据费马小定理,可以得到\(g^{\frac{p-1}{2}}=1\text{或}-1\)。然后因为\(g^0\equiv 1(\bmod p)\),根据性质2可以得到\(g^{\frac{p-1}{2}}=-1\)。所以带入即可得证。
- \(\sum_{j=0}^{n-1}g_n^{j(i-k)}\)当且仅当在\(i-k=0\)时为\(n\),否则为\(0\)
同理当\(i-k=0\)时显然为\(n\)
当\(i\neq k\)时,根据等比数列求和公式可以得到\(\frac{g_n^{(i-k)^{n}}-1}{g_n^{i-k}-1}\),根据原根的定义和费马小定理:\(g_n^n=g^{p-1}\equiv 1(\bmod p)\),所以也可以得到分子为0。
综上,\(g_n^1=g^{\frac{p-1}{n}}\)为一个可替代\(w\)的选择。
在上面的FFT代码中,我们仅仅需要写个快速幂,再更改几行:
#define N 3000050
#define ll long long
#define p 998244353
const int g=3;//模数,原根3
ll a[N],b[N],c[N];
int n,m,limit=1,l,r[N],inv_g;
int power(int a,int b){
int ans=1;
while(b){
if(b&1)ans=1ll*ans*a%p;
a=1ll*a*a%p;
b>>=1;
}
return ans;
}
void solve(int limit,ll a[],int tag){
for(int i=0;i<limit;i++){
if(i<r[i])swap(a[i],a[r[i]]);
}
for(int len=1;len<limit;len<<=1){
for(int j=0;j<limit;j+=(len<<1)){
ll gn=tag==-1?power(inv_g,(p-1)/(len<<1)):power(g,(p-1)/(len<<1)),g0=1;//如定义所说,但做逆变换的时候是用的g的逆元
for(int i=0;i<len;i++,g0=g0*gn%p){
ll x=a[j+i],y=g0*a[j+i+len]%p;
a[j+i]=(x+y)%p;
a[j+i+len]=((x-y)%p+p)%p;
}
}
}
}
int main(){
// freopen("data.in","r",stdin);
// freopen("data.out","w",stdout);
ios::sync_with_stdio(false);
cin>>n>>m;
for(int i=0;i<=n;i++)cin>>a[i];
for(int i=0;i<=m;i++)cin>>b[i];
while(limit<=n+m)limit<<=1,++l;
inv_g=power(g,p-2);
for(int i=0;i<=limit;i++)r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
solve(limit,a,1);
solve(limit,b,1);
for(int i=0;i<=limit;i++)c[i]=a[i]*b[i]%p;
solve(limit,c,-1);
ll inv=power(limit,p-2);
for(int i=0;i<=n+m;i++)cout<<c[i]*inv%p<<" ";//除法变为乘逆元
}
不过为什么我的NTT跑不过FFT啊,还慢得一批。
附上一张常用表:
//(\(g\) 是\(\bmod(r2^k+1)\)的原根)
素数 | r | k | g |
---|---|---|---|
3 | 1 | 1 | 2 |
5 | 1 | 2 | 2 |
17 | 1 | 4 | 3 |
97 | 3 | 5 | 5 |
193 | 3 | 6 | 5 |
257 | 1 | 8 | 3 |
7681 | 15 | 9 | 17 |
12289 | 3 | 12 | 11 |
40961 | 5 | 13 | 3 |
65537 | 1 | 16 | 3 |
786433 | 3 | 18 | 10 |
5767169 | 11 | 19 | 3 |
7340033 | 7 | 20 | 3 |
23068673 | 11 | 21 | 3 |
104857601 | 25 | 22 | 3 |
167772161 | 5 | 25 | 3 |
469762049 | 7 | 26 | 3 |
998244353 | 119 | 23 | 3 |
1004535809 | 479 | 21 | 3 |
2013265921 | 15 | 27 | 31 |
2281701377 | 17 | 27 | 3 |
3221225473 | 3 | 30 | 5 |
75161927681 | 35 | 31 | 3 |
77309411329 | 9 | 33 | 7 |
206158430209 | 3 | 36 | 22 |
2061584302081 | 15 | 37 | 7 |
2748779069441 | 5 | 39 | 3 |
6597069766657 | 3 | 41 | 5 |
39582418599937 | 9 | 42 | 5 |
79164837199873 | 9 | 43 | 5 |
263882790666241 | 15 | 44 | 7 |
1231453023109121 | 35 | 45 | 3 |
1337006139375617 | 19 | 46 | 3 |
3799912185593857 | 27 | 47 | 5 |
4222124650659841 | 15 | 48 | 19 |
7881299347898369 | 7 | 50 | 6 |
31525197391593473 | 7 | 52 | 3 |
180143985094819841 | 5 | 55 | 6 |
1945555039024054273 | 27 | 56 | 5 |
4179340454199820289 | 29 | 57 | 3 |
\(998244353,1004535809,469762049\),他们都存在原根\(3\)且都在int 范围内。 |
任意模数NTT
它大概是这样的,由于算法竞赛中常见的模数是\(10^9\)次方级别的,最常用的是998244353,1e9+7
这两个,于是对于多项式乘法\(A(x)B(x)\),设两个多项式分别是\(n,m(n<m)\)阶的,那么乘法所产生的最大值是\(10^9\times 10^9m\),由于\(m\)不会大于\(10^6\)级别,所以答案的值不会超过\(10^{24}\)。可以这样考虑,我们选择3个1e9
级别的模数,比如上文所说的\(998244353,1004535809,469762049\),这三者的乘积显然是大于1e24
的,我们分别对于这些数跑NTT,这很方便,但需要9次NTT。然后我们便得到了如下这个方程组:
因为直接合并三个质数会爆long long
,于是我们需要科技。
根据CRT的方法,可以先合并前两个式子,这样\(p_1p_2\)不会爆long long
,得到:
其中\(P=p_1p_2,A=a_1p_2p_2^{-1}+a_2p_1p_1^{-1}\),\(p_1^{-1},p_2^{-1}\)分别是在模\(p_2,p_1\)时的逆元。
然后就有:\(x=kP+A\equiv a_3(\bmod p_3)\),立即可推得:\(k\equiv(a_3-A)\times P^{-1}(\bmod p_3)\)。
然后就给出了答案:\(x=kP+A(\bmod p)\)。
这里在做\(P^{-1}\)的时候long long
会挂掉。而计算逆元用费马小定理,故我们需要龟速乘!奇淫技巧!
当然偷懒可以直接__int128
。奇淫技巧类似于:
ll mul(ll a,ll b,ll P){
a=(a%P+P)%P,b=(b%P+P)%P;
return ((a*b-(ll)((long double)a/P*b+1e-6)*P)%P+P)%P;
}
建议背下。
于是稍加更改便可以写出三模数NTT代码:
细节蛮多的,我会尽量在代码中标注。
#define ll long long
#define N 300350
const ll p[3]={998244353,1004535809,469762049},g=3ll;//设置成const取模会快很多
ll P,inv[3];
ll a[N][5],b[N][5],c[N][5],K[N];//注意都得开long long
int n,m,limit=1,l,r[N];
ll p_ture;
ll power(ll a,int b,ll p){
ll ans=1;
while(b){
if(b&1)ans=ans*a%p;
a=a*a%p;
b>>=1;
}
return ans;
}
ll mul(ll a,ll b,ll p){
a=(a%p+p)%p,b=(b%p+p)%p;
return ((a*b-(ll)((long double)a/p*b+1e-6)*p)%p+p)%p;
}//奇淫技巧,原理即为取模运算的另一个式子
ll expower(ll a,ll b,ll p){
ll ans=1;
while(b){
if(b&1)ans=mul(ans,a,p)%p;
a=mul(a,a,p)%p;
b>>=1;
}
return ans;
}//爆ll专用快速幂
void solve(int limit,ll a[][5],int tag,ll p,int id){
for(int i=0;i<limit;i++){
if(i<r[i])swap(a[i][id],a[r[i]][id]);
}
for(int len=1;len<limit;len<<=1){
for(int j=0;j<limit;j+=len<<1){
ll g0=1,gn=power(tag==1?g:inv[id],(p-1)/(len<<1),p);
for(int i=0;i<len;i++,g0=g0*gn%p){
ll x=a[i+j][id],y=g0*a[i+j+len][id]%p;
a[i+j][id]=(x+y)%p;
a[i+j+len][id]=((x-y)%p+p)%p;
}//和普通NTT并无区别
}
}
return ;
}
void init(){
cin>>n>>m>>p_ture;
for(int i=0;i<=n;i++)cin>>a[i][0];
for(int i=0;i<=m;i++)cin>>b[i][0];
for(int i=0;i<=n;i++)a[i][1]=a[i][2]=a[i][0]%=p_ture;
for(int i=0;i<=m;i++)b[i][1]=b[i][2]=b[i][0]%=p_ture;//注意这里得先模,否则炸精谁也救不了你
for(int i=0;i<=2;i++)inv[i]=power(g,p[i]-2,p[i]);
while(limit<=n+m)limit<<=1,l++;
for(int i=0;i<limit;i++){
r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
}
for(int i=0;i<=2;i++){
solve(limit,a,1,p[i],i);
solve(limit,b,1,p[i],i);
ll Inv=power(limit,p[i]-2,p[i]);
for(int k=0;k<limit;k++)c[k][i]=a[k][i]%p[i]*b[k][i]%p[i];
solve(limit,c,-1,p[i],i);
for(int k=0;k<limit;k++)c[k][i]=c[k][i]*Inv%p[i];
}//正常NTT板子,处理三次
P=p[1]*p[2];
ll inv_p1=expower(p[1],p[2]-2,p[2]);
ll inv_p2=expower(p[2],p[1]-2,p[1]);
ll Inv_P_0=expower(P,p[0]-2,p[0]);
ll X=mul(p[2],inv_p2,P),Y=mul(p[1],inv_p1,P);//一定要先求,减小常数开销,不然就是TLE
for(int i=0;i<limit;i++){
c[i][3]=mul(c[i][1],X,P)%P+mul(c[i][2],Y,P)%P;
c[i][3]=(c[i][3]%P+P)%P;
}
for(int i=0;i<limit;i++){
K[i]=((c[i][0]-c[i][3])%p[0]+p[0])%p[0]*Inv_P_0%p[0];//这里减法会炸负数,要二次取模
}
for(int i=0;i<=n+m;i++){
ll ans=K[i]%p_ture*(P%p_ture)%p_ture+c[i][3]%p_ture;
cout<<(ans%p_ture+p_ture)%p_ture<<" ";
}
return ;
}
int main(){
// freopen("data.in", "r", stdin);
// freopen("data.out","w",stdout);
ios::sync_with_stdio(false);
init();
}
事实上:
简直慢的要死,就差0.2s甚至不能过了。
果然得学拆系数FFT了,学了之后补上。
艹,为什么MTT比这玩意快百倍啊