任意模数下快速数论变换的两种实现
我们实现NTT时,总是在 整数模质数 域上进行的。原因很简单:为了在 上使用原根以套用复数域 上“单位根”的概念,其阶必然也应当是 的倍数才行。这就为我们带来了很多不便。假如现在给定一质数 且不保证有 ,我们就需要另辟蹊径完成卷积。
实现一:中国剩余定理
中国剩余定理:若数 的质因数分解为 ,有整数模 加法群(或者环,随便你怎么叫吧)
或者通俗地讲,若 两两互质,则线性同余方程组 在 意义下有唯一解。
于是,如果在卷积过程中不对 取模,我们会得到上限约为 的系数,其中 为多项式的次数。在实际应用中,大概为 。故而我们择取 个容易实现NTT的质数 (我常用 ,其三者的最小原根均为 ,偶因数均为 ),将原式的系数分别对其取模后在整数模 乘法群下做卷积,最后将得到三个系数 ,分别对应在模 意义下的实际系数 。
则根据裴蜀定理和中国剩余定理,又由于在整数模质数域上每个非零元素均存在逆元,我们对这些方程组两两合并:
就可以求出在模 意义下的系数了。
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
- 64
- 65
- 66
- 67
- 68
- 69
- 70
- 71
- 72
- 73
- 74
- 75
- 76
- 77
- 78
- 79
- 80
- 81
- 82
- 83
- 84
- 85
- 86
- 87
- 88
- 89
- 90
- 91
#include <bits/stdc++.h> using namespace std; /* 快读已省略。 */ #define inl inline #define reint register int #define newl putchar('\n') typedef long long ll; // typedef unsigned long long ull; // typedef __int128 lll; // typedef long double llf; typedef pair <int, int> pint; #define fst first #define scd second #define all(p) begin (p), end (p) using vint = vector <int>; constexpr int N = 1<<18, INF = 0x3f3f3f3f, gen = 3, P1 = 998244353, P2 = 1004535809, P3 = 985661441, invP1P2 = 669690699, invP1P2P3 = 401569863; int n, m, P; vint f, g, pr1, pr2, pr3; inl ll fpow (ll a, ll b, ll mod) { ll res = 1; a %= mod; for (; b; b >>= 1) { if (b & 1) (res *= a) %= mod; (a *= a) %= mod; } return res; } inl void henkan (vint &f, int l) { static int tr[N], lst = tr[0] = 0; if (lst != l) { lst = l; for (int x = 1; x < 1<<l; ++x) tr[x] = tr[x>>1]>>1|((1<<l-1) * (x & 1)); } for (int x = 1; x < 1<<l; ++x) if (tr[x] < x) swap (f[tr[x]], f[x]); } #define clog2(x) ceil (log2 (x)) #define tomod(x) if (mod < INF) x.resize (mod ,0) inl ll inv (ll x, int mod) { return fpow (x, mod - 2, mod); } template <int P, int gen> inl void NTT (vint &f, int l, bool rev) { f.resize (1<<l, 0); henkan (f, l); for (int len = 2; len <= 1<<l; len <<= 1) { const ll w_n = fpow (gen, (P - 1)/len, P); ll w = 1, g, h; for (int st = 0; st < 1<<l; st += len, w = 1) for (int i = 0; i < len/2; ++i, (w *= w_n) %= P) g = f[i + st], h = f[i + st + len/2] * w % P, f[i + st] = (g + h) % P, f[i + st + len/2] = (g + P - h) % P; } if (!rev) return; const ll p = inv (1<<l, P); for (int x = 0; x < 1<<l; ++x) f[x] = f[x] * p % P; reverse (begin (f) + 1, end (f)); } template <int P, int gen> inl vint mul (vint f, vint g, int mod = INF) { int len = f.size () + g.size () - 1, l = clog2 (len); NTT <P, gen> (f, l, 0), NTT <P, gen> (g, l, 0); for (int x = 0; x < 1<<l; ++x) f[x] = 1ll * f[x] * g[x] % P; NTT <P, gen> (f, l, 1); return f.resize (min (mod, len)), f; } int main () { /* */ read (n, m, P); f.resize (n + 1), g.resize (m + 1); for (int x = 0; x <= n; ++x) read (f[x]); for (int x = 0; x <= m; ++x) read (g[x]); pr1 = mul <P1, gen> (f, g); pr2 = mul <P2, gen> (f, g); pr3 = mul <P3, gen> (f, g); for (int i = 0; i <= n + m; ++i) { ll x4 = (pr1[i] + 1ll * P1 * (ll (pr2[i] - pr1[i] + P2) % P2 * invP1P2 % P2)) % (1LL * P1 * P2), k4 = (pr3[i] - x4 % P3 + P3 + 0ll) % P3 * invP1P2P3 % P3; print ((x4 + k4 * P1 % P * P2 % P) % P), putchar (' '); } return 0; }
实现二:拆系数FFT
如你所见,如果我们无脑不取模直接卷积,造出来的系数在 级别。如果使用FFT,在IDFT的过程中还要乘上 ——也就是 ,甚至更大。就算是 long double
也承受不了,况且还要考虑浮点误差。
因此我们将一个多项式拆成两个多项式分别相乘。现有常数 ( 常取 或者 ),我们对 做卷积,则令
这样一来,四个多项式的系数均在 以下。应用“三次转两次”优化提到的办法,我们应用两次FFT就可以求出其四者的点值表示。将其一一相乘后求得 的点值表示,(在 意义下)依次乘上 的系数后相加就是实际系数。这样我们应用了 次FFT,但仍然要使用 long double
(别忘了,IDFT完成之前系数乘 ——大约为 级别,而 double
的有效数字位(fraction)仅有 位),实际运行效率不比方法一更优。
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
- 64
- 65
- 66
- 67
- 68
- 69
- 70
- 71
- 72
- 73
- 74
- 75
- 76
- 77
- 78
- 79
- 80
- 81
- 82
- 83
- 84
- 85
- 86
- 87
- 88
- 89
- 90
- 91
- 92
#include <bits/stdc++.h> using namespace std; /* 快读已省略。 */ #define inl inline #define reint register int #define newl putchar('\n') typedef long long ll; // typedef unsigned long long ull; // typedef __int128 lll; typedef long double llf; typedef pair <int, int> pint; #define fst first #define scd second #define all(p) begin (p), end (p) using comp = complex <llf>; using vcomp = vector <comp>; using vint = vector <int>; constexpr int N = 1<<18, INF = 0x3f3f3f3f, M = 3.2e4; int n, m, num, P; vint f, g; inl void henkan (vcomp &f, int l) { static int tr[N], lst = tr[0] = 0; if (lst != l) { lst = l; for (int x = 1; x < 1<<l; ++x) tr[x] = tr[x>>1]>>1|((1<<l-1) * (x & 1)); } for (int x = 1; x < 1<<l; ++x) if (tr[x] < x) swap (f[tr[x]], f[x]); } #define clog2(x) ceil (log2 (x)) #define tomod(p) if (mod < INF) p.resize (mod, 0) inl void FFT (vcomp &f, int l, bool rev) { f.resize (1<<l, 0); henkan (f, l); for (int len = 2; len <= 1<<l; len <<= 1) { const comp w_n = comp (cos (M_PI/len*2.0l), sin (M_PI/len*2.0l)); comp w = 1, g, h; for (int st = 0; st < 1<<l; st += len, w = 1) for (int i = st; i < st + len/2; ++i, w *= w_n) g = f[i], h = f[i + len/2] * w, f[i] = g + h, f[i + len/2] = g - h; } if (!rev) return; for (int x = 0; x < 1<<l; ++x) f[x] /= 1<<l; reverse (begin (f) + 1, end (f)); } inl void pair_DFT (vcomp &f, vcomp &g, int l) { vcomp p (1<<l); comp _q; for (int x = 0; x < 1<<l; ++x) p[x] = f[x] + 1il * g[x]; FFT (p, l, 0); // p(x)=f(x)+i g(x), q(x)=f(x)-i g(x) for (int x = 0; x < 1<<l; ++x) _q = conj (p[x ? (1<<l) - x : 0]), f[x] = (p[x] + _q) / 2.0l, g[x] = (p[x] - _q) / 2il; } inl vint mul (vint f, vint g, int mod = INF) { int len = f.size () + g.size () - 1, l = clog2 (len); vcomp f1 (1<<l), g1 (1<<l), f0 (1<<l), g0 (1<<l), p (1<<l), q (1<<l), t (1<<l); f.resize (1<<l, 0), g.resize (1<<l, 0); for (int x = 0; x < 1<<l; ++x) f1[x] = f[x] % M, f0[x] = f[x] / M, g1[x] = g[x] % M, g0[x] = g[x] / M; pair_DFT (f0, g0, l), pair_DFT (f1, g1, l); for (int x = 0; x < 1<<l; ++x) p[x] = f0[x] * g0[x], q[x] = f1[x] * g0[x] + f0[x] * g1[x], t[x] = f1[x] * g1[x]; FFT (p, l, 1), FFT (q, l, 1), FFT (t, l, 1); #define coef(x) ((ll) round (real (x)) % P) for (int x = 0; x < 1<<l; ++x) f[x] = (1ll * M * M % P * coef (p[x]) % P + 1ll * M * coef (q[x]) % P + coef (t[x])) % P; return f.resize (len), f; } int main () { /* */ read (n, m, P); f.resize (n + 1); g.resize (m + 1); for (int x = 0; x <= n; ++x) read (f[x]); for (int x = 0; x <= m; ++x) read (g[x]); for (const int x : mul (f, g)) print (x), putchar (' '); return 0; }
1 Response
[…] 时间复杂度 $operatorname{O}(tABlog (AB))$,常数极大(需采用任意模数下的数论变换)。 […]