FFT“三次转两次”优化

众所周知,由于做快速傅里叶变换时,要先将两个多项式 f(x),g(x)f(x),g(x) 分别DFT,点值相乘后再IDFT,其常数是巨大的。但如果其系数均为实数,我们就可以通过构造多项式的方式,充分利用复数虚部存储需要的信息。下文中默认有 deg(f(x))=deg(g(x))=n=2s\operatorname{deg}(f(x))=\operatorname{deg}(g(x))=n=2^s。记 nn 次单位根为 ωn\omega_{n}

考虑构造两多项式 p(x)=f(x)+ig(x),q(x)=f(x)ig(x)p(x)=f(x)+\mathrm{i}g(x),q(x)=f(x)-\mathrm{i}g(x),则假若我们求出 p(x)p(x) 的点值表示,就会有 p(ωnk)=q(ωnk modn)\overline{p(\omega_n^k)}=q(\omega_n^{-k\space\bmod n})

证明:容易发现,当 2n2\mid n 时,ωnk=ωnk modn\overline{\omega_n^k}=\omega_n^{-k\space \bmod n}。故而令 ωnk=a+bi,ωnk modn=abi\omega_n^k=a+b\mathrm{i},\omega_n^{-k\space \bmod n}=a-b\mathrm{i}

考察 p(ωnk)p(\omega_n^k) 的每一次项的贡献。例如第 tt 项的贡献为 (c+di)(a+bi)t=(c+di)j=0tbjijatj(tj)=(c+di)(j=1,2∤jtbjijatj(tj)+j=0,2jtbjijatj(tj))=(c+di)(eif),e,fR=cfde+i(cedf)\begin{aligned} &(c+d\mathrm{i})(a+b\mathrm{i})^t\\ =&(c+d\mathrm{i})\sum_{j=0}^{t}b^j\mathrm{i}^ja^{t-j}\binom{t}{j}\\ =&(c+d\mathrm{i})\left({\color{red}\sum_{j=1, 2 \not\mid j}^{t}b^j\mathrm{i}^ja^{t-j}\binom{t}{j}}+{\color{blue}\sum_{j=0, 2 \mid j}^{t}b^j\mathrm{i}^ja^{t-j}\binom{t}{j}}\right)\\ =&(c+d\mathrm{i})\left({\color{red}e\mathrm{i}}{\color{blue}-f}\right), e,f\in \mathbb{R}\\ =&-cf-de+\mathrm{i}(ce-df) \end{aligned}

容易发现,设蓝字部分之和为 ff,则由于 2j2\mid j,其和只和 b,ab, a 的绝对值有关。ee 则和 b,ab, a 的符号有关。

类似地,我们考察 q(ωnk modn)q(\omega_n^{-k\space\bmod n}) 的每一次项的贡献。第 tt 项(根据构造,其系数与 [xt]p(x)[x^t]p(x) 共轭)的贡献为 (cdi)(abi)t=(cdi)j=0t(b)jijatj(tj)=(cdi)(j=1,2∤jt(b)jijatj(tj)+j=0,2jt(b)jijatj(tj))=(cdi)(eif),e,fR=cfde+i(ce+df)\begin{aligned} &(c-d\mathrm{i})(a-b\mathrm{i})^t\\ =&(c-d\mathrm{i})\sum_{j=0}^{t}(-b)^j\mathrm{i}^ja^{t-j}\binom{t}{j}\\ =&(c-d\mathrm{i})\left({\color{red}\sum_{j=1, 2 \not\mid j}^{t}(-b)^j\mathrm{i}^ja^{t-j}\binom{t}{j}}+{\color{blue}\sum_{j=0, 2 \mid j}^{t}(-b)^j\mathrm{i}^ja^{t-j}\binom{t}{j}}\right)\\ =&(c-d\mathrm{i})\left({\color{red}-e\mathrm{i}}{\color{blue}-f}\right), e,f\in \mathbb{R}\\ =&-cf-de+\mathrm{i}(-ce+df) \end{aligned}

将所有项的系数累加,仍然有实部相等,虚部为相反数。故而 p(ωnk)=q(ωnk modn)\overline{p(\omega_n^k)}=q(\omega_n^{-k\space\bmod n})\qquad \square

所以此时我们共轭(求共轭复数可用 std::conj(complex<>) )就能在线性时间内求出 q(x)q(x) 的点值表示。而此时我们知晓 p(x),q(x)p(x),q(x) 的点值表示,就可以解二元一次方程组,求得 f(x),g(x)f(x), g(x) 的点值表示。之后正常IDFT即可。

