原理
参考:
一小时学会快速傅里叶变换(Fast Fourier Transform)
实现
使用C++内置的complex
展开代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
using namespace std;
const double PI = acos(-1);
const int MAXL = 18, N = 1 << MAXL; // 序列周期大小
struct FastFourierTransform {
complex<double> omega[N], omegaInverse[N];
int range;
// 初始化频率
// 可以只进行一次初始化,只要大于可能出现的最高次幂即可,后面的转换就可以复用初始化
void init (const int& n) {
range = n;
for (int i = 0; i < n; ++i) {
omega[i] = complex<double>(cos(2 * PI / n * i), sin( 2 * PI / n * i )) ;
omegaInverse[i] = std::conj(omega[i]);
}
}
// Cooley-Tukey算法:O(n*logn)
void transform (complex<double> *a, const complex<double> *omega, const int &n) {
for (int i = 0, j = 0; i < n; ++i) {
if (i > j) std::swap (a[i], a[j]);
for(int l = n >> 1; ( j ^= l ) < l; l >>= 1);
}
for (int l = 2; l <= n; l <<= 1) {
int m = l / 2;
for (complex<double> *p = a; p != a + n; p += l) {
for (int i = 0; i < m; ++i) {
complex<double> t = omega[range / l * i] * p[m + i];
p[m + i] = p[i] - t;
p[i] += t;
}
}
}
}
// 时域转频域
void dft (complex<double> *a, const int& n) {
transform(a, omega, n);
}
// 频域转时域
void idft (complex<double> *a, const int& n) {
transform(a, omegaInverse, n);
for (int i = 0; i < n; ++i) a[i] /= n;
}
} fft ;
使用自定义的complex,两者在精度上有细微差别但影响不大
展开代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
using namespace std;
const double PI = acos(-1);
const int MAXL = 18, N = 1 << MAXL; // 序列周期大小
struct Complex {
double r, i ;
Complex ( ) { }
Complex ( double r, double i ) : r ( r ), i ( i ) { }
inline void real ( const double &x ) { r = x ; }
inline double real ( ) { return r ; }
inline void imag(const double &x) { i = x; }
inline double imag() { return i; }
inline Complex operator + ( const Complex& rhs ) const {
return Complex ( r + rhs.r, i + rhs.i ) ;
}
inline Complex operator - ( const Complex& rhs ) const {
return Complex ( r - rhs.r, i - rhs.i ) ;
}
inline Complex operator * ( const Complex& rhs ) const {
return Complex ( r * rhs.r - i * rhs.i, r * rhs.i + i * rhs.r ) ;
}
inline void operator /= ( const double& x ) {
r /= x, i /= x ;
}
inline void operator *= ( const Complex& rhs ) {
*this = Complex ( r * rhs.r - i * rhs.i, r * rhs.i + i * rhs.r ) ;
}
inline void operator += ( const Complex& rhs ) {
r += rhs.r, i += rhs.i ;
}
inline Complex conj ( ) {
return Complex ( r, -i ) ;
}
friend ostream& operator << (ostream& out, const Complex& cx) {
out << "(" << cx.r << ", " << cx.i << ")" << endl;
return out;
}
friend istream& operator >> (istream& in, Complex& cx) {
in >> cx.r >> cx.i;
return in;
}
} ;
struct FastFourierTransform {
Complex omega [N], omegaInverse [N] ;
int range;
// 初始化
void init ( const int& n ) {
range = n;
for ( int i = 0 ; i < n ; ++ i ) {
omega [i] = Complex ( cos ( 2 * PI / n * i), sin ( 2 * PI / n * i ) ) ;
omegaInverse [i] = omega [i].conj ( ) ;
}
}
// 迭代式二分转换
void transform ( Complex *a, const int& n, const Complex* omega ) {
for ( int i = 0, j = 0 ; i < n ; ++ i ) {
if ( i > j ) std :: swap ( a [i], a [j] ) ;
for( int l = n >> 1 ; ( j ^= l ) < l ; l >>= 1 ) ;
}
for ( int l = 2 ; l <= n ; l <<= 1 ) {
int m = l / 2;
for ( Complex *p = a ; p != a + n ; p += l ) {
for ( int i = 0 ; i < m ; ++ i ) {
Complex t = omega [range / l * i] * p [m + i] ;
p [m + i] = p [i] - t ;
p [i] += t ;
}
}
}
}
// 时域转频域
void dft ( Complex *a, const int& n ) {
transform ( a, n, omega ) ;
}
// 频域转时域
void idft ( Complex *a, const int& n ) {
transform ( a, n, omegaInverse ) ;
for ( int i = 0 ; i < n ; ++ i ) a [i] /= n ;
}
} fft ;
应用
加速多项式计算O(nlogn)
展开代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34int main() {
// 应用到加速多项式乘法
complex<double> a[500], b[500];
int n, m;
cin >> n >> m;
for (int i = 0; i < n; i++) cin >> a[i];
for (int i = 0; i < m; i++) cin >> b[i];
// a*b之后的真实系数
int r[20] = {0};
for (int i = 0; i < n; i++)
for (int j = 0; j < m; j++)
r[i + j] += a[i].real() * b[j].real();
cout << "真实系数:";
for (int i = 0; i < n + m - 1; i++)
cout << r[i] << " ";
cout << endl;
// 利用fft计算得出的系数
int p = 1;
while (p < n + m - 1) p <<= 1; // 只要p大于多项式结果中的最大次幂即可
fft.init(p);
fft.dft(a, p);
fft.dft(b, p);
complex<double> c[500];
for (int i = 0; i < p; i++) {
c[i] = a[i] * b[i];
}
fft.idft(c, p);
// 考虑到精度问题取实部+0.5向下取整
cout << "fft计算得出的系数:";
for (int i = 0; i < n + m - 1; i++) cout << int(c[i].real() + 0.5) << " ";
return 0;
}输入一组数据得到输出结果如下