再探 FFT – DIT 与 DIF,另种推导和优化

本文将简单推导两种方式进行的离散傅里叶变换,用另种视角解释并优化算法。参考了 Seniorious yhx-12243 的 NTT 到底写了些什么(详细揭秘) 一文、OI Wiki 快速傅里叶变换 条目 和 rushcheyo 转置原理及其应用 讲稿。

离散傅里叶变换

我们计算数列 $\{a_n\}$ 的离散傅里叶变换
$$
\newcommand\DFT{\operatorname{DFT}}\DFT(a,n)_k=\sum_{j=0}^{n-1}a_j\mathrm{e}^{\frac{-2\pi\mathrm{i}}{n}kj}
$$

卷积定理循环卷积得到,我们对同样长为 $n$ 的数列 $\{b_n\}$ 作 $\DFT$,并令数列 $c_k=\DFT(a,n)_k\DFT(b,n)_k$,则有
$$
\newcommand\IDFT{\operatorname{IDFT}}\IDFT(c,n)_k=\sum_{j+q\equiv k\pmod n}a_jb_q
$$

于是我们可以利用该原理实现常规意义下的数列卷积:有长为 $n$ 的数列 $\{a_n\}$,长 $m$ 的数列 $\{b_m\}$,则将 $a,b$ 高位补 $0$,分别作 $n+m-1$ 位的 $\DFT$,点值相乘后 $\IDFT$ 即可得到 $c_k=\sum_{i+j\equiv k\pmod{n+m-1}}a_ib_j=\sum_{\substack{i+j=k\\0\leq i<n\\0\leq j<m}}a_ib_j$。

为行文方便,下文采用 $\DFT(a,n)_k=\sum_{i=0}^{n-1}a_i\omega_n^{ik}$ 的定义,其中 $\omega_n=\exp(2\pi\mathrm{i}/n)$,即 $n$ 阶单位根。

两种分治方式

DIT(Decimation in Time, 按时域抽取)

相信绝大部分 OIer 首次接触 FFT 时均学习的这种形式。如果 $n=2^l,l\in\mathbb{N}^+$,那么根据 $\omega_n^k=\omega_{2n}^{2k},\omega_{2n}^n=-1\Rightarrow\omega_{2n}^{k+n}=-\omega_{2n}^k\ \ (k<n)$,我们可以分治优化上述过程。由 $i$ 的奇偶性不同(这也是其称为“按时域抽取”的原因),可以将原式划分作
$$\begin{aligned}
\DFT(a,n)_k&=\sum_{i=0}^{n-1}a_i\omega_n^{ik}\\
&=\left(\sum_{i=0}^{n/2-1}a_{2i}\omega_n^{2ik}\right)+\left(\sum_{i=0}^{n/2-1}a_{2i+1}\omega_n^{2ik+k}\right)\\
&=\left(\sum_{i=0}^{n/2-1}a_{2i}\omega_{n/2}^{ik}\right)+\left(\sum_{i=0}^{n/2-1}a_{2i+1}\omega_{n/2}^{ik}\right)\omega_n^k
\end{aligned}$$
$$\begin{aligned}
k<n/2\Longrightarrow&\DFT(a,n)_k=\DFT(\{a_0,a_2,\cdots,a_{n-2}\},n/2)_k+\DFT(\{a_1,a_3,\cdots,a_{n-1}\},n/2)_k\omega_n^k\\
k\geq n/2\Longrightarrow&\omega_{n/2}^k=\omega_{n/2}^{k\bmod n/2}=\omega_{n/2}^{k-n/2},\omega_n^k=-\omega_n^{k-n/2}\\
\Longrightarrow&\DFT(a,n)_k=\DFT(\{a_0,a_2,\cdots,a_{n-2}\},n/2)_{k-n/2}-\DFT(\{a_1,a_3,\cdots,a_{n-1}\},n/2)_{k-n/2}\omega_{n}^{k-n/2}
\end{aligned}$$

