再探 FFT – DIT 与 DIF,另种推导和优化

本文将简单推导两种方式进行的离散傅里叶变换,用另种视角解释并优化算法。参考了 Seniorious yhx-12243 的 NTT 到底写了些什么(详细揭秘) 一文、OI Wiki 快速傅里叶变换 条目 和 rushcheyo 转置原理及其应用 讲稿。

离散傅里叶变换

我们计算数列 {an}\{a_n\} 的离散傅里叶变换 DFT(a,n)k=j=0n1aje2πinkj \newcommand\DFT{\operatorname{DFT}}\DFT(a,n)_k=\sum_{j=0}^{n-1}a_j\mathrm{e}^{\frac{-2\pi\mathrm{i}}{n}kj}

卷积定理循环卷积得到,我们对同样长为 nn 的数列 {bn}\{b_n\}DFT\DFT,并令数列 ck=DFT(a,n)kDFT(b,n)kc_k=\DFT(a,n)_k\DFT(b,n)_k,则有 IDFT(c,n)k=j+qk(modn)ajbq \newcommand\IDFT{\operatorname{IDFT}}\IDFT(c,n)_k=\sum_{j+q\equiv k\pmod n}a_jb_q

于是我们可以利用该原理实现常规意义下的数列卷积:有长为 nn 的数列 {an}\{a_n\},长 mm 的数列 {bm}\{b_m\},则将 a,ba,b 高位补 00,分别作 n+m1n+m-1 位的 DFT\DFT,点值相乘后 IDFT\IDFT 即可得到 ck=i+jk(modn+m1)aibj=i+j=k0i<n0j<maibjc_k=\sum_{i+j\equiv k\pmod{n+m-1}}a_ib_j=\sum_{\substack{i+j=k\\0\leq i<n\\0\leq j<m}}a_ib_j

为行文方便,下文采用 DFT(a,n)k=i=0n1aiωnik\DFT(a,n)_k=\sum_{i=0}^{n-1}a_i\omega_n^{ik} 的定义,其中 ωn=exp(2πi/n)\omega_n=\exp(2\pi\mathrm{i}/n),即 nn 阶单位根。

两种分治方式

DIT(Decimation in Time, 按时域抽取)

相信绝大部分 OIer 首次接触 FFT 时均学习的这种形式。如果 n=2l,lN+n=2^l,l\in\mathbb{N}^+,那么根据 ωnk=ω2n2k,ω2nn=1ω2nk+n=ω2nk  (k<n)\omega_n^k=\omega_{2n}^{2k},\omega_{2n}^n=-1\Rightarrow\omega_{2n}^{k+n}=-\omega_{2n}^k\ \ (k<n),我们可以分治优化上述过程。由 ii 的奇偶性不同(这也是其称为“按时域抽取”的原因),可以将原式划分作 DFT(a,n)k=i=0n1aiωnik=(i=0n/21a2iωn2ik)+(i=0n/21a2i+1ωn2ik+k)=(i=0n/21a2iωn/2ik)+(i=0n/21a2i+1ωn/2ik)ωnk\begin{aligned} \DFT(a,n)_k&=\sum_{i=0}^{n-1}a_i\omega_n^{ik}\\ &=\left(\sum_{i=0}^{n/2-1}a_{2i}\omega_n^{2ik}\right)+\left(\sum_{i=0}^{n/2-1}a_{2i+1}\omega_n^{2ik+k}\right)\\ &=\left(\sum_{i=0}^{n/2-1}a_{2i}\omega_{n/2}^{ik}\right)+\left(\sum_{i=0}^{n/2-1}a_{2i+1}\omega_{n/2}^{ik}\right)\omega_n^k \end{aligned} k<n/2DFT(a,n)k=DFT({a0,a2,,an2},n/2)k+DFT({a1,a3,,an1},n/2)kωnkkn/2ωn/2k=ωn/2kmodn/2=ωn/2kn/2,ωnk=ωnkn/2DFT(a,n)k=DFT({a0,a2,,an2},n/2)kn/2DFT({a1,a3,,an1},n/2)kn/2ωnkn/2\begin{aligned} k<n/2\Longrightarrow&\DFT(a,n)_k=\DFT(\{a_0,a_2,\cdots,a_{n-2}\},n/2)_k+\DFT(\{a_1,a_3,\cdots,a_{n-1}\},n/2)_k\omega_n^k\\ k\geq n/2\Longrightarrow&\omega_{n/2}^k=\omega_{n/2}^{k\bmod n/2}=\omega_{n/2}^{k-n/2},\omega_n^k=-\omega_n^{k-n/2}\\ \Longrightarrow&\DFT(a,n)_k=\DFT(\{a_0,a_2,\cdots,a_{n-2}\},n/2)_{k-n/2}-\DFT(\{a_1,a_3,\cdots,a_{n-1}\},n/2)_{k-n/2}\omega_{n}^{k-n/2} \end{aligned}

