目录
理论
现有长度为 \(n=2^s\) 序列和相应的多项式为
\[\begin{align*} \pmb a&=(a_0,a_1,\ldots,a_{n-1})\\ f_{\pmb a}(x)&=\sum_{j=0}^{n-1}a_jx^j \end{align*} \]为方便,取符号 \(\omega_n=e^{-2\pi i/n}\),定义离散傅里叶变换(DFT)为
\[\begin{align*} \pmb y&=(y_0,y_1,\ldots,y_{n-1})\\ y_k&=\left.f_{\pmb a}(x)\right\vert_{x=\omega_n^k}=\sum_{j=0}^{n-1}a_j\omega_n^{kj}\\ \end{align*} \]如果采用每个 \(y_k\) 单独计算的方法,则需要进行 \(n\) 次 \(f_{\pmb a}(x)\) 的计算。如果采用霍纳法则进行计算,每个 \(f_{pmb a}(x)\) 需要进行 \(n-1\) 次加法和 \(n-1\) 次乘法,总体时间复杂度为 \(\Theta(n^2)\)。
如果采用快速傅里叶变换则速度将加快到 \(\Theta(n\lg n)\),这主要是因为利用了 \(\omega_n\) 的周期性。
\[\begin{align*} \omega_{dn}^{dk}&=\omega^k_n\\ \omega_n^{k+n}&=\omega_n^k \end{align*} \]将序列 \(\pmb a\) 的奇索引和偶索引分离
\[\begin{align*} \pmb a^{[0]}&=(a_0,a_2,\ldots,a_{n-2})\\ \pmb a^{[1]}&=(a_1,a_3,\ldots,a_{n-1})\\ \end{align*} \]得到相应的傅里叶变换,注意长度的变化,变为了原来的二分之一。
\[\begin{align*} f_{\pmb a^{[0]}}(x)&=\sum_{j=0}^{n/2-1}a_{2j}x^j\\ f_{\pmb a^{[1]}}(x)&=\sum_{j=0}^{n/2-1}a_{2j+1}x^j\\[1em] \pmb y^{[0]}&=(y_0,y_2,\ldots,y_{n-2})\\ \pmb y^{[1]}&=(y_1,y_3,\ldots,y_{n-1})\\ y^{[0]}_k&=f_{\pmb a^{[0]}}(\omega^k_{n/2})\\ y^{[1]}_k&=f_{\pmb a^{[1]}}(\omega^k_{n/2}) \end{align*} \]则
\[\begin{align*} f_{\pmb a}(x)&=f_{\pmb a^{[0]}}(x^2)+xf_{\pmb a^{[1]}}(x^2)\\ y_k&=f_{\pmb a}(\omega_n^k)=f_{\pmb a^{[0]}}(\omega_n^{2k})+\omega_n^kf_{\pmb a^{[1]}}(\omega_n^{2k})\\ &=f_{\pmb a^{[0]}}(\omega_{n/2}^k)+\omega_n^kf_{\pmb a^{[1]}}(\omega_{n/2}^k) \end{align*} \]但是因为 \(\pmb y^{[0,1]}\) 的长度只有 \(n/2\),所以只能计算前 \(n/2\) 个,还需要计算后 \(n/2\) 个的方法。
由于 \(\omega_{n/2}^{k+n/2}=\omega_{n/2}^k\),因此 \(f_{\pmb a^{[0,1]}}(\omega_{n/2}^{k+n/2})=f_{\pmb a^{[0,1]}}(\omega_{n/2}^k)\)。且 \(\omega_n^{k+n/2}=-\omega_n^k\)。因此有对于 \(k\in0..(n/2-1)\)
\[\begin{align*} y_k&=f_{\pmb a^{[0]}}(\omega_{n/2}^k)+\omega_n^kf_{\pmb a^{[1]}}(\omega_{n/2}^k)\\ &=y^{[0]}_k+\omega_n^ky_k^{[1]}\\ y_{k+n/2}&=f_{\pmb a^{[0]}}(\omega_{n/2}^{k+n/2})+\omega_n^{k+n/2}f_{\pmb a^{[1]}}(\omega_{n/2}^{k+n/2})\\ &=f_{\pmb a^{[0]}}(\omega_{n/2}^k)-\omega_n^kf_{\pmb a^{[1]}}(\omega_{n/2}^k)\\ &=y^{[0]}_k-\omega_n^ky_k^{[1]}\\ \end{align*} \]故 \(f_{\pmb a^{[0,1]}}(\omega_{n/2}^{k+n/2})\) 只需要计算 \(n/2\) 个,正好和序列 \(\pmb a^{[0,1]}\) 长度相同,也就是计算 \(\pmb y^{[0,1]}\) 即可。因此计算 \(f_{\pmb a^{[0,1]}}\) 的方法和 \(f_{\pmb a}\) 没什么不同,可以采用递归的方法。
这个递归函数为
\[\begin{align*} \pmb\omega_n&=\left(\omega_n^0,\omega_n^1,\ldots,\omega_n^{n/2-1}\right)\\ \pmb y_{0..(n/2-1)}&=f(\pmb a)=f(\pmb a^{[0]})+\pmb\omega_n\cdot f(\pmb a^{[1]})\\ \pmb y_{(n/2)..(n-1)}&=f(\pmb a)=f(\pmb a^{[0]})-\pmb\omega_n\cdot f(\pmb a^{[1]})\\ \end{align*} \]其中乘法为按位乘法。
注意到在前面的证明中只使用了 \(\omega_n\) 的周期性,显然这对于 \(\omega_n^{-1}\) 也是成立的。
而离散傅里叶逆变换(IDFT)为
\[\begin{align*} a_j&=\frac1n\sum_{k=0}^{n-1}y_k\omega_n^{-kj}=\frac1n\sum_{k=0}^{n-1}y_k(\omega_n^{-1})^{kj} \end{align*} \]只需要在上述证明和算法中将 \(\omega_n\) 全部替换为 \(\omega_n^{-1}\),再在最后除以 \(n\),则完成了逆运算的证明和算法。
\[\begin{align*} \pmb\omega_n&=\left(\omega_n^{-0},\omega_n^{-1},\ldots,\omega_n^{-(n/2-1)}\right)\\ n\pmb a_{0..(n/2-1)}&=f(\pmb y)=f(\pmb y^{[0]})+\pmb\omega_n\cdot f(\pmb y^{[1]})\\ n\pmb a_{(n/2)..(n-1)}&=f(\pmb y)=f(\pmb y^{[0]})-\pmb\omega_n\cdot f(\pmb y^{[1]})\\ \end{align*} \]时间复杂度为
\[\begin{align*} T(n)&=\begin{cases} \Theta(1)&n=1\\ 2T(n/2)+\Theta(n)&n=2^s,s\in\mathbb Z^+ \end{cases}\\ &=\Theta(n\lg n) \end{align*} \]实现
递归实现和迭代实现,只适用于 n
为 2 的幂的情况。
from itertools import count
from sympy import S
import numpy as np
def recursive_fft(a: np.ndarray, dtype=np.complex128) -> np.ndarray:
n = a.size
assert n & (n - 1) == 0 # n is power of 2
if a.dtype != dtype:
a = a.astype(dtype)
if n == 1:
return a.copy()
w_n = dtype(S('exp(-2 * pi * I / n)').subs('n', n))
w = w_n ** np.arange(n // 2, dtype=dtype)
a = a.reshape(-1, 2) # 用 reshape 不用新建数组
y0 = recursive_fft(a[:,0].reshape(-1), dtype)
y1 = recursive_fft(a[:,1].reshape(-1), dtype)
y1 *= w
y = np.hstack([
y0 + y1,
y0 - y1
])
return y
def recursive_ifft(y: np.ndarray, dtype=np.complex128) -> np.ndarray:
def _recursive_ifft(y):
n = y.size
if n == 1:
return y.copy()
w_n = dtype(S('exp(2 * pi * I / n)').subs('n', n))
w = w_n ** np.arange(n // 2, dtype=dtype)
y = y.reshape(-1, 2) # 用 reshape 不用新建数组
a0 = _recursive_ifft(y[:,0].reshape(-1))
a1 = _recursive_ifft(y[:,1].reshape(-1))
a1 *= w
a = np.hstack([
a0 + a1,
a0 - a1
])
return a
n = y.size
assert n & (n - 1) == 0 # n is power of 2
if y.dtype != dtype:
y = y.astype(dtype)
return _recursive_ifft(y) / n
def iterative_fft(a: np.ndarray, dtype=np.complex128) -> np.ndarray:
n = a.size
assert n & (n - 1) == 0 # n is power of 2
for s in count():
m = 1 << s
if m == n:
break
a = a.reshape(m, -1, 2).swapaxes(1, 2)
if a.dtype != dtype:
y = a.ravel().astype(dtype)
else:
y = a.flatten()
for s in count():
m = 1 << s
if m == n:
break
w_m = dtype(S('exp(-pi * I / m)').subs('m', m))
w = w_m ** np.arange(m, dtype=dtype)
y = y.reshape(-1, m << 1)
y0 = y[:,:m].copy()
y1 = y[:,m:] * w
y[:,:m] = y0 + y1
y[:,m:] = y0 - y1
return y
def iterative_ifft(y: np.ndarray, dtype=np.complex128) -> np.ndarray:
n = y.size
assert n & (n - 1) == 0 # n is power of 2
for s in count():
m = 1 << s
if m == n:
break
y = y.reshape(m, -1, 2).swapaxes(1, 2)
if y.dtype != dtype:
a = y.ravel().astype(dtype)
else:
a = y.flatten()
for s in count():
m = 1 << s
if m == n:
break
w_m = dtype(S('exp(pi * I / m)').subs('m', m))
w = w_m ** np.arange(m, dtype=dtype)
a = a.reshape(-1, m << 1)
a0 = a[:,:m].copy()
a1 = a[:,m:] * w
a[:,:m] = a0 + a1
a[:,m:] = a0 - a1
return a / n
在代码中使用了 sympy
的表达式,因此如果 dtype=object
那么可以获得结果的精确表达式,但不一定是最简式。
numpy
实现的要点
reshape()
和 swapaxes()
关于 a.reshape(m, -1, 2).swapaxes(1, 2)
的作用,参考递归实现,每次递归都分离了偶索引和奇索引。
关键在于递归和奇偶分离:
m
和reshape(m)
的作用在于指明递归次数,同时将每次递归分离到不同的坐标上,使每次递归之间(的坐标转换)互不影响。reshape(-1, 2).swapaxes(1, 2)
的作用在于奇偶分离,如果认为数据按先列后行的顺序排列,那么就是首先按奇偶分为两列,在旋转坐标轴将奇偶分为两行。
类型转换、一维化和坐标转换
为了提高效率,需要注意一下类型转换和坐标转换顺序。由于经过多次 reshape()
和 swapaxis()
,原有数组的内在坐标被转换了。而后续操作基本上是线性的、连续的操作,因此在转换完坐标后再进行类型转换和一维化 (注意 flatten()
和 ravel()
的使用),可以创建新的线性数组,方便后续操作,且遵循了操作不改变原数组的准则,时间复杂度也会减小。
slower
版本先进行类型转换且没有进行一维化,再进行坐标转换。faster
版本先进行坐标转换,再进行类型转换和一维化。
# slower
if y.dtype != dtype:
a = y.astype(dtype)
else:
a = y.copy()
for s in count():
m = 1 << s
if m == n:
break
a = a.reshape(m, -1, 2).swapaxes(1, 2)
# faster
for s in count():
m = 1 << s
if m == n:
break
y = y.reshape(m, -1, 2).swapaxes(1, 2)
if y.dtype != dtype:
a = y.ravel().astype(dtype)
else:
a = y.flatten()
数组的复制
还需要注意数组的复制。一般来说,使用 numpy
库复制数组之后时间会加快(在没有多余操作的前提下),具体看下列代码。
slower
版本的关键在于使用原数组记录相关变量,在这个过程中,没有使用新的变量,a0, a1
只是a
的视图而已,是a
的一部分。faster
版本使用copy()
和乘法,一个直接、一个间接地建立了新数组,因此速度加快了。
#slower
a = a.reshape(-1, m << 1)
a0 = a[:,:m]
a1 = a[:,m:]
a1 *= w
a0[:] += a1
a1 *= -2
a1[:] += a0
# faster
a = a.reshape(-1, m << 1)
a0 = a[:,:m].copy()
a1 = a[:,m:] * w
a[:,:m] = a0 + a1
a[:,m:] = a0 - a1
在递归的实现中,因为 hstack()
不改变原数组,所以传入参数时使用 reshape()
,而不是新建数组。在这里,选择了减小空间开销而不是减小时间开销。