终止条件为 $n=1=2^0$ 时,$\DFT(a,1)_0=a_0$。易得时间复杂度为 $\newcommand\bigO{\operatorname{O}}\bigO(n\log n)$;由 $a$ 的各个元素在递归树上的走向可知,若将该过程改写作非递归版本(自底向顶计算),则需要一开始将 $a$ 作蝴蝶变换。故而用线性变换的角度看待本算法流程,则其接受的输入是蝴蝶变换后的 $a$,输出是 $a$ 的离散傅里叶变换(若将 $a$ 视作 $n-1$ 次多项式 $f(x)$,则亦可称其为 $x=\omega_n^k,k=0,1,\cdots,n-1$ 时的点值表示)。

DIF(Decimation in Frequency, 按频域抽取)

既然分治时可以按时域变量的二进制表示分治,那么按频域的亦可。现在考虑 $k$ 的奇偶性:
$$\begin{aligned}
\DFT(a,n)_k&=\sum_{i=0}^{n-1}a_i\omega_n^{ik}\\
&=\sum_{i=0}^{n/2-1}a_i\omega_n^{ik}+a_{i+n/2}\omega_{\color{blue}n}^{(i+{\color{blue}n/2})k}\\
&=\sum_{i=0}^{n/2-1}(a_i+a_{i+n/2}({\color{blue}-1})^k)\omega_n^{ik}
\end{aligned}$$
$$\begin{aligned}
2\mid k\Longrightarrow&\DFT(a,n)_k&=&\sum_{i=0}^{n/2-1}(a_i+a_{i+n/2})\omega_{n/2}^{ik/2}\\
&&=&\DFT(\{a_0+a_{n/2},a_1+a_{n/2+1},\cdots,a_{n/2-1}+a_{n-1}\},n/2)_{k/2}\\
2\nmid k\Longrightarrow&\DFT(a,n)_k&=&\sum_{i=0}^{n/2-1}(a_i-a_{i+n/2})\omega_n^i\omega_{n/2}^{i(k-1)/2}\\
&&=&\DFT(\{(a_0\mathbf{\color{red}-}a_{n/2})\omega_n^0,\cdots,(a_{n/2-1}\mathbf{\color{red}-}a_{n-1})\omega_n^{n/2-1}\},n/2)_{(k-1)/2}\\
\end{aligned}$$

如果仍然“偶左奇右”处理后再分治,则与 DIT 同理,若输入序列 $a$,我们最终将得到蝴蝶变换后的 $\DFT(a,n)$。

下面是 DIT 和 DIF 在 $\newcommand\ZmpZ{\mathbb Z/p\mathbb Z}\ZmpZ,p=998244353$ 上的朴素实现。模板参数 rev 指示是否作逆变换。

  1. 1
  2. 2
  3. 3
  4. 4
  5. 5
  6. 6
  7. 7
  8. 8
  9. 9
  10. 10
  11. 11
  12. 12
  13. 13
  14. 14
  15. 15
  16. 16
  17. 17
  18. 18
  19. 19
  20. 20
  21. 21
  22. 22
  23. 23
  24. 24
  25. 25
  26. 26
  27. 27
  28. 28
  29. 29
  30. 30
  31. 31
  32. 32
  33. 33
  34. 34
  35. 35
  36. 36
  37. 37
  38. 38
  39. 39
  40. 40
  41. 41
  42. 42
  43. 43
  44. 44
  45. 45
  46. 46
  47. 47
  48. 48
  49. 49
  50. 50
  51. 51
  52. 52
  53. 53
  54. 54
  55. 55
  56. 56
