多点求值
问题:给定一个 \(n-1\) 次多项式 \(f(x)\),求在 \(a_0,a_2,...,a_{m-1}\) 处分别求得的点值。
\(n,m\le 10^5\)
首先我们先钦定 \(n=m\),否则也可以适当补,下文中用 \(n\) 来代替 \(m\)。
设 \(F=[f_0,f_1,...,f_{n-1}]\),\(A = \begin{bmatrix} a_0^0 & a_0^1 & ... & a_0^{n - 1} \\ a_1^0 & a_1^1 & ... & a_1^{n - 1} \\ ... & ... & ... & ... \\ a_{n - 1}^0 & a_{n - 1} ^ 1 & ... & a_{n - 1} ^ {n - 1}\end{bmatrix}\)。
那么我们要求的答案为 \(Y_i = \sum_{j = 0} ^ {n - 1} a_i^j f_j\),即 \(Y = AF\)。
考虑转置原理,我们考虑求 \(Y'=A^T F\),有 \(Y_i' = \sum_{j = 0} ^ {n - 1} a_j^i f_j\)。
考虑 \(Y_{0..n - 1}'\) 的生成函数:
\[\begin{aligned} \sum_{i = 0} x^i\sum_{j = 0} ^ {n - 1} a_j^i f_j &= \sum_{j=0} ^{n - 1} f_j\sum_{i = 0} x^ia_j^i \\ &= \sum_{j = 0} ^{n - 1} f_j\sum_{i = 0} (a_jx)^i \\ &= \sum_{j = 0} ^{n - 1} f_j \cdot \dfrac 1 {1 - a_jx} \\ \end{aligned}\]我们很容易把这个式子写成 \(\dfrac {P(x)}{Q(x)}\) 的形式。
具体的,我们分治计算,对于当前区间 \([l,r]\),合并左右 \(\dfrac{P_0(x)}{Q_0(x)}\) 和 \(\dfrac{P_1(x)}{Q_1(x)}\) 的结果为:
\[\dfrac {P_0(x)Q_1(x) + P_1(x)Q_0(x)}{Q_0(x)Q_1(x)} \]即 \(P(x) = P_0(x)Q_1(x) + P_1(x)Q_0(x), \ \ Q(x) = Q_0(x)Q_1(x)\)。
对于叶子 \(i\),\(P(x) = f_i, \ \ Q(x) = 1 - a_ix\)。
接下来我们需要把线性变换转置并倒着计算,就能得到答案。
注意到 \(Q(x)\) 和 \(F\) 无关,看作常量,我们无需关心。
对于 \(P(x)\) 我们把它看作 \(F\) 进行一些线性变换,以及一些相加后的结果。
不妨假设叶子 \(i\) 的线性变换为 \(t[i]\),其中 \(t[i]_{0,i} = 1\),其余位置为 \(0\),容易发现叶子的 \(P(x)\) 为 \(t[i]F\)。
只有叶子有贡献,那么我们最终的 \(Y'\) 就是所有叶子到根路径的所有矩阵的乘积对 \(F\) 施加变换后得到的向量的求和,最后再用总的 \(\dfrac 1{Q_{rt}(x)}\) 变换。
这里,我们的变换矩阵都是 \(Q(x)\),即为多项式乘法,转置后就变成了差卷积。
对于 \(t[i]\),转置后的 \(t[i]^T\) 为:只有 \(t[i]_{i,0}=1\),其余位置为 \(0\)。
至于倒着计算,我们先对 \(F\) 先与 \(\dfrac 1{Q_{rt}(x)}\) 进行差卷积。
然后分治,对于分治区间 \([l,r]\),我们传递进一个多项式,然后对左边差卷积 \(Q_1(x)\),对右边差卷积 \(Q_0(x)\),递归计算。
最后每个叶子都有一个 \(t[i]^T\) 的线性变换,注意到只有 \(t[i]_{i,0}=1\),意味这我们只能用该叶子传递进的多项式的 \(0\) 次项对 \(Y_i\) 做贡献。
换句话说:\(Y_i\) 就是该多项式的 \(0\) 次项。
至此,我们以 \(O(n\log^2n)\) 的复杂度完成了多项式多点求值操作。
点击查看代码
#include <bits/stdc++.h>
#define ll long long
#define ull unsigned ll
#define pir pair <ll, ll>
#define mkp make_pair
#define fi first
#define se second
#define pb push_back
using namespace std;
const ll maxn = 64010, mod = 998244353;
ll power(ll a, ll b = mod - 2) {
ll s = 1;
while(b) {
if(b & 1) s = s * a %mod;
a = a * a %mod; b >>= 1;
} return s;
}
ll pls(const ll x, const ll y) {return (x + y >= mod? x + y - mod : x + y);}
ll mus(const ll x, const ll y) {return (x < y? x + mod - y : x - y);}
void add(ll &x, const ll y) {x = (x + y >= mod? x + y - mod : x + y);}
void sub(ll &x, const ll y) {x = (x < y? x + mod - y : x - y);}
#define Poly vector <ll>
namespace Polygon {
ll rev[maxn << 2], tr;
void Getrev(ll n) {
if(tr == n) return;
tr = n;
for(ll i = 1; i < n; i++)
rev[i] = (rev[i >> 1] >> 1) | (i & 1? n >> 1 : 0);
}
void ntt(ll *a, ll n) {
Getrev(n);
for(ll i = 1; i < n; i++)
if(i < rev[i]) swap(a[i], a[rev[i]]);
for(ll i = 1; i < n; i <<= 1) {
ll g = power(3, (mod - 1) / (i << 1));
for(ll j = 0; j < n; j += (i << 1)) {
for(ll k = 0, t = 1; k < i; k++, t = t * g %mod) {
ll x = a[j|k], y = a[i|j|k] * t %mod;
a[j|k] = pls(x, y), a[i|j|k] = mus(x, y);
}
}
}
}
ll a[maxn << 2], b[maxn << 2];
void Mul(const Poly &A, const Poly &B, Poly &C) {
ll n = A.size(), m = B.size(), k = n + m - 1;
for(ll i = 0; i < n; i++) a[i] = A[i];
for(ll i = 0; i < m; i++) b[i] = B[i];
ll l = 1;
while(l < k) l <<= 1;
ntt(a, l), ntt(b, l);
for(ll i = 0; i < l; i++) a[i] = a[i] * b[i] %mod;
ntt(a, l);
reverse(a + 1, a + l);
ll Inv = power(l);
C.resize(k);
for(ll i = 0; i < l; i++) {
if(i < k) C[i] = a[i] * Inv %mod;
a[i] = b[i] = 0;
}
}
Poly Mul(const Poly &A, const Poly &B) {
Poly C;
Mul(A, B, C);
return C;
}
ll o[maxn], tt;
Poly tmp, fw;
void Getinv(const Poly &A, Poly &B) {
// puts("need inv");
// for(ll i = 0; i < A.size(); i++) printf("%lld ", A[i]);
// puts("");
ll n = A.size();
while(n > 1) {
o[++tt] = n;
n = (n + 1) >> 1;
}
tmp.resize(1);
tmp[0] = power(A[0]);
o[tt + 1] = 1;
for(ll i = tt; i; i--) {
fw = tmp;
Mul(fw, fw, fw);
Poly _A; _A.resize(o[i]);
for(ll j = 0; j < o[i]; j++) _A[j] = A[j];
Mul(fw, _A, fw);
fw.resize(o[i]);
tmp.resize(o[i]);
for(ll j = 0; j < o[i]; j++)
tmp[j] = mus(pls(tmp[j], tmp[j]), fw[j]);
}
B = tmp;
// puts("inv result");
// for(ll i = 0; i < B.size(); i++)
// printf("%lld ", B[i]); puts("");
}
Poly Getinv(const Poly &A) {
Poly B;
Getinv(A, B);
return B;
}
void Mus_Mul(const Poly &A, Poly B, Poly &C) {
reverse(B.begin(), B.end());
Mul(A, B, C);
ll n = A.size(), m = B.size();
for(ll i = 0; i < n; i++)
C[i] = C[i + m - 1];
C.resize(A.size());
}
Poly Mus_Mul(const Poly &A, const Poly &B) {
Poly C;
Mus_Mul(A, B, C);
return C;
}
};
using Polygon::Mul;
using Polygon::Mus_Mul;
using Polygon::Getinv;
ll n, m, a[maxn];
Poly Q[maxn << 4], f;
void calc_Q(ll p, ll l, ll r) {
if(l == r) {
Q[p].resize(2);
Q[p][0] = 1;
Q[p][1] = mod - a[l];
return;
}
ll mid = l + r >> 1;
calc_Q(p << 1, l, mid);
calc_Q(p << 1|1, mid + 1, r);
Mul(Q[p << 1], Q[p << 1|1], Q[p]);
}
void solve(ll p, ll l, ll r, Poly ret) {
ret.resize(r - l + 1);
if(l == r) {
if(l <= m) printf("%lld\n", ret[0]);
return;
}
ll mid = l + r >> 1;
solve(p << 1, l, mid, Mus_Mul(ret, Q[p << 1|1]));
solve(p << 1|1, mid + 1, r, Mus_Mul(ret, Q[p << 1]));
}
int main() {
scanf("%lld%lld", &n, &m);
++n;
f.resize(n);
for(ll i = 0; i < n; i++) {
scanf("%lld", &f[i]);
}
for(ll i = 1; i <= m; i++) {
scanf("%lld", a + i);
}
if(n < m) f.resize(m), n = m;
calc_Q(1, 1, n);
solve(1, 1, n, Mus_Mul(f, Getinv(Q[1])));
return 0;
}