怎么有人省选后才来学FFT啊
由于时间原因,本篇笔记仅为个人总结,真正想要学习FFT的请参看这篇博客。
前置知识
单位根性质:
- $ w_n^{2k}= w_{n/2}^k $
- $ w_n^a +w_n^b =w_n^{a+b} $
算法原理
可知 n+1 个点可以唯一确定一条 n 次多项式,于是可以用 n 个点的点对集合表示一条曲线。
那么如果有两个用点值表示的多项式,将点值对应相乘就可得到两个多项式乘出的多项式。如果两个多项式的最高项次数分别为 n 、m,那么确定最终多项式的点需要 n+m+1 个。
设多项式的点值表示的序列为 \(F(x)\),那么令奇偶下标分别为一组,奇数为 \(Fl(x)\),偶数为 \(Fr(x)\),可得到:
\(F(x) = Fl(x^2) + x\times Fr(x^2)\)
如果将 x 以单位根替换,则:
\(F(w_n^k) = Fl(w_{n/2}^k) + w_n^k \times Fr(w_{n/2}^k)\)
\(F(w_n^{k+n/2}) = Fl(w_{n/2}^k) - w_n^k \times Fr(w_{n/2}^k)\)
可发现每次的数据范围是缩小了一半的,于是我们可以在 \(n \log n\) 的时间内分治求出多项式点值表示。
上述过程被称为 \(DFT\) 。
求出多项式的点值表示,很多时候我们需要的是原多项式及其系数,下列通过点值还原多项式的做法被称作 \(IDFT\)。
设计算出的点值序列为 \(G(x)\) ,则
$ G(x)=\sum_{i=0}^{n-1} F(x) ( w_n^i )^i $
$ n\times F(x) = \sum_{i=0}^{n-1} G(x) ( w_n^{-i} )^i $
证明将 1 式代入 2 式即可。
这其实相当于再做一遍\(DFT\),不过将单位根变成负的并将最终答案除二。
快速记忆:
F(x)=Fl(x)+x*Fr(x)
正负号考虑单位根的正负性,IDFT要除n
算法实现
上述实现有两个问题:A.常数巨大 B.精度问题
对于 B ,有更优秀的 \(NTT\) 来解决,现在暂且不提。
对于 A ,各路神仙对其有不同的常数优化方式。
首先单位根的计算是多次的,我们可以将递归展开,从底层向上遍历的方式减少计算。
然后数组的拷贝也是一大问题,下文介绍蝴蝶变换可以将数组顺序直接转换为最底层的顺序。
原来的递归版(数组下标,先偶后奇,从0开始):
0 1 2 3 4 5 6 7 第1层
0 2 4 6|1 3 5 7 第2层
0 4|2 6|1 5|3 7 第3层
0|4|2|6|1|5|3|7 第4层
我们要求的就是第四层的顺序。
发现一个数最终的位置就是它的二进制翻转,如 (6,110) 变为 (6,011),这也是可以用递归来做的。
代码更容易理解:
for(int i=0;i<n;i++)tr[i]=(tr[i>>1]>>1)|((i&1)?n>>1:0);
还有其他的优化什么的,建议参考各路神仙的博客。
luogu模板题代码
#include<bits/stdc++.h>
using namespace std;
inline int rd(){
int f=1,j=0;
char w=getchar();
while(!isdigit(w)){
if(w=='-')f=-1;
w=getchar();
}
while(isdigit(w)){
j=j*10+w-'0';
w=getchar();
}
return f*j;
}
const int N=2000010,M=2350000;
const double pai=acos(-1);
int n,m;
struct cp{
cp (double xx=0,double yy=0){x=xx,y=yy;}
double x,y;
cp operator +(cp const &b)const{return cp(x+b.x,y+b.y);}
cp operator -(cp const &b)const{return cp(x-b.x,y-b.y);}
cp operator *(cp const &b)const{return cp(x*b.x-y*b.y,x*b.y+y*b.x);}
}f[N*2],p[N*2];
int tr[N*2];
void fft(cp *f,bool flg){
for(int i=0;i<n;i++)if(i<tr[i])swap(f[i],f[tr[i]]);
for(int p=2;p<=n;p<<=1){
int len=p>>1;
cp tg(cos(2*pai/p),sin(2*pai/p));
if(!flg)tg.y*=-1;
for(int k=0;k<n;k+=p){
cp buf(1,0);
for(int l=k;l<k+len;l++){
cp tt=buf*f[len+l];
f[len+l]=f[l]-tt;
f[l]=f[l]+tt;
buf=buf*tg;
}
}
}
return ;
}
signed main(){
n=rd(),m=rd();
for(int i=0;i<=n;i++)f[i].x=rd();
for(int i=0;i<=m;i++)p[i].x=rd();
for(m+=n,n=1;n<=m;n<<=1);
for(int i=0;i<n;i++)tr[i]=(tr[i>>1]>>1)|((i&1)?n>>1:0);
// cout<<"kkk\n";
fft(f,1),fft(p,1);
for(int i=0;i<n;i++)f[i]=f[i]*p[i];
fft(f,0);
for(int i=0;i<=m;i++)printf("%d ",(int)(f[i].x/n+0.49));
return 0;
}