比较安全的模板,传入的数组 \(g\) 有初值也没有问题,且求解过程中不会对传入的 \(f\) 修改
#include <bits/stdc++.h>
using namespace std;
const int N = 1 << 17;
const int mod = 998244353;
bool mem1;
char buf[1 << 23], *p1 = buf, *p2 = buf;
#define getchar() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1 << 21, stdin), p1 == p2) ? EOF : *p1 ++)
int read() {
int s = 0, w = 1; char ch = getchar();
while(!isdigit(ch)) { if(ch == '-') w = -1; ch = getchar(); }
while(isdigit(ch)) { s = s * 10 + (ch ^ 48), ch = getchar(); }
return s * w;
}
template <typename A> int mul(A x) { return x; }
template <typename A, typename ...B> int mul(A x, B ...args) { return 1ll * x * mul(args ...) % mod; }
void inc(int &a, int b) { a = a >= mod - b ? a - mod + b : a + b; }
int ksm(int a, int b) {
int res = 1;
while(b > 0) {
if(b & 1) res = mul(res, a);
a = mul(a, a), b >>= 1;
}
return res;
}
int f[N], g[N], rev[N], iv[N];
int inv2 = ksm(2, mod - 2), inv3 = ksm(3, mod - 2);
void ntt(int *f, int op, int len) {
for(int i = 0; i < len; ++ i)
if(i < rev[i]) swap(f[i], f[rev[i]]);
for(int i = 2; i <= len; i <<= 1) {
int base = ksm(op == 1 ? 3 : inv3, (mod - 1) / i);
for(int j = 0, p = i >> 1; j < len; j += i)
for(int k = 0, pw = 1; k < p; ++ k) {
int x = f[j + k], y = mul(pw, f[j + k + p]);
f[j + k] = (x + y) % mod, f[j + k + p] = (x - y + mod) % mod;
pw = mul(pw, base);
}
}
if(op == -1)
for(int i = 0, inv = ksm(len, mod - 2); i < len; ++ i)
f[i] = mul(f[i], inv);
}
int f_in[N], g_in[N];
void inv(int *f, int *g, int n) {
int len;
g[0] = ksm(f[0], mod - 2);
for(len = 1; len < (n << 1); len <<= 1) {
int lim = (len << 1);
for(int i = 0; i < len; ++ i) g_in[i] = g[i], f_in[i] = f[i];
for(int i = 0; i < lim; ++ i)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) * len);
ntt(f_in, 1, lim), ntt(g_in, 1, lim);
for(int i = 0; i < lim; ++ i)
f_in[i] = mul(g_in[i], (2ll - mul(f_in[i], g_in[i]) + mod) % mod);
ntt(f_in, -1, lim);
for(int i = 0; i < lim; ++ i) g[i] = f_in[i];
for(int i = len; i < lim; ++ i) g[i] = 0;
}
for(int i = n; i < len; ++ i) g[i] = 0;
for(int i = 0; i < len; ++ i) f_in[i] = g_in[i] = 0;
}
int f_sq[N], g_sq[N];
void sqrt(int *f, int *g, int n) {
int len;
g[0] = 1;
for(len = 1; len < (n << 1); len <<= 1) {
int lim = (len << 1);
for(int i = 0; i < len; ++ i) f_sq[i] = f[i];
inv(g, g_sq, len);
for(int i = 0; i < lim; ++ i)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) * len);
ntt(f_sq, 1, lim), ntt(g_sq, 1, lim);
for(int i = 0; i < lim; ++ i) f_sq[i] = mul(f_sq[i], g_sq[i]);
ntt(f_sq, -1, lim);
for(int i = 0; i < lim; ++ i) g[i] = mul((f_sq[i] + g[i]) % mod, inv2);
for(int i = len; i < lim; ++ i) g[i] = 0;
}
for(int i = n; i < len; ++ i) g[i] = 0;
for(int i = 0; i < len; ++ i) f_sq[i] = g_sq[i] = 0;
}
void deriv(int *f, int *g, int n) {
for(int i = 0; i < n; ++ i)
g[i] = mul(i + 1, f[i + 1]);
g[n - 1] = 0;
}
void inter(int *f, int *g, int n) {
for(int i = n - 2; i >= 0; -- i)
g[i + 1] = mul(f[i], iv[i + 1]);
g[0] = 0;
}
int f_ln[N], g_ln[N];
void ln(int *f, int *g, int n) {
deriv(f, f_ln, n);
inv(f, g_ln, n);
int lim = 1, bit = 0;
while(lim <= 2 * n - 2) lim <<= 1, ++ bit;
for(int i = 0; i < lim; ++ i)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << bit - 1);
ntt(f_ln, 1, lim), ntt(g_ln, 1, lim);
for(int i = 0; i < lim; ++ i)
f_ln[i] = mul(f_ln[i], g_ln[i]);
ntt(f_ln, -1, lim);
for(int i = 0; i < n; ++ i) g[i] = f_ln[i];
for(int i = 0; i < lim; ++ i) f_ln[i] = g_ln[i] = 0;
inter(g, g, n);
}
int f_ex[N], g_ex[N];
void exp(int *f, int *g, int n) {
int len;
g[0] = 1;
for(len = 1; len < (n << 1); len <<= 1) {
int lim = (len << 1);
ln(g, g_ex, len);
for(int i = 0; i < len; ++ i)
f_ex[i] = g[i], g_ex[i] = (f[i] - g_ex[i] + mod) % mod;
++ g_ex[0];
for(int i = 0; i < lim; ++ i)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) * len);
ntt(f_ex, 1, lim), ntt(g_ex, 1, lim);
for(int i = 0; i < lim; ++ i) f_ex[i] = mul(f_ex[i], g_ex[i]);
ntt(f_ex, -1, lim);
for(int i = 0; i < lim; ++ i) g[i] = f_ex[i];
for(int i = len; i < lim; ++ i) g[i] = 0;
}
for(int i = n; i < len; ++ i) g[i] = 0;
for(int i = 0; i < len; ++ i) f_ex[i] = g_ex[i] = 0;
}
int n, m;
bool mem2;
signed main() {
cerr << (&mem2 - &mem1) / 1048576. << endl;
iv[0] = iv[1] = 1;
for(int i = 2; i < N; ++ i)
iv[i] = mul(mod - mod / i, iv[mod % i]);
n = read();
for(int i = 0; i < n; ++ i) f[i] = read();
exp(f, g, n);
for(int i = 0; i < n; ++ i) printf("%d ", g[i]);
return 0;
}
标签:int,多项式,全家,len,++,lim,mul,mod
From: https://www.cnblogs.com/wyb-sen/p/17662792.html