终止条件为 n=1=20n=1=2^0 时,DFT(a,1)0=a0\DFT(a,1)_0=a_0。易得时间复杂度为 O(nlogn)\newcommand\bigO{\operatorname{O}}\bigO(n\log n);由 aa 的各个元素在递归树上的走向可知,若将该过程改写作非递归版本(自底向顶计算),则需要一开始将 aa蝴蝶变换。故而用线性变换的角度看待本算法流程,则其接受的输入是蝴蝶变换后的 aa输出aa 的离散傅里叶变换(若将 aa 视作 n1n-1 次多项式 f(x)f(x),则亦可称其为 x=ωnk,k=0,1,,n1x=\omega_n^k,k=0,1,\cdots,n-1 时的点值表示)。

DIF(Decimation in Frequency, 按频域抽取)

既然分治时可以按时域变量的二进制表示分治,那么按频域的亦可。现在考虑 kk 的奇偶性: DFT(a,n)k=i=0n1aiωnik=i=0n/21aiωnik+ai+n/2ωn(i+n/2)k=i=0n/21(ai+ai+n/2(1)k)ωnik\begin{aligned} \DFT(a,n)_k&=\sum_{i=0}^{n-1}a_i\omega_n^{ik}\\ &=\sum_{i=0}^{n/2-1}a_i\omega_n^{ik}+a_{i+n/2}\omega_{\color{blue}n}^{(i+{\color{blue}n/2})k}\\ &=\sum_{i=0}^{n/2-1}(a_i+a_{i+n/2}({\color{blue}-1})^k)\omega_n^{ik} \end{aligned} 2kDFT(a,n)k=i=0n/21(ai+ai+n/2)ωn/2ik/2=DFT({a0+an/2,a1+an/2+1,,an/21+an1},n/2)k/22kDFT(a,n)k=i=0n/21(aiai+n/2)ωniωn/2i(k1)/2=DFT({(a0an/2)ωn0,,(an/21an1)ωnn/21},n/2)(k1)/2\begin{aligned} 2\mid k\Longrightarrow&\DFT(a,n)_k&=&\sum_{i=0}^{n/2-1}(a_i+a_{i+n/2})\omega_{n/2}^{ik/2}\\ &&=&\DFT(\{a_0+a_{n/2},a_1+a_{n/2+1},\cdots,a_{n/2-1}+a_{n-1}\},n/2)_{k/2}\\ 2\nmid k\Longrightarrow&\DFT(a,n)_k&=&\sum_{i=0}^{n/2-1}(a_i-a_{i+n/2})\omega_n^i\omega_{n/2}^{i(k-1)/2}\\ &&=&\DFT(\{(a_0\mathbf{\color{red}-}a_{n/2})\omega_n^0,\cdots,(a_{n/2-1}\mathbf{\color{red}-}a_{n-1})\omega_n^{n/2-1}\},n/2)_{(k-1)/2}\\ \end{aligned}

如果仍然“偶左奇右”处理后再分治,则与 DIT 同理,若输入序列 aa,我们最终将得到蝴蝶变换后的 DFT(a,n)\DFT(a,n)

下面是 DIT 和 DIF Z/pZ,p=998244353\newcommand\ZmpZ{\mathbb Z/p\mathbb Z}\ZmpZ,p=998244353 上的朴素实现。模板参数 rev 指示是否作逆变换。

  1. 1
  2. 2
  3. 3
  4. 4
  5. 5
  6. 6
  7. 7
  8. 8
  9. 9
  10. 10
  11. 11
  12. 12
  13. 13
  14. 14
  15. 15
  16. 16
  17. 17
  18. 18
  19. 19
  20. 20
  21. 21
  22. 22
  23. 23
  24. 24
  25. 25
  26. 26
  27. 27
  28. 28
  29. 29
  30. 30
  31. 31
  32. 32
  33. 33
  34. 34
  35. 35
  36. 36
  37. 37
  38. 38
  39. 39
  40. 40
  41. 41
  42. 42
  43. 43
  44. 44
  45. 45
  46. 46
  47. 47
  48. 48
  49. 49
  50. 50
  51. 51
  52. 52
  53. 53
  54. 54
  55. 55
  56. 56
