由卷积定理和循环卷积得到,我们对同样长为 n 的数列 {bn} 作 DFT,并令数列 ck=DFT(a,n)kDFT(b,n)k,则有
IDFT(c,n)k=j+q≡k(modn)∑ajbq
于是我们可以利用该原理实现常规意义下的数列卷积:有长为 n 的数列 {an},长 m 的数列 {bm},则将 a,b 高位补 0,分别作 n+m−1 位的 DFT,点值相乘后 IDFT 即可得到 ck=∑i+j≡k(modn+m−1)aibj=∑i+j=k0≤i<n0≤j<maibj。
为行文方便,下文采用 DFT(a,n)k=∑i=0n−1aiωnik 的定义,其中 ωn=exp(2πi/n),即 n 阶单位根。
两种分治方式
DIT(Decimation in Time, 按时域抽取)
相信绝大部分 OIer 首次接触 FFT 时均学习的这种形式。如果 n=2l,l∈N+,那么根据 ωnk=ω2n2k,ω2nn=−1⇒ω2nk+n=−ω2nk(k<n),我们可以分治优化上述过程。由 i 的奇偶性不同(这也是其称为“按时域抽取”的原因),可以将原式划分作
DFT(a,n)k=i=0∑n−1aiωnik=i=0∑n/2−1a2iωn2ik+i=0∑n/2−1a2i+1ωn2ik+k=i=0∑n/2−1a2iωn/2ik+i=0∑n/2−1a2i+1ωn/2ikωnkk<n/2⟹k≥n/2⟹⟹DFT(a,n)k=DFT({a0,a2,⋯,an−2},n/2)k+DFT({a1,a3,⋯,an−1},n/2)kωnkωn/2k=ωn/2kmodn/2=ωn/2k−n/2,ωnk=−ωnk−n/2DFT(a,n)k=DFT({a0,a2,⋯,an−2},n/2)k−n/2−DFT({a1,a3,⋯,an−1},n/2)k−n/2ωnk−n/2
终止条件为 n=1=20 时,DFT(a,1)0=a0。易得时间复杂度为 O(nlogn);由 a 的各个元素在递归树上的走向可知,若将该过程改写作非递归版本(自底向顶计算),则需要一开始将 a 作蝴蝶变换。故而用线性变换的角度看待本算法流程,则其接受的输入是蝴蝶变换后的 a,输出是 a 的离散傅里叶变换(若将 a 视作 n−1 次多项式 f(x),则亦可称其为 x=ωnk,k=0,1,⋯,n−1 时的点值表示)。
DIF(Decimation in Frequency, 按频域抽取)
既然分治时可以按时域变量的二进制表示分治,那么按频域的亦可。现在考虑 k 的奇偶性:
DFT(a,n)k=i=0∑n−1aiωnik=i=0∑n/2−1aiωnik+ai+n/2ωn(i+n/2)k=i=0∑n/2−1(ai+ai+n/2(−1)k)ωnik2∣k⟹2∤k⟹DFT(a,n)kDFT(a,n)k====i=0∑n/2−1(ai+ai+n/2)ωn/2ik/2DFT({a0+an/2,a1+an/2+1,⋯,an/2−1+an−1},n/2)k/2i=0∑n/2−1(ai−ai+n/2)ωniωn/2i(k−1)/2DFT({(a0−an/2)ωn0,⋯,(an/2−1−an−1)ωnn/2−1},n/2)(k−1)/2
如果仍然“偶左奇右”处理后再分治,则与 DIT 同理,若输入序列 a,我们最终将得到蝴蝶变换后的 DFT(a,n)。
下面是 DIT 和 DIF 在 Z/pZ,p=998244353 上的朴素实现。模板参数 rev 指示是否作逆变换。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
#include<bits/stdc++.h>constexprint G =1<<21, P =998244353, gen =3;#defineinlinlineusingnamespace std;using ll =longlong;
inl ll fpow(ll a, ll b){
ll res =1; a %= P;for(; b; b >>=1){if(b &1)(res *= a)%= P;(a *= a)%= P;}return res;}
inl voidbutterfly(int f[],int l){staticint tr[G], last;if(last != l){ last = l;for(int i =1; i <1<<l;++i)
tr[i]= tr[i>>1]>>1|(i &1)*(1<<l-1);}for(int i =1; i <1<<l;++i)if(tr[i]< i)swap(f[tr[i]], f[i]);}
inl voidreverse(int f[],int l){const ll invl =fpow(1<<l, P-2);for(int i =0; i <1<<l;++i)
f[i]= invl * f[i]% P;reverse(f +1, f +(1<<l));}template<bool rev>
inl voidDIT(int f[],int l){butterfly(f, l);for(int len =2, j =0; len <=1<<l; len <<=1,++j){const ll w_n =fpow(gen,(P-1)/len);
ll g, h, w =1;for(int st =0; st <1<<l; st += len, w =1)for(int i = st; i < st + len/2;++i,(w *= w_n)%= P)
g = f[i], h = f[i + len/2]* w % P,
f[i]=(g + h)% P,
f[i + len/2]=(P + g - h)% P;}if(rev)reverse(f, l);}template<bool rev>
inl voidDIF(int f[],int l){for(int len =1<<l, q =0; len >1; len >>=1,++q){const ll w_n =fpow(gen,(P-1)/len);
ll g, h, w =1;for(int st =0; st <1<<l; st += len, w =1)for(int i = st; i < st + len/2;++i,(w *= w_n)%= P)
g = f[i], h = f[i + len/2],
f[i]=(g + h)% P,
f[i + len/2]=(P + g - h)* w % P;}butterfly(f, l);if(rev)reverse(f, l);}
我们用 DIF 实现 DFT,DIT 实现 IDFT 就可以免去蝴蝶变换。当然,这会略微增加一点编码时间;同时需注意 DIT 得出的点值序列是蝴蝶变换后的。
代码实现
可以按需将 x±y 结果的取模利用位运算优化。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
#include<bits/stdc++.h>usingnamespace std;#defineinlinline/* 快读已省略。 */#definenewlputchar('\n')typedeflonglong ll;// typedef unsigned long long ull;// typedef __int128 lll;// typedef long double llf;typedef pair <int,int> pint;#definefstfirst#definescdsecond#defineall(p)begin(p),end(p)#defineempbemplace_backconstexprint G =1<<21, P =998244353, gen =3;int n, m, f[G], g[G], w[G];
inl ll fpow(ll a, ll b){
ll res =1; a %= P;for(; b; b >>=1){if(b &1)(res *= a)%= P;(a *= a)%= P;}return res;}
inl voidcalc_powg(){
w[0]=1; ll f;constint g =fpow(gen,(P-1)/G);for(int t =0;(1<<t+1)< G;++t){
f = w[1<<t]=fpow(g, G>>t+2);for(int x =1<<t; x <1<<t+1;++x)
w[x]= f * w[x -(1<<t)]% P;}}
inl voidDIT(int f[],int l){static ll g, h;for(int len =2; len <=1<<l; len <<=1)for(int st =0, t =0; st <1<<l; st += len,++t)for(int i = st; i < st + len/2;++i)
g = f[i], h = f[i + len/2],
f[i]=(g + h)% P,
f[i + len/2]=(P + g - h)* w[t]% P;const ll invl =fpow(1<<l, P-2);for(int i =0; i <1<<l;++i)
f[i]= invl * f[i]% P;reverse(f +1, f +(1<<l));}
inl voidDIF(int f[],int l){static ll g, h;for(int len =1<<l; len >1; len >>=1)for(int st =0, t =0; st <1<<l; st += len,++t)for(int i = st; i < st + len/2;++i)
g = f[i], h = f[i + len/2]*(ll) w[t]% P,
f[i]=(g + h)% P,
f[i + len/2]=(P + g - h)% P;}intmain(){/* */read(n, m);calc_powg();for(int i =0; i <= n;++i)read(f[i]);for(int i =0; i <= m;++i)read(g[i]);constint l =ceil(log2(n + m +1));DIF(f, l),DIF(g, l);for(int i =0; i <1<<l;++i)
f[i]=(ll) f[i]* g[i]% P;DIT(f, l);for(int i =0; i <= n + m;++i)print(f[i]),putc(' ');return0;}
两个月,反反复复看了8次才看明白
这个东西果然还是得想办法形象理解才行
一些细小的跳步也很容易导致看不懂
实际上是当时我写得不太好。
原课件里面使用“转置原理”解释了该算法,但我并没能理解,所以就用了很笨拙的方法(拆二进制位算贡献)证明了它的正确性。
现在理解了转置原理,近期有重写本篇的计划。
感谢!
已经把这个优化应用在三模 NTT 和多点求值里了
理论上也应该可以用于 FFT 和 MTT 吧?望讲解。
是这样的,上面的推导全是关于DIT这个线性变换本身的,跟在哪个数域(例如 $\mathbb{Z}/p\mathbb Z,\mathbb C$,前者上有NTT,后者上有FFT)上没有关系。稍后我会补全。