FFT学习笔记

重学$FFT$

  在$OI$中,随着时代的进步,我们常常会遇到要求两个多项式卷积的情况。显然,两个多项式卷积时间复杂度是$O(n^2)$,但这个速度远远不能满足我们的需求。所以此时,我们就会用到$FFT$。

FFT是什么

  $FFT$全称快速傅里叶变换。在$OI$中,通常用来在$O(nlogn)$时间里解决两个多项式卷积的问题。

多项式知识

什么是多项式

  和初中课本上的定义一样。

  我们由若干个单项式相加组成的代数式叫做多项式多项式中的每个单项式叫做多项式的项,这些单项式中的最高项次数,就是这个多项式的次数。其中多项式中不含字母的项叫做常数项。 ———摘自百度百科

如何确定一个多项式

  一个多项式有两种表示方法,一种是系数表示法,一种是点值表示法。

系数表示法

  系数表示法是我们最常用的表示方法。它定义$A(x)=\sum_{i = 0}^{n - 1} a_ix ^ i$。然后我们就可以唯一确定一个多项式。

点值表示法

  对于一个$n$次的多项式,我们发现将$n+1$个不同的数代入这个多项式,然后算出此时多项式的值。也可以确定这个多项式。可以感性理解一下,类比如何确定一个二次函数。

多项式间的运算

  多项式之间的加减法非常简单,就是对应次数的系数相加即可。都可以在$O(n)$时间内完成。

多项式卷积

  你可以把两个多项式卷积理解成是两个多项式在做乘法。设两个$n$次多项式$A,B$相乘,得到一个$2n$次多项式$C$,我们可以知道:

    $C(x)=\sum_{k=0}^{2n}(\sum_{k=i+j}A_iB_j)x^k$

$FFT$的出现

  上文中讲到两个用系数表示的多项式直接卷积,时间复杂度是$O(n^2)$。那么我们怎么优化这个复杂度呢?观察上文中的定义,我们发现当两个用点值表示的多项式同时代入同一个$x$时,对于每一位,结果多项式$C(x)$的点值就等于$A(x)$的点值乘上$B(x)$的点值。此时求出用点值表示的多项式$C$的时间复杂度显然是$O(n)$的。

  所以我们的问题就变成了如何快速实现系数表示法和点值表示法之间的转化。我们发现如果点值如果代入单位根的话就有很多很好的性质。

单位根

  在复平面上,作出单位圆。以原点为起点,单位圆的$n$等分点为终点,作出$n$个向量。将所得的幅角为正且最小的向量对应的复数称为$\omega_n$,也可以叫$n$次单位根。

单位根好的性质

  性质一:$\omega_{2n}^{2k}=\omega_n^k$

  这个感性理解一下就好了,因为这两个向量的终点显然相同。

  性质二:$\omega_{n}^{k+\frac{n}{2}}=-\omega_n^k$

  可以想象一下这两个向量在复平面上关于原点对称。

  性质三:$\omega_{n}^{n}=\omega_{n}^ {0} = 1$

$FFT$的过程

  我们将一个$n$个多项式$A$,按照次数分成奇和偶两组,并将一个$x$提出后扔掉,分别记为$A_1$和$A_2$。举一个例子:对于多项式$A(x)=2x^3+x^2+4x+3$,$A_1(x)=2x^2+4$,$A_2(x)=x+3$,所以我们发现$A(x)=A_1(x^2)+x×A_2(x^2)$。

  设$k<\frac{n}{2}$我们将$\omega_{n}^{k}$代入$A(x)$,可以得到如下式子:

    $A(\omega_{n}^{k}) \\= A_1(\omega_{n}^{2k})+\omega_{n}^{k}×A_2(\omega_{n}^{2k}) \\= A_1(\omega_{\frac{n}{2}}^{k})+\omega_{n}^{k}×A_2(\omega_{\frac{n}{2}}^{k})$

  我们再将$k+\frac{n}{2}$代入式子,可以得到:

   $
A(\omega_{n}^{k + \frac{n}{2}})\\=A_1(\omega_{n}^{2k + n})+\omega_{n}^{k + \frac{n}{2}} × A_2(\omega_{n}^{2k + n}) \\= A_1(\omega_{n}^{2k} × \omega_{n} ^ {n}) - \omega_{n}^{k} × A_2(\omega_{n}^{2k} × \omega_{n}^{n}) \\= A_1(\omega_{n}^{2k}) - \omega_{n}^{k} × A_2(\omega_{n}^{2k}) \\= A_1(\omega_{\frac{n}{2}} ^ {k}) - \omega_{n} ^ {k} × A_2(\omega_{\frac{n}{2}}^{k})$

  观察这两个式子,我们发现当我们将$k$取遍$[0,\frac{n}{2})$时,$k+n$取遍了$[\frac{n}{2}, n)$,所以当我们知道$A_1(x)$和$A_2(x)$在$\omega_{\frac{n}{2}} ^ {0}, \omega_{\frac{n}{2}} ^ {1},\omega_{\frac{n}{2}} ^ {2} \cdots \omega_{\frac{n}{2}} ^ {\frac{n}{2} - 1}$的取值时,我们就可以知道$A(x)$在$\omega_{n} ^ {0}, \omega_{n} ^ {1},\omega_{n} ^ {2} \cdots \omega_{n} ^ {n - 1}$的取值。求出$A(n)$时间复杂度是$O(n)$。

  上面的过程都可以很方便地用递归分治来实现。通过主定理我们知道总的时间复杂度是$O(nlogn)$的。