#include <bits/stdc++.h> constexpr int G = 1<<21, P = 998244353, gen = 3; #define inl inline using namespace std; using ll = long long; inl ll fpow (ll a, ll b) { ll res = 1; a %= P; for (; b; b >>= 1) { if (b & 1) (res *= a) %= P; (a *= a) %= P; } return res; } inl void butterfly (int f[], int l) { static int tr[G], last; if (last != l) { last = l; for (int i = 1; i < 1<<l; ++i) tr[i] = tr[i>>1]>>1|(i & 1) * (1<<l-1); } for (int i = 1; i < 1<<l; ++i) if (tr[i] < i) swap (f[tr[i]], f[i]); } inl void reverse (int f[], int l) { const ll invl = fpow (1<<l, P-2); for (int i = 0; i < 1<<l; ++i) f[i] = invl * f[i] % P; reverse (f + 1, f + (1<<l)); } template <bool rev> inl void DIT (int f[], int l) { butterfly (f, l); for (int len = 2, j = 0; len <= 1<<l; len <<= 1, ++j) { const ll w_n = fpow (gen, (P-1)/len); ll g, h, w = 1; for (int st = 0; st < 1<<l; st += len, w = 1) for (int i = st; i < st + len/2; ++i, (w *= w_n) %= P) g = f[i], h = f[i + len/2] * w % P, f[i] = (g + h) % P, f[i + len/2] = (P + g - h) % P; } if (rev) reverse (f, l); } template <bool rev> inl void DIF (int f[], int l) { for (int len = 1<<l, q = 0; len > 1; len >>= 1, ++q) { const ll w_n = fpow (gen, (P-1)/len); ll g, h, w = 1; for (int st = 0; st < 1<<l; st += len, w = 1) for (int i = st; i < st + len/2; ++i, (w *= w_n) %= P) g = f[i], h = f[i + len/2], f[i] = (g + h) % P, f[i + len/2] = (P + g - h) * w % P; } butterfly (f, l); if (rev) reverse (f, l); }

从另一视角考虑

ωnikai\omega_n^{ik}a_i 逐位分解

我们直接考察 aia_iDFT(a,n)k\DFT(a,n)_k 的贡献。显然有 DFT(a,n)kaiωnik\DFT(a,n)_k\leftarrow a_i\omega_n^{ik},不过若将 i,ki,k 写成二进制,由 ωx+y=ωxωy\omega^{x+y}=\omega^x\omega^yωnnk=1,kZ\omega_n^{nk}=1,k\in\mathbb Z 得到:

bit(x,k)=x2kmod2ωnik=j=0l1ω2l2jbit(i,j)kωnik=q=0l1ω2l2qbit(k,q)i=j=0l1ω2ljbit(i,j)(kmod2lj)=q=0l1ω2lqbit(k,q)(imod2lq)(2)=j=1lω2jbit(i,lj)(kmod2j)(1) \newcommand\bit{\operatorname{bit}}\bit(x,k)=\left\lfloor\frac{x}{2^k}\right\rfloor\bmod 2\\ \begin{aligned} \omega_n^{ik}&=\prod_{j=0}^{l-1}\omega_{2^l}^{2^j\bit(i,j)k}&\quad\quad&\omega_n^{ik}&=&\prod_{q=0}^{l-1}\omega_{2^l}^{2^q\bit(k,q)i}\\ &=\prod_{j=0}^{l-1}\omega_{2^{l-j}}^{\bit(i,j)(k\bmod 2^{l-j})}&\quad\quad&&=&\prod_{q=0}^{l-1}\omega_{2^{l-q}}^{\bit(k,q)(i\bmod{2^{l-q}})}&(2)\\ &=\prod_{j=1}^{l}\omega_{2^j}^{\bit(i,l-j)(k\bmod{2^j})}&(1) \end{aligned}

现在再来观察 DIF 的过程。对于第 q (l>q0)q\ (l>q\geq 0) 层而言,共有 2q2^q 个相互独立的分治序列,每一序列含 2lq2^{l-q} 个元素,其中每个元素均为 2q2^q 个不同的 iiaia_i 的带权和。对于一对固定的 i,ki,kaia_i 之于 DFT(a,n)k\DFT(a,n)_k 的贡献即等于其在分治树上走到 DFT(a,n)k\DFT(a,n)_k 的权值之积。

