任意模数下快速数论变换的两种实现

我们实现NTT时,总是在 整数模质数 p=2kq+1p=2^kq+1 上进行的。原因很简单:为了在 Z/pZ\mathbb{Z}/p\mathbb{Z} 上使用原根以套用复数域 C\mathbb{C} 上“单位根”的概念,其阶必然也应当是 2k2^k 的倍数才行。这就为我们带来了很多不便。假如现在给定一质数 pp’ 且不保证有 p=2kq+1p’=2^kq+1,我们就需要另辟蹊径完成卷积。

实现一:中国剩余定理

中国剩余定理:若数 nn 的质因数分解为 i=1kpiei\sum_{i=1}^{k}p_i^{e_i},有整数模 nn 加法(或者,随便你怎么叫吧)

Z/nZZ/p1e1Z×Z/p2e2Z××Z/pkekZ\mathbb{Z}/n\mathbb{Z}\cong \mathbb{Z}/p_1^{e1}\mathbb{Z}\times \mathbb{Z}/p_2^{e2}\mathbb{Z}\times \cdots \times\mathbb{Z}/p_k^{ek}\mathbb{Z}

或者通俗地讲,若 m1,m2,,mkm_1, m_2, \cdots, m_k 两两互质,则线性同余方程组 {xx1(modm1)xx2(modm2)xxk(modmk)\left\{\begin{aligned} x &\equiv x_1\pmod{m_1}\\ x &\equiv x_2\pmod{m_2}\\ &\quad \vdots\\ x &\equiv x_k\pmod{m_k}\end{aligned}\right.modi=1kmi\bmod \prod_{i=1}^{k}m_i 意义下有唯一解。

于是,如果在卷积过程中不对 pp’ 取模,我们会得到上限约为 p2np’^2n 的系数,其中 nn 为多项式的次数。在实际应用中,大概为 102310^{23}。故而我们择取 33 个容易实现NTT的质数 p1,p2,p3p_1,p_2,p_3(我常用 998244353,1004535809,985661441998244353,1004535809,985661441,其三者的最小原根均为 33,偶因数均为 2212^{21}),将原式的系数分别对其取模后在整数模 pip_i 乘法群下做卷积,最后将得到三个系数 x1,x2,x3x_1,x_2,x_3,分别对应在模 pip_i 意义下的实际系数 xx

则根据裴蜀定理中国剩余定理,又由于在整数模质数域上每个非零元素均存在逆元,我们对这些方程组两两合并: x1+k1p1x2+k2p2x(modp1p2)x1+k1p1x2(modp2)k1x2x1p1(modp2)x4x1+k1p1(modp1p2)x4+k4p1p2x3+k3p3x(modp1p2p3)x4+k4p1p2x3(modp3)k4x3x4p1p2(modp3)\begin{aligned} x_1+k_1p_1&\equiv x_2+k_2p_2\equiv x&\pmod{p_1p_2}\\ x_1+k_1p_1&\equiv x_2&\pmod{p_2}\\ k_1&\equiv \dfrac{x_2-x_1}{p_1}&\pmod{p_2}\\ x_4&\equiv x_1+k_1p_1&\pmod{p_1p_2}\\ x_4+k_4p_1p_2&\equiv x_3+k_3p_3\equiv x&\pmod{p_1p_2p_3}\\ x_4+k_4p_1p_2&\equiv x_3&\pmod{p_3}\\ k_4&\equiv\dfrac{x_3-x_4}{p_1p_2}&\pmod{p_3}\\ \end{aligned}

就可以求出在模 pp 意义下的系数了。

洛谷题库 P4345 R78201636 记录详情

  1. 1
  2. 2
  3. 3
  4. 4
  5. 5
  6. 6
  7. 7
  8. 8
  9. 9
  10. 10
  11. 11
  12. 12
  13. 13
  14. 14
  15. 15
  16. 16
  17. 17
  18. 18
  19. 19
  20. 20
  21. 21
  22. 22
  23. 23
  24. 24
  25. 25
  26. 26
  27. 27
  28. 28
  29. 29
  30. 30
  31. 31
  32. 32
  33. 33
  34. 34
  35. 35
  36. 36
  37. 37
  38. 38
  39. 39
  40. 40
  41. 41
  42. 42
  43. 43
  44. 44
  45. 45
  46. 46
  47. 47
  48. 48
  49. 49
  50. 50
  51. 51
  52. 52
  53. 53
  54. 54
  55. 55
  56. 56
  57. 57
  58. 58
  59. 59
  60. 60
  61. 61
  62. 62
  63. 63
  64. 64
  65. 65
  66. 66
  67. 67
  68. 68
  69. 69
  70. 70
  71. 71
  72. 72
  73. 73
  74. 74
  75. 75
  76. 76
  77. 77
  78. 78
  79. 79
  80. 80
  81. 81
  82. 82
  83. 83
  84. 84
  85. 85
  86. 86
  87. 87
  88. 88
  89. 89
  90. 90
  91. 91
#include <bits/stdc++.h> using namespace std; /* 快读已省略。 */ #define inl inline #define reint register int #define newl putchar('\n') typedef long long ll; // typedef unsigned long long ull; // typedef __int128 lll; // typedef long double llf; typedef pair <int, int> pint; #define fst first #define scd second #define all(p) begin (p), end (p) using vint = vector <int>; constexpr int N = 1<<18, INF = 0x3f3f3f3f, gen = 3, P1 = 998244353, P2 = 1004535809, P3 = 985661441, invP1P2 = 669690699, invP1P2P3 = 401569863; int n, m, P; vint f, g, pr1, pr2, pr3; inl ll fpow (ll a, ll b, ll mod) { ll res = 1; a %= mod; for (; b; b >>= 1) { if (b & 1) (res *= a) %= mod; (a *= a) %= mod; } return res; } inl void henkan (vint &f, int l) { static int tr[N], lst = tr[0] = 0; if (lst != l) { lst = l; for (int x = 1; x < 1<<l; ++x) tr[x] = tr[x>>1]>>1|((1<<l-1) * (x & 1)); } for (int x = 1; x < 1<<l; ++x) if (tr[x] < x) swap (f[tr[x]], f[x]); } #define clog2(x) ceil (log2 (x)) #define tomod(x) if (mod < INF) x.resize (mod ,0) inl ll inv (ll x, int mod) { return fpow (x, mod - 2, mod); } template <int P, int gen> inl void NTT (vint &f, int l, bool rev) { f.resize (1<<l, 0); henkan (f, l); for (int len = 2; len <= 1<<l; len <<= 1) { const ll w_n = fpow (gen, (P - 1)/len, P); ll w = 1, g, h; for (int st = 0; st < 1<<l; st += len, w = 1) for (int i = 0; i < len/2; ++i, (w *= w_n) %= P) g = f[i + st], h = f[i + st + len/2] * w % P, f[i + st] = (g + h) % P, f[i + st + len/2] = (g + P - h) % P; } if (!rev) return; const ll p = inv (1<<l, P); for (int x = 0; x < 1<<l; ++x) f[x] = f[x] * p % P; reverse (begin (f) + 1, end (f)); } template <int P, int gen> inl vint mul (vint f, vint g, int mod = INF) { int len = f.size () + g.size () - 1, l = clog2 (len); NTT <P, gen> (f, l, 0), NTT <P, gen> (g, l, 0); for (int x = 0; x < 1<<l; ++x) f[x] = 1ll * f[x] * g[x] % P; NTT <P, gen> (f, l, 1); return f.resize (min (mod, len)), f; } int main () { /* */ read (n, m, P); f.resize (n + 1), g.resize (m + 1); for (int x = 0; x <= n; ++x) read (f[x]); for (int x = 0; x <= m; ++x) read (g[x]); pr1 = mul <P1, gen> (f, g); pr2 = mul <P2, gen> (f, g); pr3 = mul <P3, gen> (f, g); for (int i = 0; i <= n + m; ++i) { ll x4 = (pr1[i] + 1ll * P1 * (ll (pr2[i] - pr1[i] + P2) % P2 * invP1P2 % P2)) % (1LL * P1 * P2), k4 = (pr3[i] - x4 % P3 + P3 + 0ll) % P3 * invP1P2P3 % P3; print ((x4 + k4 * P1 % P * P2 % P) % P), putchar (' '); } return 0; }

实现二:拆系数FFT

如你所见,如果我们无脑不取模直接卷积,造出来的系数在 102310^{23} 级别。如果使用FFT,在IDFT的过程中还要乘上 nn——也就是 102810^{28},甚至更大。就算是 long double 也承受不了,况且还要考虑浮点误差。

因此我们将一个多项式拆成两个多项式分别相乘。现有常数 MMMM 常取 4×1044\times 10^4 或者 p\sqrt{p}),我们对 f(x),g(x)f(x), g(x) 做卷积,则令 f(x)=Mf0(x)+f1(x),g(x)=Mg0(x)+g1(x)f(x)=Mf_0(x)+f_1(x),g(x)=Mg_0(x)+g_1(x)这样一来,四个多项式的系数均在 MM 以下。应用“三次转两次”优化提到的办法,我们应用两次FFT就可以求出其四者的点值表示。将其一一相乘后求得  f0(x)g0(x),f0(x)g1(x)+f1(x)g0(x),f1(x)g1(x)f_0(x)g_0(x), f_0(x)g_1(x)+f_1(x)g_0(x), f_1(x)g_1(x) 的点值表示,(在 modp\bmod p 意义下)依次乘上 M2,M,1M^2, M, 1 的系数后相加就是实际系数。这样我们应用了 55 次FFT,但仍然要使用 long double (别忘了,IDFT完成之前系数乘 nn——大约为 101910^{19} 级别,而 double 的有效数字位(fraction)仅有 5252 位),实际运行效率不比方法一更优。

  1. 1
  2. 2
  3. 3
  4. 4
  5. 5
  6. 6
  7. 7
  8. 8
  9. 9
  10. 10
  11. 11
  12. 12
  13. 13
  14. 14
  15. 15
  16. 16
  17. 17
  18. 18
  19. 19
  20. 20
  21. 21
  22. 22
  23. 23
  24. 24
  25. 25
  26. 26
  27. 27
  28. 28
  29. 29
  30. 30
  31. 31
  32. 32
  33. 33
  34. 34
  35. 35
  36. 36
  37. 37
  38. 38
  39. 39
  40. 40
  41. 41
  42. 42
  43. 43
  44. 44
  45. 45
  46. 46
  47. 47
  48. 48
  49. 49
  50. 50
  51. 51
  52. 52
  53. 53
  54. 54
  55. 55
  56. 56
  57. 57
  58. 58
  59. 59
  60. 60
  61. 61
  62. 62
  63. 63
  64. 64
  65. 65
  66. 66
  67. 67
  68. 68
  69. 69
  70. 70
  71. 71
  72. 72
  73. 73
  74. 74
  75. 75
  76. 76
  77. 77
  78. 78
  79. 79
  80. 80
  81. 81
  82. 82
  83. 83
  84. 84
  85. 85
  86. 86
  87. 87
  88. 88
  89. 89
  90. 90
  91. 91
  92. 92
#include <bits/stdc++.h> using namespace std; /* 快读已省略。 */ #define inl inline #define reint register int #define newl putchar('\n') typedef long long ll; // typedef unsigned long long ull; // typedef __int128 lll; typedef long double llf; typedef pair <int, int> pint; #define fst first #define scd second #define all(p) begin (p), end (p) using comp = complex <llf>; using vcomp = vector <comp>; using vint = vector <int>; constexpr int N = 1<<18, INF = 0x3f3f3f3f, M = 3.2e4; int n, m, num, P; vint f, g; inl void henkan (vcomp &f, int l) { static int tr[N], lst = tr[0] = 0; if (lst != l) { lst = l; for (int x = 1; x < 1<<l; ++x) tr[x] = tr[x>>1]>>1|((1<<l-1) * (x & 1)); } for (int x = 1; x < 1<<l; ++x) if (tr[x] < x) swap (f[tr[x]], f[x]); } #define clog2(x) ceil (log2 (x)) #define tomod(p) if (mod < INF) p.resize (mod, 0) inl void FFT (vcomp &f, int l, bool rev) { f.resize (1<<l, 0); henkan (f, l); for (int len = 2; len <= 1<<l; len <<= 1) { const comp w_n = comp (cos (M_PI/len*2.0l), sin (M_PI/len*2.0l)); comp w = 1, g, h; for (int st = 0; st < 1<<l; st += len, w = 1) for (int i = st; i < st + len/2; ++i, w *= w_n) g = f[i], h = f[i + len/2] * w, f[i] = g + h, f[i + len/2] = g - h; } if (!rev) return; for (int x = 0; x < 1<<l; ++x) f[x] /= 1<<l; reverse (begin (f) + 1, end (f)); } inl void pair_DFT (vcomp &f, vcomp &g, int l) { vcomp p (1<<l); comp _q; for (int x = 0; x < 1<<l; ++x) p[x] = f[x] + 1il * g[x]; FFT (p, l, 0); // p(x)=f(x)+i g(x), q(x)=f(x)-i g(x) for (int x = 0; x < 1<<l; ++x) _q = conj (p[x ? (1<<l) - x : 0]), f[x] = (p[x] + _q) / 2.0l, g[x] = (p[x] - _q) / 2il; } inl vint mul (vint f, vint g, int mod = INF) { int len = f.size () + g.size () - 1, l = clog2 (len); vcomp f1 (1<<l), g1 (1<<l), f0 (1<<l), g0 (1<<l), p (1<<l), q (1<<l), t (1<<l); f.resize (1<<l, 0), g.resize (1<<l, 0); for (int x = 0; x < 1<<l; ++x) f1[x] = f[x] % M, f0[x] = f[x] / M, g1[x] = g[x] % M, g0[x] = g[x] / M; pair_DFT (f0, g0, l), pair_DFT (f1, g1, l); for (int x = 0; x < 1<<l; ++x) p[x] = f0[x] * g0[x], q[x] = f1[x] * g0[x] + f0[x] * g1[x], t[x] = f1[x] * g1[x]; FFT (p, l, 1), FFT (q, l, 1), FFT (t, l, 1); #define coef(x) ((ll) round (real (x)) % P) for (int x = 0; x < 1<<l; ++x) f[x] = (1ll * M * M % P * coef (p[x]) % P + 1ll * M * coef (q[x]) % P + coef (t[x])) % P; return f.resize (len), f; } int main () { /* */ read (n, m, P); f.resize (n + 1); g.resize (m + 1); for (int x = 0; x <= n; ++x) read (f[x]); for (int x = 0; x <= m; ++x) read (g[x]); for (const int x : mul (f, g)) print (x), putchar (' '); return 0; }
  • 2022年7月3日