$FFT$的优化

  如果用递归实现$FFT$的话,常数很大,容易$T$掉。我们发现分治到边界时下标等于原来下标的二进制位翻转。所以我们就可以用迭代实现这个过程。

$FFT$代码

// Author: 23forever
#include <bits/stdc++.h>
#define pb push_back
#define pii pair<int, int>
#define mp make_pair
#define fi first
#define se second
typedef long long LL;
const int MAXN = 4000000;
using namespace std;

typedef vector < double > poly;
typedef complex < double > comp;
typedef vector < comp > vc;

namespace polynomial {

const double PI = acos(-1);

int bitrev[MAXN + 5];

void FFT(vc &a, int opt) {
  int len = a.size();
  for (int i = 0; i < len; ++i) {
    if (i < bitrev[i]) swap(a[i], a[bitrev[i]]);
  }

  for (int i = 2; i <= len; i <<= 1) {
    comp wn(cos(2 * PI / i), opt * sin(2 * PI / i));

    for (int j = 0; j < len; j += i) {
      comp w(1, 0);

      for (int k = j; k < j + i / 2; ++k, w *= wn) {
        comp u = a[k], v = w * a[k + i / 2];
        a[k] = u + v, a[k + i / 2] = u - v;
      }
    }
  }
}

poly operator * (const poly vec_a, const poly vec_b) {
  int len = 1, cnt = 0, tot = vec_a.size() + vec_b.size() - 1;
  while (len < tot) len <<= 1, ++cnt;
  for (int i = 0; i < len; ++i) {
    bitrev[i] = (bitrev[i >> 1] >> 1) | ((i & 1) << (cnt - 1));
  }

  vc a(len, 0), b(len, 0);
  for (int i = 0; i < vec_a.size(); ++i) a[i] = vec_a[i];
  for (int i = 0; i < vec_b.size(); ++i) b[i] = vec_b[i];

  FFT(a, 1);
  FFT(b, 1);
  for (int i = 0; i < len; ++i) a[i] *= b[i];

  FFT(a, -1);

  poly ret(tot, 0);
  for (int i = 0; i < tot; ++i) ret[i] = a[i].real() / len;
  return ret;
}

}

int n, m;
poly a, b;

void init() {
  ios::sync_with_stdio(false);
  cin.tie(0);

  cin >> n >> m;
  a.resize(n + 1, 0), b.resize(m + 1, 0);

  for (int i = 0; i <= n; ++i) cin >> a[i];
  for (int i = 0; i <= m; ++i) cin >> b[i];
}

int main() {
#ifdef forever23
  freopen("test.in", "r", stdin);
  //freopen("test.out", "w", stdout);
#endif
  init();

  using namespace polynomial;

  poly c = a * b;
  for (int i = 0; i < c.size(); ++i) cout << int(round(c[i])) << ' ';
  cout << endl;
  return 0;
}

$NTT$代码

#include <bits/stdc++.h>
typedef long long LL;
const int P = 998244353;
const int MAXN = 400000;
const int MAXL = 20;
using namespace std;

int fastPow(int b, int p) {
  int ret = 1;

  while (p) {
    if (p & 1) ret = 1LL * ret * b % P;
    b = 1LL * b * b % P;
    p >>= 1;
  }

  return ret;
}

typedef vector < LL > poly;

namespace polynomial {

const int G = 3; 
LL bitrev[MAXN + 5], wn[MAXL + 5];

void ntt(poly &a) {
  int len = a.size();
  for (int i = 0; i < len; ++i) {
    if (i < bitrev[i]) swap(a[i], a[bitrev[i]]);
  }

  for (int i = 2, d = 1; i <= len; i <<= 1, ++d) {
    for (int j = 0; j < len; j += i) {
      LL w = 1;

      for (int k = j; k < j + i / 2; ++k, w = w * wn[d] % P) {
        LL u = a[k], v = 1LL * w * a[k + i / 2] % P;
        a[k] = (u + v) % P, a[k + i / 2] = (u - v + P) % P;
      }
    }
  }
}

poly operator * (poly a, poly b) {
  int len = 1, cnt = 0, tot = a.size() + b.size() - 1;
  while (len < tot) len <<= 1, ++cnt;
  for (int i = 0; i < len; ++i) {
    bitrev[i] = (bitrev[i >> 1] >> 1) | ((i & 1) << (cnt - 1));
  }

  for (int i = 0; i < MAXL; ++i) wn[i] = fastPow(G, (P - 1) >> i);
  a.resize(len, 0), b.resize(len, 0);

  ntt(a), ntt(b);
  for (int i = 0; i < len; ++i) a[i] = a[i] * b[i] % P;
  ntt(a);

  for (int i = 1; i < len / 2; ++i) swap(a[i], a[len - i]);
  LL inv = fastPow(len, P - 2);
  for (int i = 0; i < len; ++i) a[i] = a[i] * inv % P; 
  while (a.size() > tot) a.pop_back();
  return a;
}

}

int n, m;
poly a, b;

void init() {
  cin >> n >> m;

  a.resize(n, 0), b.resize(m, 0);
  for (int i = 0; i < n; ++i) cin >> a[i];
  for (int i = 0; i < m; ++i) cin >> b[i];
}


int main() {
#ifdef forever23
  freopen("test.in", "r", stdin);
#endif
  init();

  using namespace polynomial;
  poly c = a * b;
  for (int i = 0; i < c.size(); ++i) cout << c[i] << ' ';
  cout << endl;

  return 0;
}

本博客所有文章均采用 CC BY-NC-SA 4.0 许可协议,转载请注明出处
本文链接:https://23forever.com/2019/07/17/FFT/