NTT
前置知识:FFT
NTT,中文“快速数论变换”,是 FFT 在数论领域上的实现,比 FFT 更快,应用更广。
对于 FFT,因为其涉及到复数操作,对于某些需要取模的题不再适用。并且因为需要求正弦与余弦,使用时难以避免精度误差。这时就需要用到 NTT 来解决问题了。
我们知道 FFT 的实现是在复平面上找到了 \(n\) 个不同的点 \(\omega^0_n, \omega^1_n, \dots, \omega^{n - 1}_n\)。由于现在需要解决取模的问题,因此我们考虑在剩余系中找 \(n\) 个等价的点。
原根
我们可以使用原根找到这 \(n\) 个数。由于原根通常较小,所以可以暴力求原根。
假设要取模的奇素数为 \(p\),再设 \(p = qn + 1\) 且 \(n \mid (p - 1)\)。
设 \(g\) 为模 \(p\) 意义下的原根,根据费马小定理和欧拉定理得:\(g^{nq}\equiv g^{\varphi(p)}\equiv 1 \pmod p\)。其中 \(\varphi\) 为欧拉函数,由于 \(p\) 是奇数,故 \(\varphi(p) = p - 1\)。
为了验证原根与单位根的相似性质,不妨令 \(x^i_n = g^{iq} = \left(g^\frac{p - 1}{n}\right)^i\),看做和 \(\omega^i_n\) 等价。我们接下来验证 \(x^i_n\) 是否与 \(\omega^i_n\) 等价。
-
乘法:\(\omega^i_n \times \omega^j_n = \omega^{i + j}_n\)。对于 \(x^i_n\) 显然成立。
-
周期性:\(\omega^i_n = \omega^{i + n}_n\)。由于 \(x^{n}_n = g^{nq} \equiv 1 \pmod p\),所以 \(x^{i + n}_n = x^i_n \times x^n_n = x^i_n\),成立。
-
互异性:\(\omega^0_n, \omega^1_n, \dots, \omega^{n - 1}_n\) 互不相同。根据原根的性质可以很容易得到 \(x^0_n, x^1_n, \dots, x^{n - 1}_n\) 不同。
-
消去引理:\(\forall n \in \mathbb{N}, k \in \mathbb{N}, d \in \mathbb{N}^{*}\),有 \(\omega^{dk}_{dn} = \omega^{k}_n\)。证明:\(x^{dk}_{dn} = g^{dk \cdot \frac{p - 1}{dn}} = \left(g^{\frac{p - 1}{n}}\right)^k=x^k_n\)。
-
折半引理:\(\forall n \in \mathbb{N}, k \in \mathbb{N}\) 且 \(n\) 为偶数,有 \((\omega^{k + n/ 2}_n)^2 = \omega^k_{n / 2}\)。首先 \((x^{k + n/ 2}_n)^2 = x^{2k+n}_n\),因为有周期性 \(x^{2k+n}_n = x^{2k}_n\),再由消去引理得到 \(x^{2k}_n = x^{k}_{n/2}\)。
可以发现,因为 \(x^i_n\) 同样满足消去引理与折半引理,所以后面对多项式的推导中 \(x^i_n\) 与 \(\omega^i_n\) 完全等价。
模数
不是所有模数都可以使用 NTT(至少不能用朴素的 NTT)。
在 FFT 中,为了方便使用折半引理,我们强制将 \(n \leftarrow 2^{\left\lceil\log_2 n\right\rceil}\),这样的话我们的模数必须满足 \(p = q2^k + 1\),其中 \(q\) 是奇数,必须有 \(n \le 2^k\)。后面可能会收录一些 NTT 常用模数,平常可以用 \(998244353\)。
INTT
将 \(x^{-i}_n\) 换成 \(\frac{1}{x^i_n}\) 后用逆元即可,除以 \(n\) 时也用逆元。
代码:
template<class T>
T power(T a, i64 b) {
T res = 1;
for (; b; b >>= 1, a *= a) {
if (b & 1)
res *= a;
}
return res;
}
template<const i64 P>
class ModInt {
public:
i64 x;
static i64 Mod;
ModInt() : x{0} {}
ModInt(int _x) : x{(_x % getMod() + getMod()) % getMod()} {}
ModInt(i64 _x) : x{(_x % getMod() + getMod()) % getMod()} {}
static void setMod(i64 Mod_) {
Mod = Mod_;
}
static i64 getMod() {
return !P ? Mod : P;
}
explicit constexpr operator int() const {
return x;
}
ModInt &operator += (ModInt a) & {
x = x + a.x >= getMod() ? x + a.x - getMod() : x + a.x;
return (*this);
}
ModInt &operator -= (ModInt a) & {
x = x - a.x < 0 ? x - a.x + getMod() : x - a.x;
return (*this);
}
ModInt &operator *= (ModInt a) & {
(x *= a.x) %= getMod();
return (*this);
}
constexpr ModInt inv() {
return power((*this), getMod() - 2);
}
ModInt &operator /= (ModInt a) & {
return (*this) *= a.inv();
}
friend ModInt operator + (ModInt lhs, ModInt rhs) {
return lhs += rhs;
}
friend ModInt operator - (ModInt lhs, ModInt rhs) {
return lhs -= rhs;
}
friend ModInt operator * (ModInt lhs, ModInt rhs) {
return lhs *= rhs;
}
friend ModInt operator / (ModInt lhs, ModInt rhs) {
return lhs /= rhs;
}
friend std::istream &operator >> (std::istream &is, ModInt &p) {
return is >> p.x;
}
friend std::ostream &operator << (std::ostream &os, ModInt p) {
return os << p.x;
}
int operator !() {
return !x;
}
friend bool operator == (ModInt lhs, ModInt rhs) {
return lhs.x == rhs.x;
}
friend bool operator != (ModInt lhs, ModInt rhs) {
return lhs.x != rhs.x;
}
ModInt operator -() {
return ModInt(getMod() - x);
}
ModInt &operator ++() & {
++x;
return *this;
}
ModInt operator ++(int) {
ModInt temp = *this;
++*this;
return temp;
}
} ;
template<>
i64 ModInt<0>::Mod = 998244353;
const int P = 167772161, g = 3;
using Z = ModInt<P>;
struct Comb {
int n;
vector<Z> _fac;
vector<Z> _invfac;
vector<Z> _inv;
Comb() : n{0}, _fac{1}, _invfac{1}, _inv{0} {}
Comb(int n) : Comb() {init(n);}
void init(int m) {
m = min<int>(m, Z::getMod() - 1);
if (m <= n) return;
_fac.resize(m + 1);
_invfac.resize(m + 1);
_inv.resize(m + 1);
for (int i = n + 1; i <= m; i++) _fac[i] = _fac[i - 1] * i;
_invfac[m] = _fac[m].inv();
for (int i = m; i > n; i--) {
_invfac[i - 1] = _invfac[i] * i;
_inv[i] = _invfac[i] * _fac[i - 1];
} n = m;
}
Z fac(int m) {if (m > n) init(2 * m); return _fac[m];}
Z invfac(int m) {if (m > n) init(2 * m); return _invfac[m];}
Z inv(int m) {if (m > n) init(2 * m); return _inv[m];}
Z binom(int n, int m) {return n < m || m < 0 ? 0 : fac(n) * invfac(m) * invfac(n - m);}
} comb;
std::vector<int> rev;
void ExtendRev(int n) {
int m = rev.size();
rev.resize(n);
int s = __builtin_ctz(n) - 1;
for (int i = 0; i < n; ++i)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << s);
}
template<const int P, const int g>
struct Poly : public std::vector<ModInt<P>> {
using V = ModInt<P>;
int invg = V(g).inv().x;
Poly() : std::vector<V>() {}
Poly(int n) : std::vector<V>(n, V {}) {}
Poly(int n, V x) : std::vector<V>(n, x) {}
Poly(std::vector<V> _a) : std::vector<V>(_a) {}
Poly(std::initializer_list<V> _a) : std::vector<V>(_a) {}
template<class InputIt, class = std::_RequireInputIter<InputIt>>
Poly(InputIt first, InputIt last) : std::vector<V>(first, last) {}
std::vector<V> trunc(int n) {
auto f = *this;
f.resize(n);
return f;
}
void dft(int Root = g) {
int n = this->size();
ExtendRev(n);
for (int i = 0; i < n; ++i)
if (i < rev[i]) {
std::swap((*this)[i], (*this)[rev[i]]);
}
for (int k = 1; k < n; k <<= 1) {
V wn = power(V(Root), (P - 1) / (2 * k));
for (int i = 0; i < n; i += 2 * k) {
V w = 1;
for (int j = 0; j < k; ++j, w *= wn) {
auto x = (*this)[i + j], y = (*this)[i + j + k] * w;
(*this)[i + j] = x + y;
(*this)[i + j + k] = x - y;
}
}
}
}
void idft() {
dft(invg);
V invn = V(int(this->size())).inv().x;
for (auto &x : *this)
x *= invn;
}
friend Poly operator * (Poly a, V b) {
for (auto &x : a)
x *= b;
return a;
}
friend Poly operator * (Poly a, Poly b) {
int m = a.size() + b.size() - 1, n = a.size() + b.size();
for (; n != (n & -n); ++n) ;
a.resize(n);
b.resize(n);
a.dft();
b.dft();
for (int i = 0; i < n; ++i)
a[i] *= b[i];
a.idft();
return a.trunc(m);
}
friend Poly operator + (Poly a, Poly b) {
if (a.size() < b.size())
std::swap(a, b);
for (int i = 0; i < b.size(); ++i)
a[i] += b[i];
return a;
}
Poly operator -() {
Poly a = *this;
for (int i = 0; i < a.size(); ++i)
a[i] = -a[i];
return a;
}
friend Poly operator - (Poly a, Poly b) {
return a + -b;
}
} ;
using Pol = Poly<P, g>;
标签:return,int,NTT,笔记,学习,Poly,operator,ModInt,omega
From: https://www.cnblogs.com/CTHOOH/p/18187174