#include <bits/stdc++.h> constexpr int G = 1<<21, P = 998244353, gen = 3; #define inl inline using namespace std; using ll = long long; inl ll fpow (ll a, ll b) { ll res = 1; a %= P; for (; b; b >>= 1) { if (b & 1) (res *= a) %= P; (a *= a) %= P; } return res; } inl void butterfly (int f[], int l) { static int tr[G], last; if (last != l) { last = l; for (int i = 1; i < 1<<l; ++i) tr[i] = tr[i>>1]>>1|(i & 1) * (1<<l-1); } for (int i = 1; i < 1<<l; ++i) if (tr[i] < i) swap (f[tr[i]], f[i]); } inl void reverse (int f[], int l) { const ll invl = fpow (1<<l, P-2); for (int i = 0; i < 1<<l; ++i) f[i] = invl * f[i] % P; reverse (f + 1, f + (1<<l)); } template <bool rev> inl void DIT (int f[], int l) { butterfly (f, l); for (int len = 2, j = 0; len <= 1<<l; len <<= 1, ++j) { const ll w_n = fpow (gen, (P-1)/len); ll g, h, w = 1; for (int st = 0; st < 1<<l; st += len, w = 1) for (int i = st; i < st + len/2; ++i, (w *= w_n) %= P) g = f[i], h = f[i + len/2] * w % P, f[i] = (g + h) % P, f[i + len/2] = (P + g - h) % P; } if (rev) reverse (f, l); } template <bool rev> inl void DIF (int f[], int l) { for (int len = 1<<l, q = 0; len > 1; len >>= 1, ++q) { const ll w_n = fpow (gen, (P-1)/len); ll g, h, w = 1; for (int st = 0; st < 1<<l; st += len, w = 1) for (int i = st; i < st + len/2; ++i, (w *= w_n) %= P) g = f[i], h = f[i + len/2], f[i] = (g + h) % P, f[i + len/2] = (P + g - h) * w % P; } butterfly (f, l); if (rev) reverse (f, l); }

从另一视角考虑

将 $\omega_n^{ik}a_i$ 逐位分解

我们直接考察 $a_i$ 对 $\DFT(a,n)_k$ 的贡献。显然有 $\DFT(a,n)_k\leftarrow a_i\omega_n^{ik}$,不过若将 $i,k$ 写成二进制,由 $\omega^{x+y}=\omega^x\omega^y$ 及 $\omega_n^{nk}=1,k\in\mathbb Z$ 得到:

$$
\newcommand\bit{\operatorname{bit}}\bit(x,k)=\left\lfloor\frac{x}{2^k}\right\rfloor\bmod 2\\
\begin{aligned}
\omega_n^{ik}&=\prod_{j=0}^{l-1}\omega_{2^l}^{2^j\bit(i,j)k}&\quad\quad&\omega_n^{ik}&=&\prod_{q=0}^{l-1}\omega_{2^l}^{2^q\bit(k,q)i}\\
&=\prod_{j=0}^{l-1}\omega_{2^{l-j}}^{\bit(i,j)(k\bmod 2^{l-j})}&\quad\quad&&=&\prod_{q=0}^{l-1}\omega_{2^{l-q}}^{\bit(k,q)(i\bmod{2^{l-q}})}&(2)\\
&=\prod_{j=1}^{l}\omega_{2^j}^{\bit(i,l-j)(k\bmod{2^j})}&(1)
\end{aligned}
$$

现在再来观察 DIF 的过程。对于第 $q\ (l>q\geq 0)$ 层而言,共有 $2^q$ 个相互独立的分治序列,每一序列含 $2^{l-q}$ 个元素,其中每个元素均为 $2^q$ 个不同的 $i$ 的 $a_i$ 的带权和。对于一对固定的 $i,k$,$a_i$ 之于 $\DFT(a,n)_k$ 的贡献即等于其在分治树上走到 $\DFT(a,n)_k$ 的权值之积。