洛谷题库 P3803 R78318988 记录详情

  1. 1
  2. 2
  3. 3
  4. 4
  5. 5
  6. 6
  7. 7
  8. 8
  9. 9
  10. 10
  11. 11
  12. 12
  13. 13
  14. 14
  15. 15
  16. 16
  17. 17
  18. 18
  19. 19
  20. 20
  21. 21
  22. 22
  23. 23
  24. 24
  25. 25
  26. 26
  27. 27
  28. 28
  29. 29
  30. 30
  31. 31
  32. 32
  33. 33
  34. 34
  35. 35
  36. 36
  37. 37
  38. 38
  39. 39
  40. 40
  41. 41
  42. 42
  43. 43
  44. 44
  45. 45
  46. 46
  47. 47
  48. 48
  49. 49
  50. 50
  51. 51
  52. 52
  53. 53
  54. 54
  55. 55
  56. 56
  57. 57
  58. 58
  59. 59
  60. 60
  61. 61
  62. 62
  63. 63
  64. 64
  65. 65
  66. 66
  67. 67
  68. 68
  69. 69
  70. 70
  71. 71
  72. 72
  73. 73
  74. 74
  75. 75
  76. 76
  77. 77
  78. 78
  79. 79
  80. 80
  81. 81
  82. 82
  83. 83
  84. 84
  85. 85
  86. 86
  87. 87
  88. 88
  89. 89
  90. 90
  91. 91
  92. 92
  93. 93
  94. 94
#include <bits/stdc++.h> using namespace std; #define inl inline template <typename T> inl bool read (T &x) { x = 0; int f = 1; char c = getchar (); for (; c != EOF && !isdigit (c); c = getchar ()) if (c == '-') f = -1; if (c == EOF) return 0; for (; c != EOF && isdigit (c); c = getchar ()) x = (x<<1) + (x<<3) + (c^48); x *= f; return 1; } template <typename T, typename... Targs> inl bool read (T &x, Targs&... args) { return read (x) && read (args...); } template <typename T> void print (T x) { if (x < 0) putchar ('-'), x = -x; if (x > 9) print (x / 10); putchar ('0' + x % 10); } template <typename T, typename... Targs> inl void print (T x, Targs... args) { print (x), putchar (' '), print (args...); } #define reint register int #define newl putchar('\n') typedef long long ll; // typedef unsigned long long ull; // typedef __int128 lll; // typedef long double llf; typedef pair <int, int> pint; #define fst first #define scd second #define all(p) begin (p), end (p) using comp = complex <double>; using vint = vector <int>; using vcomp = vector <comp>; constexpr int N = 1<<21, INF = 0x3f3f3f3f; int n, m; vint f, g; inl void henkan (vcomp &f, int l) { static int tr[N], lst = tr[0] = 0; if (lst != l) { lst = l; for (int x = 1; x < 1<<l; ++x) tr[x] = tr[x>>1]>>1|((1<<l-1) * (x & 1)); } for (int x = 1; x < 1<<l; ++x) if (tr[x] < x) swap (f[tr[x]], f[x]); } inl void FFT (vcomp &f, int l, bool rev) { f.resize (1<<l, 0); henkan (f, l); for (int len = 2; len <= 1<<l; len <<= 1) { const comp w_n = exp (M_PI / len * 2.0i); comp w = 1, g, h; for (int st = 0; st < 1<<l; st += len, w = 1) for (int i = st; i < st + len/2; ++i, w *= w_n) g = f[i], h = f[i + len/2] * w, f[i] = g + h, f[i + len/2] = g - h; } if (!rev) return; for (int x = 0; x < 1<<l; ++x) f[x] /= 1<<l; reverse (begin (f) + 1, end (f)); } inl vint mul (vint f, vint g) { int len = f.size () + g.size () - 1, l = ceil (log2 (len)); f.resize (1<<l, 0), g.resize (1<<l, 0); vcomp p (1<<l), _f (1<<l); comp _q; for (int x = 0; x < 1<<l; ++x) p[x] = comp (f[x], g[x]); FFT (p, l, 0); for (int x = 0; x < 1<<l; ++x) _q = conj (p[x ? (1<<l) - x : 0]), // f_1(x)=(p(x)+q(x))/2, g_1(x)=(p(x)-q(x))/2i _f[x] = (p[x] + _q) * (p[x] - _q) / 4.0i; FFT (_f, l, 1); for (int x = 0; x < 1<<l; ++x) f[x] = round (real (_f[x])); return f.resize (len), f; } int main () { /* */ read (n, m); f.resize (n + 1), g.resize (m + 1); for (int i = 0; i <= n; ++i) read (f[i]); for (int i = 0; i <= m; ++i) read (g[i]); for (const int x : mul (f, g)) print (x), putchar (' '); return 0; }
  • 2022年7月3日