考察在第 qq 层,向 DFT(a,n)k\DFT(a,n)_k 贡献的 aia_i 将乘上的权值。若 bit(k,q)=1\bit(k,q)=1,则 aia_i 走右侧分支,且乘上 ω2lq?\omega_{2^{l-q}}^?;否则乘上 11 并走左侧分支。?? 只与 aia_i 在该层分治序列中的实际位置有关,而观察可得它的实际位置正是 imod2lqi\bmod {2^{l-q}}!又由于当 bit(i,lq1)=1    imod2lq2lq1\bit(i,l-q-1)=1\iff i\bmod{2^{l-q}}\geq 2^{l-q-1}(在该层序列右侧)时,有 ω2lqimod2lq=ω2lqimod2lq2lq1\omega_{2^{l-q}}^{i\bmod{2^{l-q}}}={\color{red}-}\omega_{2^{l-q}}^{i\bmod{2^{l-q}}-2^{l-q-1}}(标红的负号可以在 DIF 的推导中找到)。这正好与 (2)(2) 式完全对应!同理可以得到,DIT 的过程与 (1)(1) 式完全对应(j,1jlj,1\leq j\leq l 为分治树自底向顶的第 jj 层,分治序列长为 2j2^j)。

优化单位根之幂移动次数

不难从上述示例代码中看出,我们在每一层分治中都将 ω2lq\omega_{2^{l-q}} 之幂移动了 O(n)\bigO(n) 次;总移动次数为 O(nlogn)\bigO(n\log n) 次。通常情况下这已经足够快速了;不过对于时限异常紧张的题目,或者虽然并非正解,但有望通过 DFT 优化暴力获取更多部分分(甚至满分!)的题目,我们还想再快一些。

仍然考虑上述两种实现的解释。我们用 (2)(2) 式解释了 DIF。能否用 (1)(1) 解释呢?不难注意到,对于向 DFT(a,n)k\DFT(a,n)_k 贡献的 aia_i,在第 q+1q+1 层时所处的分治序列是从左往右第 rev(kmod2q+1)\newcommand\rev{\operatorname{rev}}\rev(k\bmod{2^{q+1}}) 个(第 q+1q+1 层有 2q+12^{q+1} 个序列;rev(x)\rev(x) 表示将 xx 的无前导零二进制表示翻转);在第 qq 层的序列内,aia_i 处在左半还是右半是由 bit(i,lq1)\bit(i,l-q-1) 位决定的。回到 (1)(1) 式,如果令 j=q+1j=q+1,这似乎正与它所表达的相契合!于是,将分治树视为 0/1 Trie\text{0/1 Trie},将 aia_iDFT(a,n)k\DFT(a,n)_k 贡献的过程视作遍历 rev(k)\rev(k) 二进制表示对应的路径,并在过程中以 kk 长为 q+1q+1 的前缀与 bit(i,lq1)\bit(i,l-q-1) 之积作为指数,乘上 ω2q+1\omega_{2^{q+1}} 的次幂,就是我们所要的权值!在这个角度下,每个分治序列长为 2lq12^{l-q-1},就是因为对 ii 的最高 qq 位毫不关注;按低位排序是为快速分裂出 qq 更大时的分治序列。

在这种视角下,输入系数序列、输出蝴蝶变换后的点值序列的目的保持不变;但对于第 q+1q+1 层,对于同一分治序列而言,rev(kmod2q+1)\rev(k\bmod{2^{q+1}}) 是不变的。这意味着我们对于同一序列采用了同一单位根之幂!故而原根的移动次数为 O(q=0l12q)=O(n)\bigO(\sum_{q=0}^{l-1}2^q)=\bigO(n),有了大幅提升。DIT 之于 (2)(2) 式的解释同理。

