0%

[CTSC2010]性能优化

深知自己快退役了,所以这段时间打算把 FFT 原理的那一部分搞明白

首先,我们需要知道平常我们做的 FFT 实际上是循环卷积,循环卷积的长度相当于单位根的下指标,而平时这个值都是大于最后次数的,所以和普通卷积并无差别,具体来说我们要求的是

\[ c_k=\sum_{i=0}^{n-1}\sum_{j=0}^{n-1}a_ib_j[(i+j)\operatorname{mod}n=k] \]

由单位根反演可得 \[ \begin{aligned} c_k&=\sum_{i=0}^{n-1}\sum_{j=0}^{n-1}a_ib_j\big(\frac{1}{k}\sum_{d=0}^{n-1}\omega_n^{(i+j-k)d}\big)\\\\&=\frac{1}{k}\sum_{d=0}^{n-1}\omega_n^{-kd}\big(\sum_{i=0}^{n-1}a_i\omega_n^{id}\big)\big(\sum_{j=0}^{n-1}b_j\omega_n^{jd}\big) \end{aligned} \]

我们发现前一部分是 IDFT 而后一部分是 DFT,合起来就是 FFT

但这就要求我们可以实现任意长度 FFT,然而对于 FFT 实际求解中最关键的一步却要求需要为 \(2\) 的整数次幂,具体过程如下

考虑当前多项式 \(A(x)\) 我们不妨设

\[ A^{[0]}(x)=\sum_{i\operatorname{mod}2=0}a_ix^{\frac{i}{2}} \]

以及

\[ A^{[1]}(x)=\sum_{i\operatorname{mod}2=1}a_ix^{\frac{i-1}{2}} \]

那么显然

\[ A(\omega_n^i)=A^{[0]}(w_n^{2i})+w_n^iA^{[1]}(\omega_n^{2i}) \]

由折半引理得,对于前一半

\[ A(\omega_n^i)=A^{[0]}(w_{\frac{n}{2}}^i)+w_n^iA^{[1]}(w_{\frac{n}{2}}^i) \]

对于后一半而言

\[ A(\omega_n^i)=A^{[0]}(w_{\frac{n}{2}}^i)-w_n^iA^{[1]}(w_{\frac{n}{2}}^i) \]

所以达到问题规模减小的目的

但对于一般的 \(n\) 而言我们需要找到它的因子 \(m\) 类比上方做如下变换,设

\[ A^{[j]}(x)=\sum_{i\operatorname{mod}m=j}a_ix^{\frac{i-j}{m}} \]

而最终的

\[ \begin{aligned} A(w_n^i)&=\sum_{j=0}^{m-1}\omega_n^{ij}A^{[j]}(\omega_n^{im})\\\\&=\sum_{j=0}^{m-1}\omega_n^{ij}A^{[j]}(\omega_{\frac{n}{m}}^i) \end{aligned} \]

这道题由于保证 \(n\) 可以表示为若干不超过 \(10\) 的正整数乘积,所以复杂度得到保证,更一般的我们可以使用 Bluestein's Algorithm

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
#include <cstring>
#include <cctype>
#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;
typedef long long i64;
inline int read(int f = 1, int x = 0, char ch = ' ')
{
while(!isdigit(ch = getchar())) if(ch == '-') f = -1;
while(isdigit(ch)) x = x*10+ch-'0', ch = getchar();
return f*x;
}
const int N = 5e5+5;
int n, C, m, P, gn, lim, d[N]; i64 g[N];
i64 qpow(i64 a, int b) { i64 ret = 1; for( ; b; b >>= 1, a = a*a%P) if(b&1) ret = ret*a%P; return ret; }
void prepare()
{
vector<int> p;
for(int i = 2, n = P-1; i <= 7; ++i)
if(n%i == 0)
{
p.push_back(i);
while(n%i == 0) n /= i, d[++m] = i;
}
for(gn = 1; ; ++gn)
{
int i = 0;
for( ; i < p.size()&&qpow(gn, (P-1)/p[i]) != 1; ++i);
if(i == p.size()) break;
}
g[0] = 1; for(int i = 1; i < n; ++i) g[i] = g[i-1]*gn%P;
}
struct Poly
{
vector<int> A;
int& operator [] (const int i) { return A[i]; }
int ti() { return A.size()-1; }
void set(int ti) { A.resize(ti+1); }
}A, B;
void _FFT(Poly &A, int _)
{
if(!A.ti()) return; int n = A.ti()+1, m = d[_];
vector<Poly> B; B.resize(m);
for(int i = 0; i < m; ++i) B[i].set(n/m-1);
for(int i = 0; i <= n-1; ++i) B[i%m][(i-(i%m))/m] = A[i];
for(int i = 0; i < m; ++i) _FFT(B[i], _+1);
for(int i = 0, q = lim/n; i <= n-1; ++i)
{
A[i] = 0;
for(int j = 0, p = 0; j < m; ++j, p = (p+i)%n)
A[i] = (A[i]+g[p*q]*B[j][i%(n/m)]%P)%P;
}
}
void FFT(Poly &A, int t)
{
if(!t) _FFT(A, 1);
else
{
reverse(++A.A.begin(), A.A.end()), _FFT(A, 1); i64 v = qpow(n, P-2);
for(int i = 0; i < lim; ++i) A[i] = A[i]*v%P;
}
}
int main()
{
n = read(), C = read(), A.set(n-1), B.set(n-1), lim = n, P = n+1, prepare();
for(int i = 0; i < n; ++i) A[i] = read();
for(int i = 0; i < n; ++i) B[i] = read();
FFT(A, 0), FFT(B, 0);
for(int i = 0; i < n; ++i) A[i] = A[i]*qpow(B[i], C)%P;
FFT(A, 1);
for(int i = 0; i < n; ++i) printf("%d\n", A[i]);
return 0;
}