考察在第 $q$ 层,向 $\DFT(a,n)_k$ 贡献的 $a_i$ 将乘上的权值。若 $\bit(k,q)=1$,则 $a_i$ 走右侧分支,且乘上 $\omega_{2^{l-q}}^?$;否则乘上 $1$ 并走左侧分支。$?$ 只与 $a_i$ 在该层分治序列中的实际位置有关,而观察可得它的实际位置正是 $i\bmod {2^{l-q}}$!又由于当 $\bit(i,l-q-1)=1\iff i\bmod{2^{l-q}}\geq 2^{l-q-1}$(在该层序列右侧)时,有 $\omega_{2^{l-q}}^{i\bmod{2^{l-q}}}={\color{red}-}\omega_{2^{l-q}}^{i\bmod{2^{l-q}}-2^{l-q-1}}$(标红的负号可以在 DIF 的推导中找到)。这正好与 $(2)$ 式完全对应!同理可以得到,DIT 的过程与 $(1)$ 式完全对应($j,1\leq j\leq l$ 为分治树自底向顶的第 $j$ 层,分治序列长为 $2^j$)。

优化单位根之幂移动次数

不难从上述示例代码中看出,我们在每一层分治中都将 $\omega_{2^{l-q}}$ 之幂移动了 $\bigO(n)$ 次;总移动次数为 $\bigO(n\log n)$ 次。通常情况下这已经足够快速了;不过对于时限异常紧张的题目,或者虽然并非正解,但有望通过 DFT 优化暴力获取更多部分分(甚至满分!)的题目,我们还想再快一些。

仍然考虑上述两种实现的解释。我们用 $(2)$ 式解释了 DIF。能否用 $(1)$ 解释呢?不难注意到,对于向 $\DFT(a,n)_k$ 贡献的 $a_i$,在第 $q+1$ 层时所处的分治序列是从左往右第 $\newcommand\rev{\operatorname{rev}}\rev(k\bmod{2^{q+1}})$ 个(第 $q+1$ 层有 $2^{q+1}$ 个序列;$\rev(x)$ 表示将 $x$ 的无前导零二进制表示翻转);在第 $q$ 层的序列内,$a_i$ 处在左半还是右半是由 $\bit(i,l-q-1)$ 位决定的。回到 $(1)$ 式,如果令 $j=q+1$,这似乎正与它所表达的相契合!于是,将分治树视为 $\text{0/1 Trie}$,将 $a_i$ 向 $\DFT(a,n)_k$ 贡献的过程视作遍历 $\rev(k)$ 二进制表示对应的路径,并在过程中以 $k$ 长为 $q+1$ 的前缀与 $\bit(i,l-q-1)$ 之积作为指数,乘上 $\omega_{2^{q+1}}$ 的次幂,就是我们所要的权值!在这个角度下,每个分治序列长为 $2^{l-q-1}$,就是因为对 $i$ 的最高 $q$ 位毫不关注;按低位排序是为快速分裂出 $q$ 更大时的分治序列。

在这种视角下,输入系数序列、输出蝴蝶变换后的点值序列的目的保持不变;但对于第 $q+1$ 层,对于同一分治序列而言,$\rev(k\bmod{2^{q+1}})$ 是不变的。这意味着我们对于同一序列采用了同一单位根之幂!故而原根的移动次数为 $\bigO(\sum_{q=0}^{l-1}2^q)=\bigO(n)$,有了大幅提升。DIT 之于 $(2)$ 式的解释同理。

在实现时,如果顺次遍历第 $q$ 层的 $2^q$ 个序列,采用的单位根之幂即分别为 $\omega_{2^{q+1}}^{\rev(0)},\omega_{2^{q+1}}^{\rev(1)},\cdots,\omega_{2^{q+1}}^{\rev(2^q-1)}$。我们注意到,将 $2^q$ 阶、$2^{q+1}$ 阶单位根之幂的前半部分序列分别作蝴蝶变换,前者是后者的前缀(不难由 $\omega_{2n}^{2k}=\omega_n^k$ 证明),因此我们以 $\bigO(n)$ 时间预处理出 $2^l$ 阶单位根 $\omega_{2^l}^0,\cdots,\omega_{2^l}^{2^{l-1}-1}$ 的序列蝴蝶变换后的结果,即可快速调用。