在实现时,如果顺次遍历第 qq 层的 2q2^q 个序列,采用的单位根之幂即分别为 ω2q+1rev(0),ω2q+1rev(1),,ω2q+1rev(2q1)\omega_{2^{q+1}}^{\rev(0)},\omega_{2^{q+1}}^{\rev(1)},\cdots,\omega_{2^{q+1}}^{\rev(2^q-1)}。我们注意到,将 2q2^q 阶、2q+12^{q+1} 阶单位根之幂的前半部分序列分别作蝴蝶变换,前者是后者的前缀(不难由 ω2n2k=ωnk\omega_{2n}^{2k}=\omega_n^k 证明),因此我们以 O(n)\bigO(n) 时间预处理出 2l2^l 阶单位根 ω2l0,,ω2l2l11\omega_{2^l}^0,\cdots,\omega_{2^l}^{2^{l-1}-1} 的序列蝴蝶变换后的结果,即可快速调用。

省去蝴蝶变换

我们用 DIF 实现 DFT\DFTDIT 实现 IDFT\IDFT 就可以免去蝴蝶变换。当然,这会略微增加一点编码时间;同时需注意 DIT 得出的点值序列是蝴蝶变换后的。

代码实现

可以按需将 x±yx\pm y 结果的取模利用位运算优化。

  1. 1
  2. 2
  3. 3
  4. 4
  5. 5
  6. 6
  7. 7
  8. 8
  9. 9
  10. 10
  11. 11
  12. 12
  13. 13
  14. 14
  15. 15
  16. 16
  17. 17
  18. 18
  19. 19
  20. 20
  21. 21
  22. 22
  23. 23
  24. 24
  25. 25
  26. 26
  27. 27
  28. 28
  29. 29
  30. 30
  31. 31
  32. 32
  33. 33
  34. 34
  35. 35
  36. 36
  37. 37
  38. 38
  39. 39
  40. 40
  41. 41
  42. 42
  43. 43
  44. 44
  45. 45
  46. 46
  47. 47
  48. 48
  49. 49
  50. 50
  51. 51
  52. 52
  53. 53
  54. 54
  55. 55
  56. 56
  57. 57
  58. 58
  59. 59
  60. 60
  61. 61
  62. 62
  63. 63
  64. 64
  65. 65
  66. 66
  67. 67
  68. 68
  69. 69
  70. 70
  71. 71
  72. 72
  73. 73
  74. 74
#include <bits/stdc++.h> using namespace std; #define inl inline /* 快读已省略。 */ #define newl putchar('\n') typedef long long ll; // typedef unsigned long long ull; // typedef __int128 lll; // typedef long double llf; typedef pair <int, int> pint; #define fst first #define scd second #define all(p) begin (p), end (p) #define empb emplace_back constexpr int G = 1<<21, P = 998244353, gen = 3; int n, m, f[G], g[G], w[G]; inl ll fpow (ll a, ll b) { ll res = 1; a %= P; for (; b; b >>= 1) { if (b & 1) (res *= a) %= P; (a *= a) %= P; } return res; } inl void calc_powg () { w[0] = 1; ll f; const int g = fpow (gen, (P-1)/G); for (int t = 0; (1<<t+1) < G; ++t) { f = w[1<<t] = fpow (g, G>>t+2); for (int x = 1<<t; x < 1<<t+1; ++x) w[x] = f * w[x - (1<<t)] % P; } } inl void DIT (int f[], int l) { static ll g, h; for (int len = 2; len <= 1<<l; len <<= 1) for (int st = 0, t = 0; st < 1<<l; st += len, ++t) for (int i = st; i < st + len/2; ++i) g = f[i], h = f[i + len/2], f[i] = (g + h) % P, f[i + len/2] = (P + g - h) * w[t] % P; const ll invl = fpow (1<<l, P-2); for (int i = 0; i < 1<<l; ++i) f[i] = invl * f[i] % P; reverse (f + 1, f + (1<<l)); } inl void DIF (int f[], int l) { static ll g, h; for (int len = 1<<l; len > 1; len >>= 1) for (int st = 0, t = 0; st < 1<<l; st += len, ++t) for (int i = st; i < st + len/2; ++i) g = f[i], h = f[i + len/2] * (ll) w[t] % P, f[i] = (g + h) % P, f[i + len/2] = (P + g - h) % P; } int main () { /* */ read (n, m); calc_powg (); for (int i = 0; i <= n; ++i) read (f[i]); for (int i = 0; i <= m; ++i) read (g[i]); const int l = ceil (log2 (n + m + 1)); DIF (f, l), DIF (g, l); for (int i = 0; i < 1<<l; ++i) f[i] = (ll) f[i] * g[i] % P; DIT (f, l); for (int i = 0; i <= n + m; ++i) print (f[i]), putc (' '); return 0; }
  • 2023年2月23日