总算把之前摸鱼多项式欠下的东西还清了些。。。
常数应该不算特别大
点击查看代码
namespace Polys {
#define Poly std::vector <int>
#define ll long long
const int G = 3, MOD = 998244353;
ll power(ll a, ll b = MOD - 2) {
ll ret = 1;
for (; b; b >>= 1) {
if (b & 1) ret = (ret * a) % MOD;
a = (a * a) % MOD;
}
return ret;
}
const int invG = power(G);
std::vector <int> r;
void initr(int n) {
if (r.size() == n) return;
r.resize(n);
int cnt = 0;
for (int i = 1; i < n; i <<= 1) cnt++;
for (int i = 0; i < n; i++) {
r[i] = (r[i >> 1] >> 1) | ((i & 1) << (cnt - 1));
}
}
#define clr(a, s) memset(a, 0, sizeof(int) * s)
#define cpy(a, b, s) memcpy(a, b, sizeof(int) * s);
void mul(int *a, int *b, int n) {
for (int i = 0; i < n; i++) {
a[i] = 1ll * a[i] * b[i] % MOD;
}
}
void NTT(int *p, int n, int type) {
initr(n);
static int f[MAXN << 1];
for (int i = 0; i < n; i++) f[i] = p[i];
for (int i = 0; i < n; i++) {
if (i < r[i]) std::swap(f[i], f[r[i]]);
}
for (int len = 2; len <= n; len <<= 1) {
int m = len >> 1;
int base;
if (type == 1) base = power(G, (MOD - 1) / len);
else base = power(invG, (MOD - 1) / len);
for (int i = 0; i < n; i += len) {
int w = 1;
for (int j = 0; j < m; j++, w = 1ll * w * base % MOD) {
int cur = 1ll * f[i + j + m] * w % MOD;
f[i + j + m] = f[i + j] - cur;
if (f[i + j + m] < 0) f[i + j + m] += MOD;
f[i + j] = f[i + j] + cur;
if (f[i + j] >= MOD) f[i + j] -= MOD;
}
}
}
if (type == - 1) {
int invn = power(n);
for (int i = 0; i < n; i++) f[i] = 1ll * f[i] * invn % MOD;
}
for (int i = 0; i < n; i++) p[i] = f[i];
}
Poly operator - (const Poly &x, const Poly &y) {
Poly ret((int)x.size() + y.size() - 1);
for (int i = 0; i < ret.size(); i++) {
if (i < x.size()) ret[i] += x[i];
if (i < y.size()) ret[i] -= y[i];
if (ret[i] >= MOD) ret[i] -= MOD;
if (ret[i] < 0) ret[i] += MOD;
}
return ret;
}
Poly operator + (const Poly &x, const Poly &y) {
Poly ret((int)x.size() + y.size() - 1);
for (int i = 0; i < ret.size(); i++) {
if (i < x.size()) ret[i] += x[i];
if (i < y.size()) ret[i] += y[i];
if (ret[i] >= MOD) ret[i] -= MOD;
if (ret[i] < 0) ret[i] += MOD;
}
return ret;
}
Poly operator * (const Poly &x, int c) {
Poly ret(x.size());
for (int i = 0; i < ret.size(); i++) {
ret[i] = 1ll * x[i] * c % MOD;
}
return ret;
}
Poly operator * (const Poly &x, const Poly &y) {
static int a[MAXN << 1], b[MAXN << 1];
cpy(a, &x[0], x.size());cpy(b, &y[0], y.size());
int lim;
for (lim = 1; lim < ((int)x.size() + y.size() - 1); lim <<= 1);
NTT(a, lim, 1);NTT(b, lim, 1);
mul(a, b, lim);
NTT(a, lim, -1);
Poly ret((int)x.size() + y.size() - 1);
cpy(&ret[0], a, ret.size());
clr(a, lim);clr(b, lim);
return ret;
}
void Getinv(const Poly &a, Poly &b, int n) {
if (n == 1) {
b.push_back(power(a[0]));
return;
}
if (n & 1) {
Getinv(a, b, --n);
int sum = 0;
for (int i = 0 ;i < n; i++) {
sum += 1ll * b[i] * a[n - i] % MOD;
if (sum >= MOD) sum -= MOD;
}
b.push_back(1ll * sum * power(MOD - a[0]) % MOD);
return;
}
Getinv(a, b, n >> 1);
Poly tmp(n);
cpy(&tmp[0], &a[0], n);
b = b * 2 - tmp * b * b;
b.resize(n);
}
Poly Inv(const Poly &x) {
Poly ret;
Getinv(x, ret, x.size());
return ret;
}
Poly Der(const Poly &x) {
Poly ret(x.size());
for (int i = 1; i < x.size() ; i++) {
ret[i - 1] = 1ll * i * x[i] % MOD;
}
return ret;
}
std::vector <int> invs;
void initinv(int n) {
if (invs.size() <= n) {
int cur = invs.size();
invs.resize(n + 1);
for (int i = cur; i <= n; i++) {
invs[i] = power(i);
}
}
}
Poly Inter(const Poly &x) {
Poly ret(x.size());
initinv(ret.size());
for (int i = 1; i < ret.size(); i++) {
ret[i] = 1ll * x[i - 1] * invs[i] % MOD;
}
ret[0] = 0;
return ret;
}
Poly ln(const Poly &x) {
Poly ret(x.size());
Poly inv = Inv(x), der = Der(x);
der = der * inv;
der = Inter(der);
der.resize(x.size());
return der;
}
void Getexp(const Poly &a, Poly &b, int n) {
if (n == 1) {
b.push_back(1);
return;
}
if (n & 1) {
Getexp(a, b, n - 1);
n -= 2;
int sum = 0;
for (int i = 0; i <= n; i++) {
sum += 1ll * (i + 1) * a[i + 1] % MOD * b[n - i] % MOD;
if (sum >= MOD) sum -= MOD;
}
initinv(n + 1);
sum =1ll * sum * invs[n + 1] % MOD;
b.push_back(sum);
return;
}
Getexp(a, b , n >> 1);
Poly lnb = b;
lnb.resize(n);lnb = ln(lnb);
for (int i = 0; i < lnb.size(); i++) {
lnb[i] = (a[i] - lnb[i]);
if (lnb[i] < 0) lnb[i] += MOD;
}
lnb[0]++;
if (lnb[0] >= MOD) lnb[0] -= MOD;
b = lnb * b;
b.resize(n);
}
Poly exp(const Poly &x) {
Poly ret;
Getexp(x, ret, x.size());
return ret;
}
}