现在有两个整数多项式 \(A\),\(B\),\(0 \le a_i,b_i \le 10^9\), \(n \le 10^5\),求它们的卷积,同时系数对 \(p \le 10^9\) 取模。
我们会发现,最终的系数可能会达到 \(10^5 \times 10^9 \times 10^9=10^{23}\) 级别,FFT 会爆 long double 的精度;NTT,由于模数无特殊性质, 完全不能使用。
接下来介绍三种方法解决任意模数的卷积。前两种又称 \(MTT\)。
传统派——三模 NTT
虽然 NTT 受模数限制系数范围较小,但如果对几个不同的模数同时做 NTT,值域就能乘起来。
具体的,我们选三个 NTT 模数 \(A\),\(B\),\(C\),应有 \(ABC > 10^{23}\),然后做三次卷积,对于每一个系数,得到
\[\left \{ \begin{aligned} x \equiv x_1 \pmod A \\ x \equiv x_2 \pmod B \\ x \equiv x_3 \pmod C \\ \end{aligned} \right. \]这当然能用中国剩余定理求解,考虑迭代法。
令 \(k_1A+x_1=x\),
\(k_1A+x_1 \equiv x_2 \pmod B\)
\(k_1 \displaystyle\equiv \frac{(x_2-x_1)}{A} \pmod B\)
所以得到 \(x \equiv k_1A+x_1 \pmod {AB}\)
重复上述过程与 \(C\) 合并,即可得到 \(x \bmod ABC\),在最后一次合并时各项对 \(p\) 取模即可。
接下来我们考虑三个模数的选择,可以是 \(469762049=7 \times 2^{26}+1\),\(998244353=119 \times 2^{23}+1\),\(1004535809 = 479 \times 2^{21}+1\),它们原根都是 \(3\)(这玩意有点难记,建议记分解形式)。
三模 NTT 需要做九次 NTT 运算,而且一堆取模,常数较大。
#include <bits/stdc++.h>
#define int long long
using namespace std;
const int N = 1 << 22;
int n, m, len, rev[N], a[3][N], b[3][N];
int x[3][N];
int pow_mod(int a, int b, int p) {
int res = 1;
a %= p;
while (b) {
if (b & 1) res = res * a % p;
a = a * a % p;
b >>= 1;
}
return res;
}
const int p[3] = {7 * (1 << 26) + 1, 119 * (1 << 23) + 1, 479 * (1 << 21) + 1};
const int g = 3, gi[3] = {pow_mod(3, p[0] - 2, p[0]), pow_mod(3, p[1] - 2, p[1]), pow_mod(3, p[2] - 2, p[2])};
void init() {
for (int i = 0; i < len; i++) {
rev[i] = rev[i >> 1] >> 1;
if (i & 1) rev[i] |= len >> 1;
}
}
void NTT(int *a, bool t, int now) {
const int mod = p[now];
for (int i = 0; i < len; i++) {
if (rev[i] > i) swap(a[rev[i]], a[i]);
}
for (int i = 1; i < len; i <<= 1) {
int wn = pow_mod(t ? g : gi[now], (mod - 1) / (i << 1), mod);
for (int j = 0; j < len; j += (i << 1)) {
int wk = 1;
for (int k = 0; k < i; k++, wk = wk * wn % mod) {
int x = a[j + k], y = wk * a[i + j + k] % mod;
// if (now == 1) cout << wk << ' ' << now << endl;
a[j + k] = (x + y) % mod;
a[i + j + k] = (x - y + mod) % mod;
}
}
}
int inv = pow_mod(len, mod - 2, mod);
if (!t) {
for (int i = 0; i < len; i++) a[i] = a[i] * inv % mod;
}
}
signed main() {
ios::sync_with_stdio(false);
cin.tie(0);
int P;
cin >> n >> m >> P;
for (int i = 0; i <= n; i++) cin >> a[0][i], a[1][i] = a[2][i] = a[0][i];
for (int i = 0; i <= m; i++) cin >> b[0][i], b[1][i] = b[2][i] = b[0][i];
len = 1ll << max(1ll, (int)(ceil(log2(n + m))));
init();
for (int i = 0; i < 3; i++) {
NTT(a[i], 1, i);
NTT(b[i], 1, i);
for (int j = 0; j < len; j++) x[i][j] = a[i][j] * b[i][j] % p[i];
NTT(x[i], 0, i);
}
for (int i = 0; i < n + m + 1; i++) {
int k1 = ((x[1][i] - x[0][i] % p[1]) + p[1]) % p[1] * pow_mod(p[0], p[1] - 2, p[1]) % p[1];
x[0][i] += k1 * p[0];
x[0][i] %= p[0] * p[1];
int k2 = ((x[2][i] - x[0][i]) % p[2] + p[2]) % p[2] * pow_mod(p[0] * p[1], p[2] - 2, p[2]) % p[2];
x[0][i] = (k2 % P * (p[0] * p[1] % P) + x[0][i] % P) % P;
cout << x[0][i] << ' ';
}
return 0;
}
速度派——拆系数 FFT
既然一次乘法会溢出,那么分多次乘法加上去就行了。
我们选一个在 \(\sqrt{p}\) 周围的数 \(M\),将多项式系数拆成 \(CM+D\) 的形式(其中 \(C,D\) 为拆出来的多项式),就有 \(AB=(C_AM+D_A)(C_BM+D_B)=C_AC_BM^2+(D_AC_B+C_AD_B)M+D_AD_B\)。对每一项都进行卷积,然后相加,每一次卷积的值域应该在 \(M^2N=10^{14}\) 左右,可以大力 FFT。
但是如果硬乘,我们要算 \(4 \times3=12\) 次 FFT,考虑到本质上只有四个多项式,我们预处理出它们的 DFT,然后在乘,优化到 \(7\) 次 FFT运算,其中 \(4\) 次 DFT对 \(C_A,C_B,D_A,D_B\),\(3\) 次 IDFT 对 \(C_AC_B,D_AC_B+C_AD_B,D_AD_B\)。
还是不够快?我们就要用到一些仙术了,合并DFT,可以通过一次 DFT 运算求解两个多项式的 DFT。
合并 DFT
以下的 \(i\) 均为虚数单位。
我们设多项式 \(P(x)=A(x)+iB(x)\),\(Q(x)=A(x)-iB(x)\)。
如果求解出 \({DFT(P)},{DFT(Q)}\),那么就能得到
\(DFT(A)=(DFT(P)+DFT(Q))/2\)
\(DFT(B)=(DFT(P)-DFT(Q))/2i\)
考虑 \(P(w_n^k)=A(w_n^k)+iB(w_n^k)=\displaystyle\sum_{j=0}^{n-1}(a_j+ib_j)w_n^{jk}\)
定理:若 \(a_1,b_1\) 为共轭复数,\(a_2,b_2\) 为共轭复数。则 \(a_1a_2,b_1b_2\) 也为共轭复数(共轭复数指实部相同,虚部相反,记为 \(conj\))。
这个是易证的,乘出来就行了。
有了这个,我们知道 \(a_j+ib_j\),\(a_j-ib_j\) 共轭,\(w_n^k\),\(w_n^{-k}\)共轭,那么就有 \(P(w_n^k)\) 与 \(Q(w_n^{-k})\) 共轭,换句话说, \(conj(DFT(P)[k])=DFT(Q)[n-k]\))。那么我们只要做一次 \(DFT\) 就可以了。
对于 \(IDFT\),有一个很奇怪神的做法,因为 \(IDFT\) 后得到的系数肯定是实数,如果乘个 \(i\),就全是虚数,那么如果把两个点值表示一个放实数,一个放虚数,就可以一次 \(IDFT\) 算出两个多项式的系数了,具体证明我也不会,直观理解吧。
这样最后就是 4 次 \(FFT\) 了,不用常数优化就能有较优秀的复杂度。
精度处理:使用long double;要多取模
#include <bits/stdc++.h>
#define int long long
#define double long double
using namespace std;
struct Complex{
double x, y;
Complex operator + (const Complex &b) const {return {x + b.x, y + b.y};}
Complex operator - (const Complex &b) const {return {x - b.x, y - b.y};}
Complex operator * (const Complex &b) const {return {x * b.x - y * b.y, x * b.y + y * b.x};}
Complex operator / (double k) const {return {x / k, y / k};}
Complex conj() {return {x, -y};}
Complex () {}
Complex(double a, double b) {x = a, y = b;}
};
const int N = 4e5 + 5, M = 1 << 15;
const double pi = acos(-1);
int a[N], b[N], rev[N], n, m, len, mod;
Complex pre[N];
int add(int x, int y) {
return (x + y) % mod;
}
int mult(int x, int y) {
return x * y % mod;
}
void init() {
for (int i = 0; i < len; i++) {
rev[i] = rev[i >> 1] >> 1;
if (i & 1) rev[i] |= len >> 1;
}
for (int i = 0; i < len; i++) pre[i] = {std::cos(2 * pi * i / len), std::sin(2 * pi * i / len)};
}
void FFT(Complex *a, bool t) {
for (int i = 0; i < len; i++) {
if (i < rev[i]) swap(a[i], a[rev[i]]);
}
for (int i = 1; i < len; i <<= 1) {
for (int j = 0; j < len; j += (i << 1)) {
for (int k = 0; k < i; k++) {
Complex x = a[j + k], y = a[i + j + k] * (t ? pre[k * len / (2 * i)] : pre[k * len / (2 * i)].conj());
a[j + k] = x + y, a[i + j + k] = x - y;
}
}
}
if (!t) for (int i = 0; i < len; i++) a[i] = a[i] / len;
}
Complex P1[N], P2[N];
Complex A1[N], A2[N], B1[N], B2[N];
void mult(int* a, int *b) {
for (int i = 0; i < len; i++) P1[i] = {(double)(a[i] / M), (double)(a[i] % M)};
for (int i = 0; i < len; i++) P2[i] = {(double)(b[i] / M), (double)(b[i] % M)};
FFT(P1, 1), FFT(P2, 1);
for (int i = 0; i < len; i++) {
auto Q = (i == 0 ? P1[0] : P1[len - i]).conj();
A1[i] = (P1[i] + Q) / 2;
A2[i] = (Q - P1[i]) * Complex(0, 1) / 2;
}
for (int i = 0; i < len; i++) {
auto Q = (i == 0 ? P2[0] : P2[len - i]).conj();
B1[i] = (P2[i] + Q) / 2;
B2[i] = (Q - P2[i]) * Complex(0, 1) / 2;
}
for (int i = 0; i < len; i++) {
auto tmp1 = A1[i] * B1[i], tmp2 = A2[i] * B1[i] + A1[i] * B2[i], tmp3 = A2[i] * B2[i];
A1[i] = tmp1 + tmp2 * Complex(0, 1);
A2[i] = tmp3;
}
FFT(A1, 0), FFT(A2, 0);
for (int i = 0; i < n + m + 1; i++) {
int a = round(A1[i].x), b = round(A1[i].y), c = round(A2[i].x);
cout << add(add(mult(mult(M, M), a % mod), mult(M, b % mod)), c % mod) << ' ';
}
}
signed main() {
ios::sync_with_stdio(false);
cin.tie(0);
cin >> n >> m >> mod;
for (int i = 0; i <= n; i++) cin >> a[i];
for (int i = 0; i <= m; i++) cin >> b[i];
len = 1 << max(1ll, (int)ceil(log2(n + m)));
init();
mult(a, b);
return 0;
}
力大砖飞派——int128
直接选一个大于 \(10^{23}\) 的模数,用 int128 做 NTT。
没错,就这样,但是毒点也挺多的,首先要找大模数,可以暴力枚举 \(r \times 2^k+1\) 找,然后用 Miller_rabin 判断素数,最后要 PR 分解质因数找原根……
挺烦的吧,不过,只要你把模数背出来了,上面问题就一个都没有!!!
但是还有几个问题;int128 龟速取模,以及乘法溢出,这对卡常选手自然是手到擒来,像我这种juruo根本玩不来……
这里提供一个很好记的模数(我用python打出来的):\(1234567*2^{88}+1\),原根为\(3\)。
总结
三种方法各有优劣,三模 NTT 稳定,但慢;拆系数 FFT 快,但会有精度问题;int128 看起来简单,实际上有一堆小细节……不过呢,这三种算法都有一个同样的特性:写起来很烦(笑)。
标签:const,FTT,int,DFT,NTT,模数,Complex From: https://www.cnblogs.com/Uuuuuur/p/18018274