省去蝴蝶变换

我们用 DIF 实现 $\DFT$,DIT 实现 $\IDFT$ 就可以免去蝴蝶变换。当然,这会略微增加一点编码时间;同时需注意 DIT 得出的点值序列是蝴蝶变换后的。

代码实现

可以按需将 $x\pm y$ 结果的取模利用位运算优化。

  1. 1
  2. 2
  3. 3
  4. 4
  5. 5
  6. 6
  7. 7
  8. 8
  9. 9
  10. 10
  11. 11
  12. 12
  13. 13
  14. 14
  15. 15
  16. 16
  17. 17
  18. 18
  19. 19
  20. 20
  21. 21
  22. 22
  23. 23
  24. 24
  25. 25
  26. 26
  27. 27
  28. 28
  29. 29
  30. 30
  31. 31
  32. 32
  33. 33
  34. 34
  35. 35
  36. 36
  37. 37
  38. 38
  39. 39
  40. 40
  41. 41
  42. 42
  43. 43
  44. 44
  45. 45
  46. 46
  47. 47
  48. 48
  49. 49
  50. 50
  51. 51
  52. 52
  53. 53
  54. 54
  55. 55
  56. 56
  57. 57
  58. 58
  59. 59
  60. 60
  61. 61
  62. 62
  63. 63
  64. 64
  65. 65
  66. 66
  67. 67
  68. 68
  69. 69
  70. 70
  71. 71
  72. 72
  73. 73
  74. 74
#include <bits/stdc++.h> using namespace std; #define inl inline /* 快读已省略。 */ #define newl putchar('\n') typedef long long ll; // typedef unsigned long long ull; // typedef __int128 lll; // typedef long double llf; typedef pair <int, int> pint; #define fst first #define scd second #define all(p) begin (p), end (p) #define empb emplace_back constexpr int G = 1<<21, P = 998244353, gen = 3; int n, m, f[G], g[G], w[G]; inl ll fpow (ll a, ll b) { ll res = 1; a %= P; for (; b; b >>= 1) { if (b & 1) (res *= a) %= P; (a *= a) %= P; } return res; } inl void calc_powg () { w[0] = 1; ll f; const int g = fpow (gen, (P-1)/G); for (int t = 0; (1<<t+1) < G; ++t) { f = w[1<<t] = fpow (g, G>>t+2); for (int x = 1<<t; x < 1<<t+1; ++x) w[x] = f * w[x - (1<<t)] % P; } } inl void DIT (int f[], int l) { static ll g, h; for (int len = 2; len <= 1<<l; len <<= 1) for (int st = 0, t = 0; st < 1<<l; st += len, ++t) for (int i = st; i < st + len/2; ++i) g = f[i], h = f[i + len/2], f[i] = (g + h) % P, f[i + len/2] = (P + g - h) * w[t] % P; const ll invl = fpow (1<<l, P-2); for (int i = 0; i < 1<<l; ++i) f[i] = invl * f[i] % P; reverse (f + 1, f + (1<<l)); } inl void DIF (int f[], int l) { static ll g, h; for (int len = 1<<l; len > 1; len >>= 1) for (int st = 0, t = 0; st < 1<<l; st += len, ++t) for (int i = st; i < st + len/2; ++i) g = f[i], h = f[i + len/2] * (ll) w[t] % P, f[i] = (g + h) % P, f[i + len/2] = (P + g - h) % P; } int main () { /* */ read (n, m); calc_powg (); for (int i = 0; i <= n; ++i) read (f[i]); for (int i = 0; i <= m; ++i) read (g[i]); const int l = ceil (log2 (n + m + 1)); DIF (f, l), DIF (g, l); for (int i = 0; i < 1<<l; ++i) f[i] = (ll) f[i] * g[i] % P; DIT (f, l); for (int i = 0; i <= n + m; ++i) print (f[i]), putc (' '); return 0; }
  • 2023年2月23日
  • 4