应用
FFT,中文“快速傅里叶变换”,用来加速多项式乘法和卷积,可以将 \(O(n^2)\) 的复杂度优化至 \(O(n \log n)\)。
多项式
系数表示法
一个 \(n\) 次 \(n+1\) 项的多项式 \(f(x)\) 可以表示为 \(f(x) = \sum\limits_{i = 0}^{n} a_ix^i\)。
也可以用每一项的系数表示 \(f(x)\),即 \(f(x) = \{a_0, a_1, \dots, a_n\}\)。这种表示方法称为系数表示法.
点值表示法
把多项式 \(f(x)\) 看成一个平面直角坐标系中的函数 \(y = f(x)\),再代入 \(n+1\) 个不同的 \(x\) 得到 \(n+1\) 个 \(y\),那么这 \(n\) 个点就可以唯一确定多项式 \(f(x)\),即有且仅有一个多项式 \(f(x)\),满足 \(\forall i \in [0, n], f(x_i) = y_i\)。因为可以看成是一个 \(n+1\) 元 \(n+1\) 次方程组,一定能解出该方程组。
那么 \(f(x)\) 还可以用 \(f(x) = \{(x_0, f(x_0)), (x_1, f(x_1)), \dots, (x_{n}, f(x_{n}))\}\) 表示,这就是点值表示法。
点值表示法与系数表示法的互相转化
-
对于一个 \(n\) 次多项式,在已知系数的情况下,将其转化为点值表示法,称为 DFT(离散傅里叶变换)。
对 \(n+1\) 个 \(x\) 分别代入计算,每次 \(O(n)\),总复杂度是 \(O(n^2)\)。 -
对于一个由点值表示法表示的 \(n\) 次多项式,将其转化为系数表示法,称为 IDFT(离散傅里叶逆变换)。
即解一个 \(n+1\) 元 \(n+1\) 次的方程组,使用高斯消元是 \(O(n^3)\) 的,可以用拉格朗日插值做到 \(O(n^2)\)。
两种表示法在多项式乘法上的区别
-
对于两个用系数表示法表示的多项式 \(A(x), B(x)\),需要分别枚举两个多项式的次数再将系数相乘,系数表示法做多项式乘法的复杂度为 \(O(n^2)\)。
-
对于两个用点值表示法表示的多项式:
\[\begin{aligned} f(x) &= \{(x_0, f(x_0)), \dots, (x_{n}, f(x_{n}))\},\\ g(x) &= \{(x_0, g(x_0)), \dots, (x_{n}, g(x_{n}))\} \end{aligned} \]两者相乘的结果记为 \(h(x)\),则:
\[h(x) = \{(x_0, f(x_0) \cdot g(x_0)), \dots, (x_{n}, f(x_{n}) \cdot g(x_{n}))\} \]时间复杂度 \(O(n)\)。
那么将系数表示法转化为点值,做乘法后再转换为系数表示法会不会更快呢?在使用 DFT 和 IDFT 的情况下显然是不会的,因为转化是 \(O(n^2)\) 的。如果想要更快,就需要 FFT 了。
快速傅里叶变换
复数
复数基础详见高中数学必修二第七章。
复数可以写成 \(a + bi\) 的形式,在复平面上表示为 \((a, b)\)。
在快速傅里叶变换中,我们考虑更换点值表示法中代入的 \(x\),使得能通过 \(x\) 的一些性质减少部分运算。而这些 \(x\) 即为 \(x^n = 1\) 的 \(n\) 个根,根据高中数学知识,这些根的分布情况应为:(\(n = 8\))
即这 \(n\) 个点与原点的线段平分了单位圆,将编号为 \(1\) 的根称作 \(x^n = 1\) 的单位复数根,记为 \(\omega_n\),根据欧拉公式,有:
\[e^{i\theta} = \cos \theta + i \sin \theta \]故 \(\omega_n = e^{i \cdot (2\pi / n)}\),其它根为 \(\omega_n\) 的若干次幂。
选择这 \(n\) 个点作为点值转化原因是它们满足以下两点性质。
-
消去引理:\(\forall n \in N, k \in N, d \in N^{*}\),有 \(\omega^{dk}_{dn} = \omega^{k}_n\)。
证明:
\[\omega^{dk}_{dn} = (e^{i \cdot (2 \pi / dn)})^{dk} = (e^{i \cdot (2 \pi / n)})^k=\omega^k_n \] -
折半引理:\(\forall n \in N, k \in N\) 且 \(n\) 为偶数,有 \((\omega^{k + n/ 2}_n)^2 = \omega^k_{n / 2}\)。
证明:
\[(\omega^{k + n/ 2}_n)^2 = \omega^{2k+n}_n = \omega^{2k}_n = \omega^{k}_{n/2} \]
FFT
考虑将 \(\omega^0_n, \omega^1_n, \dots, \omega^{n - 1}_n\) 代入求值。
(由于下面的推导需要用到折半定理,所以将 \(n\) 视为 \(2\) 的次幂)
有多项式:
\[A(x) = a_0 + a_1x + a_2x^2 + \dots + a_{n - 1}x^{n - 1} \]将 \(A(x)\) 中的单项按照下标奇偶性分开:
\[\begin{aligned} A(x) &= (a_0 + a_2x^2 + \dots + a_{n - 2}x^{n - 2}) + (a_1x + a_3x^3 + \dots + a_{n - 1}x^{n - 1})\\ &= (a_0 + a_2x^2 + \dots + a_{n - 2}x^{n - 2}) + x(a_1 + a_3x^2 + \dots + a_{n - 1}x^{n - 2}) \end{aligned} \]令:
\[\begin{aligned} A_1(x) &= a_0 + a_2x + \dots + a_{n - 2}x^{n / 2 - 1} \\ A_2(x) &= a_1 + a_3x + \dots + a_{n - 1}x^{n / 2 - 1} \end{aligned} \]则:
\[A(x) = A_1(x^2) + xA_2(x^2) \]令 \(k \in [0, n / 2 - 1]\),将 \(\omega^k_n\) 代入 \(A(x)\) 后得:
\[\begin{aligned} A(\omega^k_n) &= A_1((\omega^k_n)^2) + \omega^k_nA_2((\omega^k_n)^2) \\ &= A_1(\omega^{2k}_n) + \omega^k_nA_2(\omega^{2k}_n) \\ &= A_1(\omega^{k}_{n / 2}) + \omega^k_nA_2(\omega^{k}_{n/2}) \end{aligned} \]接着将 \(\omega^{k+n/2}_n\) 代入 \(A(x)\) 后得:
\[\begin{aligned} A(\omega^{k+n/2}_n) &= A_1((\omega^{k+n/2}_n)^2) + \omega^{k+n/2}_nA_2((\omega^{k+n/2}_n)^2) \\ &= A_1(\omega^{k}_{n / 2}) + \omega^{k+n/2}_nA_2(\omega^{k}_{n/2})\\ &= A_1(\omega^{k}_{n / 2}) - \omega^k_nA_2(\omega^{k}_{n/2}) \end{aligned} \]可以观察到 \(A(\omega^k_n)\) 与 \(A(\omega^{k+n/2}_n)\) 化简后只有符号不同,也就是说,只要求出了 \(A_1(\omega^{k}_{n / 2})\) 和 \(A_2(\omega^{k}_{n/2})\) 的值,就能求出 \(A(\omega^k_n)\) 和 \(A(\omega^{k+n/2}_n)\) 的值。
问题就转化为了求 \(A_1(x)\) 与 \(A_2(x)\) 在 \(x = \omega^0_{n/2}, \omega^1_{n / 2}, \dots, \omega^{n / 2 - 1}_{n / 2}\) 的取值,递归求解即可,时间复杂度 \(O(n \log n)\)。
IFFT
通过 FFT,我们已经能在 \(O(n \log n)\) 的时间复杂度下将系数转化为点值,接下来,考虑如何将多项式的点值转化为系数。
对于前面的 FFT,其过程可以用矩阵乘法来表示:
\[\begin{bmatrix} 1 & 1 & 1 & \cdots & 1 \\ 1 & \omega^{1}_n & \omega^{2}_n & \cdots & \omega^{n - 1}_n \\ 1 & \omega^{2}_n & \omega^{4}_n & \cdots & \omega^{2n - 2}_n \\ \vdots & \vdots & \vdots & \ddots & \vdots \\ 1 & \omega^{n - 1}_n & \omega^{2n - 2}_n & \cdots & \omega^{(n-1)^2}_n \end{bmatrix} \times \begin{bmatrix} a_0 \\ a_1 \\ a_2 \\ \vdots \\ a_{n - 1} \end{bmatrix} = \begin{bmatrix} A(a_0) \\ A(a_1) \\ A(a_2) \\ \vdots \\ A(a_{n - 1}) \end{bmatrix} \]所以构造出左边矩阵的逆矩阵即可,左边的矩阵满足 \(V_{i, j} = \omega^{ij}_n\),它的逆矩阵是 \(V^{-1}_{i, j} = \frac{\omega^{-ij}_n}{n}\)。所以就有:
\[\begin{bmatrix} 1 & 1 & 1 & \cdots & 1 \\ 1 & \omega^{-1}_n & \omega^{-2}_n & \cdots & \omega^{-n + 1}_n \\ 1 & \omega^{-2}_n & \omega^{-4}_n & \cdots & \omega^{-2n + 2}_n \\ \vdots & \vdots & \vdots & \ddots & \vdots \\ 1 & \omega^{-n + 1}_n & \omega^{-2n + 2}_n & \cdots & \omega^{-(n-1)^2}_n \end{bmatrix} \times \begin{bmatrix} A(a_0) \\ A(a_1) \\ A(a_2) \\ \vdots \\ A(a_{n - 1}) \end{bmatrix} = n \times \begin{bmatrix} a_0 \\ a_1 \\ a_2 \\ \vdots \\ a_{n - 1} \end{bmatrix} \]相当于做 FFT 时将 \(\omega^k_n\) 变成 \(\omega^{-k}_n\),最后将结果除以 \(n\) 即可,时间复杂度也为 \(O(n \log n)\)。
有了 FFT 和 IFFT 就可以写出递归版本的多项式乘法,下面是洛谷 P3803 的代码:
#include <bits/stdc++.h>
using i64 = long long;
using namespace std::complex_literals;
using complex = std::complex<double>;
int n, _m;
std::vector<complex> a, b;
const double pi = acos(-1.0);
void FFT(std::vector<complex> &a, int coef) {
int n = a.size();
if (n == 1)
return ;
std::vector<complex> a1, a2;
for (int i = 0; i < n; ++i) {
(i & 1 ? a2 : a1).push_back(a[i]);
}
FFT(a1, coef);
FFT(a2, coef);
double theta = coef * 2 * pi / n;
complex wn{cos(theta), sin(theta)};
complex w = 1;
for (int k = 0; k < n / 2; ++k, w *= wn) {
a[k] = a1[k] + w * a2[k];
a[k + n / 2] = a1[k] - w * a2[k];
}
}
std::vector<int> multiply(std::vector<int> _a, std::vector<int> _b) {
std::vector<complex> a(_a.begin(), _a.end()), b(_b.begin(), _b.end());
int n = a.size(), _m = b.size(), len = n + _m - 2;
for (n += _m; n != (n & -n); ++n) ;
a.resize(n);
b.resize(n);
FFT(a, 1);
FFT(b, 1);
for (int i = 0; i < n; ++i)
a[i] = a[i] * b[i];
FFT(a, -1);
std::vector<int> c(len + 1);
for (int i = 0; i <= len; ++i) {
c[i] = (int)(a[i].real() / n + 0.5);
}
return c;
}
int main() {
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
int n, m;
std::cin >> n >> m;
++n, ++m;
std::vector<int> a, b;
while (n--) {
int x;
std::cin >> x;
a.push_back(x);
}
while (m--) {
int x;
std::cin >> x;
b.push_back(x);
}
auto ans = multiply(a, b);
for (auto x : ans)
std::cout << x << ' ';
}
交上去能够通过,但常数较大(最慢的点跑了 1.5s)。考虑优化常数,下面介绍小常数迭代写法。
蝴蝶变换
考虑 FFT 过程中下标的变化:
观察到后序列的下标其实就是原序列下标二进制的翻转,因为每次都把奇数放在右边,偶数放在左边。有了这个性质,我们可以预先把下标放在后序列对应的下标上,然后对区间进行合并即可。
代码:
#include <bits/stdc++.h>
using i64 = long long;
using namespace std::complex_literals;
using complex = std::complex<double>;
const double pi = acos(-1.0);
void FFT(std::vector<complex> &a, int coef, const std::vector<int> &rev) {
int n = a.size();
for (int i = 0; i < n; ++i)
if (i < rev[i])
std::swap(a[i], a[rev[i]]);
for (int k = 1; k < n; k <<= 1) {
double theta = coef * pi / k;
complex wn{cos(theta), sin(theta)};
for (int i = 0; i < n; i += k * 2) {
complex w = 1;
for (int j = 0; j < k; ++j, w *= wn) {
auto x = a[i + j], y = a[i + j + k] * w;
a[i + j] = x + y;
a[i + j + k] = x - y;
}
}
}
}
std::vector<int> multiply(std::vector<int> _a, std::vector<int> _b) {
std::vector<complex> a(_a.begin(), _a.end()), b(_b.begin(), _b.end());
int n = a.size(), _m = b.size(), len = n + _m - 2;
for (n += _m; n != (n & -n); ++n) ;
a.resize(n);
b.resize(n);
int s = std::__lg(n);
std::vector<int> rev(n);
for (int i = 1; i < n; ++i)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (s - 1));
FFT(a, 1, rev);
FFT(b, 1, rev);
for (int i = 0; i < n; ++i)
a[i] = a[i] * b[i];
FFT(a, -1, rev);
std::vector<int> c(len + 1);
for (int i = 0; i <= len; ++i) {
c[i] = (int)(a[i].real() / n + 0.5);
}
return c;
}
int main() {
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
int n, m;
std::cin >> n >> m;
++n, ++m;
std::vector<int> a, b;
while (n--) {
int x;
std::cin >> x;
a.push_back(x);
}
while (m--) {
int x;
std::cin >> x;
b.push_back(x);
}
auto ans = multiply(a, b);
for (auto x : ans)
std::cout << x << ' ';
}
交上去最慢的点跑了 600ms,效率超过递归写法的两倍。
标签:std,end,int,FFT,笔记,学习,vector,omega From: https://www.cnblogs.com/CTHOOH/p/18184021