快速傅里叶变换(FFT)
前言
本文为个人学习笔记,大量参考了 oi-wiki 以及其他博客的内容。
问题
记:
\[f(x) = c_0 + c_1 x + c_2 x^2 + \cdots + c_{n}x^{n} \\ g(x) = d_0 + d_1 x + d_2 x^2 + \cdots + d_{m}x^{m} \\ h(x) = f(x) \times g(x) \]在 \(\mathcal O(n \log n)\) 内解决两个多项式乘法后的系数(即给定 \(f(x)\) 和 \(g(x)\) 的系数,要你求出 \(h(x)\) 的系数)。
分析
暴力显然是 \(\mathcal O(n^2)\) 的,优化的想法是先考虑点值表示,再考虑从点值表示转换为系数表示。
具体如下:
点值表示的意思是,你需要求出(\(\omega_n^k\) 是什么先忽略,当作是 \(n\) 个已知量即可):
\[f(\omega_n^0), f(\omega_n^1), \cdots f(\omega_n^{n-1}) \\ g(\omega_n^0), g(\omega_n^1), \cdots g(\omega_n^{n-1}) \]那么:
\[h(\omega_n^k) = f(\omega_n^k) \times g(\omega_n^k) \]实际上,\(n\) 个点的点值表示法也确定了一个 \(n - 1\) 次的多项式,因此,一定存在某个算法能将点值表示法转化为系数表示(这个后面再说)。
至此,FFT 的核心思想已经说清楚了,就是考虑求出 \(f, g\) 的点值表示,那么 \(h\) 的点值表示就可以在 \(\mathcal O(n)\) 的复杂度内求出,而后再考虑从点值表示转化为系数表示。
问题一:求出某个多项式的点值表示(离散傅里叶变换 DFT)
实际上这个问题的真实含义是:怎么选取 \(\omega_n^k\) 这个已知量才能使得在一个优秀的复杂度内求出多项式的点值表示。
记 \(\omega_n^k\) 表示将复数坐标系的单位圆平均分成 \(n\) 份,从 \(x\) 轴逆时针出发的第 \(k\) 条分界箭头的复数表示。
选取这个 \(\omega_n^k\) 的原因是它有某些性质,能 " 在一个优秀的复杂度内求出多项式的点值表示 "。
性质:
\[1) \ \omega_n^k = \omega_{\frac{n}{2}}^{\frac{k}{2}} \ \ \ \ \ \ \ \\ 2) \ \omega_n^{k + \frac{n}{2}} = - \omega_n^k \]然后开始推式子:
令
\[f_1(x) = c_0 + c_2x + \cdots + c_{n - 2}x^{\frac{n - 2}{2}} \\ f_2(x) = c_1 + c_3x + \cdots + c_{n - 1}x^{\frac{n - 2}{2}} \]显然有
\[f(x) = f_1(x^2) + xf_2(x^2) \]将 \(\omega _n^k(0 \leq k < \frac{n}{2})\) 代入有:
\[\begin {split} f(\omega_{n}^k) &= f_1(\omega _n^{2k}) + \omega_n^k f_2(\omega_n^{2k}) \\ &= f_1(\omega _{\frac{n}{2}}^{k}) + \omega_n^k f_2(\omega _{\frac{n}{2}}^{k}) \end {split} \]将 \(\omega_n^{k + \frac{n}{2}}(0 \leq k < \frac{n}{2})\) 代入有:
\[\begin {split} f(\omega_n^{k + \frac{n}{2}}) &= f_1(\omega _n^{2k + n}) + \omega_n^{k + \frac{n}{2}} f_2(\omega_n^{2k + n}) \\ &= f_1(\omega _{\frac{n}{2}}^{k}) - \omega_n^k f_2(\omega _{\frac{n}{2}}^{k}) \end {split} \]递归求解即可,有 \(\log n\) 层,时间复杂度为 \(\mathcal O(n \log n)\),为了方便处理,一般把 \(n\) 处理为 \(\geq(n + m)\) 的二次幂,多出来的部分系数补为 \(0\) 即可。
具体实现中有以下注意事项:
1、虚数可以使用 C++ STL 库中的 complex 类型;
代码
#include <bits/stdc++.h>
template < typename T >
inline void read(T &cnt) {
cnt = 0; char ch = getchar(); bool op = 1;
for (; ! isdigit(ch); ch = getchar())
if (ch == '-') op = 0;
for (; isdigit(ch); ch = getchar())
cnt = cnt * 10 + ch - 48;
cnt = op ? cnt : - cnt;
}
const int N = (1 << 22) + 5;
const double PI = acos(-1);
inline void FFT(std::complex < double > *A, int n) {
if (n == 1) return;
int m = (n >> 1);
std::complex < double > A0[m], A1[m];
for (int i = 0; i < m; ++ i) {
A0[i] = A[i * 2];
A1[i] = A[i * 2 + 1];
}
FFT(A0, m); FFT(A1, m); // 递归处理
auto W = std::complex < double > (cos(2.0 * PI / n), sin(2.0 * PI / n)),
w = std::complex < double > (1.0, 0.0); // 从 w_n^0 出发
for (int i = 0; i < m; ++ i) { // 根据式子计算 A 即可
A[i] = A0[i] + w * A1[i];
A[i + m] = A0[i] - w * A1[i];
w *= W; // 等价于 w_n^k -> w_n^{k + 1}
}
}
int n, m;
std::complex < double > F[N], G[N];
int main() {
read(n), read(m);
for (int i = 0; i <= n; ++ i) {
int x; read(x);
F[i] = x;
}
for (int i = 0; i <= m; ++ i) {
int x; read(x);
G[i] = x;
}
int sum = 1;
while (sum <= n + m) sum *= 2; // 补齐成二次幂
FFT(F, sum);
FFT(G, sum);
for (int i = 0; i < sum; ++ i)
F[i] *= G[i];
return 0;
}
优化
递归实在太慢了!
以 \(8\) 项多项式为例,模拟拆分的过程:
- 初始序列为 \(\{x_0, x_1, x_2, x_3, x_4, x_5, x_6, x_7\}\)
- 一次二分之后 \(\{x_0, x_2, x_4, x_6\},\{x_1, x_3, x_5, x_7 \}\)
- 两次二分之后 \(\{x_0,x_4\} \{x_2, x_6\},\{x_1, x_5\},\{x_3, x_7 \}\)
- 三次二分之后 \(\{x_0\}\{x_4\}\{x_2\}\{x_6\}\{x_1\}\{x_5\}\{x_3\}\{x_7 \}\)
规律:其实就是原来的那个序列,每个数用二进制表示,然后把二进制翻转对称一下,就是最终那个位置的下标。比如 \(x_1\) 是 001,翻转是 100,也就是 4,而且最后那个位置确实是 4。我们称这个变换为位逆序置换(bit-reversal permutation),证明留给读者自证。
实际上,位逆序置换可以 \(\mathcal O(n)\) 从小到大递推实现,设 \(len=2^k\),其中 k 表示二进制数的长度,设 \(R(x)\) 表示长度为 \(k\) 的二进制数 \(x\) 翻转后的数(高位补 0)。我们要求的是 \(R(0),R(1),\cdots,R(n-1)\)。
首先 \(R(0)=0\)。
我们从小到大求 \(R(x)\)。因此在求 \(R(x)\) 时,\(R\left(\left\lfloor \dfrac{x}{2} \right\rfloor\right)\) 的值是已知的。因此我们把 \(x\) 右移一位(除以 \(2\)),然后翻转,再右移一位,就得到了 \(x\) 除了(二进制)个位之外其它位的翻转结果。
考虑个位的翻转结果:如果个位是 0,翻转之后最高位就是 0。如果个位是 1,则翻转后最高位是 1,因此还要加上 \(\dfrac{len}{2}=2^{k-1}\)。综上
\[R(x)=\left\lfloor \frac{R\left(\left\lfloor \frac{x}{2} \right\rfloor\right)}{2} \right\rfloor + (x\bmod 2)\times \frac{len}{2} \]举个例子:设 \(k=5\),\(len=(100000)_2\)。为了翻转 \((11001)_2\):
- 考虑 \((1100)_2\),我们知道 \(R((1100)_2)=R((01100)_2)=(00110)_2\),再右移一位就得到了 \((00011)_2\)。
- 考虑个位,如果是 \(1\),它就要翻转到数的最高位,即翻转数加上 \((10000)_2=2^{k-1}\),如果是 \(0\) 则不用更改。
蝶形运算优化
已知 \(f_1(\omega_{n/2}^k)\) 和 \(f_2(\omega_{n/2}^k)\) 后,需要使用下面两个式子求出 \(f(\omega_n^k)\) 和 \(f(\omega_n^{k+n/2})\):
\[\begin{aligned} f(\omega_n^k) & = f_1(\omega_{n/2}^k) + \omega_n^k \times f_2(\omega_{n/2}^k) \\ f(\omega_n^{k+n/2}) & = f_1(\omega_{n/2}^k) - \omega_n^k \times f_2(\omega_{n/2}^k) \end{aligned} \]使用位逆序置换后,对于给定的 \(n, k\):
- \(f_1(\omega_{n/2}^k)\) 的值存储在数组下标为 \(k\) 的位置,\(f_2(\omega_{n/2}^k)\) 的值存储在数组下标为 \(k + \dfrac{n}{2}\) 的位置。
- \(f(\omega_n^k)\) 的值将存储在数组下标为 \(k\) 的位置,\(f(\omega_n^{k+n/2})\) 的值将存储在数组下标为 \(k + \dfrac{n}{2}\) 的位置。
因此可以直接在数组下标为 \(k\) 和 \(k + \frac{n}{2}\) 的位置进行覆写,而不用开额外的数组保存值。此方法即称为 蝶形运算,或更准确的,基 - 2 蝶形运算。
再详细说明一下如何借助蝶形运算完成所有段长度为 \(\frac{n}{2}\) 的合并操作:
1、令段长度为 \(s = \frac{n}{2}\);
2、同时枚举序列 \(\{f_1(\omega_{n/2}^k)\}\) 的左端点 \(l_g = 0, 2s, 4s, \cdots, N-2s\) 和序列 \(\{f_2(\omega_{n/2}^k)\}\) 的左端点 \(l_h = s, 3s, 5s, \cdots, N-s\);
3、合并两个段时,枚举 \(k = 0, 1, 2, \cdots, s-1\),此时 \(f_1(\omega_{n/2}^k)\) 存储在数组下标为 \(l_g + k\) 的位置,\(f_2(\omega_{n/2}^k)\) 存储在数组下标为 \(l_h + k\) 的位置;
4、使用蝶形运算求出 \(f(\omega_n^k)\) 和 \(f(\omega_n^{k+n/2})\),然后直接在原位置覆写。
代码
#include <bits/stdc++.h>
template < typename T >
inline void read(T &cnt) {
cnt = 0; char ch = getchar(); bool op = 1;
for (; ! isdigit(ch); ch = getchar())
if (ch == '-') op = 0;
for (; isdigit(ch); ch = getchar())
cnt = cnt * 10 + ch - 48;
cnt = op ? cnt : - cnt;
}
const int N = (1 << 22) + 5;
const double PI = acos(-1);
int rev[N];
inline void change(std::complex < double > *A, int n) {
for (int i = 0; i < n; ++ i) { // 求 R 数组
rev[i] = rev[i >> 1] >> 1;
if (i & 1) {
rev[i] |= (n >> 1);
}
}
for (int i = 0; i < n; ++ i) // 将原序列 变为 底层对应的序列
if (i < rev[i]) std::swap(A[i], A[rev[i]]);
}
inline void FFT(std::complex < double > *A, int n) {
change(A, n);
for (int m = 2; m <= n; m *= 2) { // m 是当前处理的每段长度
auto W = std::complex < double >
(cos(2.0 * PI / m), sin(2.0 * PI / m));
for (int x = 0; x < n; x += m) { // x 是每段的开头
auto w = std::complex < double > (1.0, 0.0);
for (int i = x; i < x + m / 2; ++ i) { // 求出每段的点值表示 根据公式求即可
auto A0 = A[i], A1 = A[i + m / 2];
A[i] = A0 + w * A1;
A[i + m / 2] = A0 - w * A1;
w *= W;
}
}
}
}
int n, m;
std::complex < double > F[N], G[N];
int main() {
change(F, 8);
read(n), read(m);
for (int i = 0; i <= n; ++ i) {
int x; read(x);
F[i] = x;
}
for (int i = 0; i <= m; ++ i) {
int x; read(x);
G[i] = x;
}
int sum = 1;
while (sum <= n + m) sum *= 2; // 补齐成二次幂
FFT(F, sum);
FFT(G, sum);
for (int i = 0; i < sum; ++ i)
F[i] *= G[i];
FFT(F, sum);
return 0;
}
问题二:将点值表示转化为系数表示(傅里叶反变换 IDFT)
点值表示的矩阵形式为:
\[\begin{bmatrix}f(\omega_n^0) \\ f(\omega_n^1) \\ f(\omega_n^2) \\ f(\omega_n^3) \\ \vdots \\ f(\omega_n^{n-1}) \end{bmatrix} = \begin{bmatrix}1 & 1 & 1 & 1 & \cdots & 1 \\ 1 & \omega_n^1 & \omega_n^2 & \omega_n^3 & \cdots & \omega_n^{n-1} \\ 1 & \omega_n^2 & \omega_n^4 & \omega_n^6 & \cdots & \omega_n^{2(n-1)} \\ 1 & \omega_n^3 & \omega_n^6 & \omega_n^9 & \cdots & \omega_n^{3(n-1)} \\ \vdots & \vdots & \vdots & \vdots & \ddots & \vdots \\ 1 & \omega_n^{n-1} & \omega_n^{2(n-1)} & \omega_n^{3(n-1)} & \cdots & \omega_n^{(n-1)^2} \end{bmatrix} \begin{bmatrix} a_0 \\ a_1 \\ a_2 \\ a_3 \\ \vdots \\ a_{n-1} \end{bmatrix} \]怎么求系数 \(a\) 呢?根据线性代数的知识:
\[Ax = b \\ x = A^{-1}b \]如果能求出 \(A^{-1}\),那么 \(A^{-1} b\) 也是两个多项式相乘的结果,FFT 即可。
唯一的问题变为怎么求解 \(A^{-1}\)。
根据矩阵的逆的定义,有
\[A^{-1} \cdot A = E \]设 \(V\) 为原矩阵,\(G\) 为逆矩阵,考虑最终落在 \(E(i, j)\) 的值:
\[E(i, j) = \sum_{k=0}^{n-1} G(i, k) \cdot V(k, j) = \sum_{k=0}^{n-1} G(i, k) \cdot \omega_n^{kj} = [i == j] \]引理
当 \(k\) 不是 \(n\) 的倍数时,
\[\sum_{i=0}^{n-1}\omega_n^{ki} = 0 \]证明如下:
\[\sum_{i=0}^{n-1}\omega_n^{ki} = \frac{\omega_n^{kn} - 1}{1 - \omega_n^{k}} = \frac{1 - 1}{1 - \omega_n^{k}} = 0 \]令 \(G(i, k) = \omega_n^{-ik}\),则:
\[\sum_{k=0}^{n-1} G(i, k) \cdot \omega_n^{kj} = \sum_{k=0}^{n-1} \omega_n^{-ik} \cdot \omega_n^{kj} = \sum_{k=0}^{n-1} \omega_n^{k(j-i)} \]当 \(j-i\) 不为 \(n\) 的倍数(0)时,上式为 0;
反之,有:
\[\sum_{k=0}^{n-1} \omega_n^{k(j-i)} = \sum_{k=0}^{n-1} \omega_n^{0} = n \]再前面补个系数 \(\frac{1}{n}\) 即可,故:
\[G(i, k) = \frac{1}{n}\omega_n^{-ik} \]int main() {
change(F, 8);
read(n), read(m);
for (int i = 0; i <= n; ++ i) {
int x; read(x);
F[i] = x;
}
for (int i = 0; i <= m; ++ i) {
int x; read(x);
G[i] = x;
}
int sum = 1;
while (sum <= n + m) sum *= 2; // 补齐成二次幂
FFT(F, sum);
FFT(G, sum);
for (int i = 0; i < sum; ++ i)
F[i] *= G[i];
FFT(F, sum);
std::reverse(F + 1, F + sum); // 从第一位开始翻转
// 翻转后变为 0 1-n, 2-n, ..., -1
// 实际上等价于 0, 1, 2, ..., n-1
for (int i = 0; i <= n + m; ++ i) { // 四舍五入
std::cout << (int)(F[i].real() / sum + 0.5) << ' ';
}
return 0;
}
标签:cnt,ch,frac,变换,傅里叶,FFT,int,cdots,omega
From: https://www.cnblogs.com/chzhc-/p/18514743