以下代码必须开 -O2
#include <algorithm>
#include <cassert>
#include <cstdio>
#include <cstring>
#include <vector>
using namespace std;
#ifdef LOCAL
#define debug(...) fprintf(stderr, ##__VA_ARGS__)
#else
#define debug(...) void(0)
#endif
typedef long long LL;
template <unsigned P> struct modint {
unsigned v; modint() : v(0) {}
template <class T> modint(T x) { x %= (int)P, v = x < 0 ? x + P : x; }
modint operator+() const { return *this; }
modint operator-() const { return modint(0) - *this; }
modint inv() const { return assert(v), qpow(*this, P - 2); }
friend int raw(const modint &self) { return self.v; }
template <class T> friend modint qpow(modint a, T b) {
modint r = 1;
for (; b; b >>= 1, a *= a) if (b & 1) r *= a;
return r;
}
modint &operator+=(const modint &rhs) { if (v += rhs.v, v >= P) v -= P; return *this; }
modint &operator-=(const modint &rhs) { if (v -= rhs.v, v >= P) v += P; return *this; }
modint &operator*=(const modint &rhs) { v = 1ull * v * rhs.v % P; return *this; }
modint &operator/=(const modint &rhs) { return *this *= rhs.inv(); }
friend modint operator+(modint lhs, const modint &rhs) { return lhs += rhs; }
friend modint operator-(modint lhs, const modint &rhs) { return lhs -= rhs; }
friend modint operator*(modint lhs, const modint &rhs) { return lhs *= rhs; }
friend modint operator/(modint lhs, const modint &rhs) { return lhs /= rhs; }
friend bool operator==(const modint &lhs, const modint &rhs) { return lhs.v == rhs.v; }
friend bool operator!=(const modint &lhs, const modint &rhs) { return lhs.v != rhs.v; }
};
typedef modint<998244353> mint;
const int glim(const int &x){return 1 << (32 - __builtin_clz(x));}
const int bitctz(const int &x){return __builtin_ctz(x);}
const vector<mint> wns = []() -> vector<mint> {
vector<mint> wns = {};
for (int j = 1; j <= 23; j++)
wns.push_back(qpow(mint(3), (998244353 - 1) >> j));
return wns;
}();
void ntt(vector<mint> &a, const int &op) {
const int n = a.size();
for (int i = 1, r = 0; i < n; i++) {
r ^= n - (1 << (bitctz(n) - bitctz(i) - 1));
if (i < r) swap(a[i], a[r]);
}
vector<mint> w(n);
for (int k = 1, len = 2; len <= n; k <<= 1, len <<= 1) {
const mint wn = wns[bitctz(k)];
for (int i = raw(w[0] = 1); i < k; i++) w[i] = w[i - 1] * wn;
for (int i = 0; i < n; i += len) {
for (int j = 0; j < k; j++) {
const mint x = a[i + j], y = a[i + j + k] * w[j];
a[i + j] = x + y, a[i + j + k] = x - y;
}
}
}
if (op == -1) {
mint iz = mint(1) / n;
for (int i = 0; i < n; i++) a[i] *= iz;
reverse(a.begin() + 1, a.end());
}
}
vector<mint> getInv(const vector<mint> &a, int lim) {
vector<mint> b = {1 / a[0]};
for (int len = 2; len <= glim(lim); len <<= 1) {
vector<mint> c(a.begin(), a.begin() + min(len, (int)a.size()));
b.resize(len << 1), ntt(b, 1);
c.resize(len << 1), ntt(c, 1);
for (int i = 0; i < len << 1; i++)
b[i] = b[i] * (2 - c[i] * b[i]);
ntt(b, -1), b.resize(len);
}
b.resize(lim);
return b;
}
vector<mint> multiple(vector<mint> a, vector<mint> b) {
int rLen = a.size() + b.size() - 1, len = glim(rLen);
a.resize(len), ntt(a, 1);
b.resize(len), ntt(b, 1);
for (int i = 0; i < len; i++) a[i] *= b[i];
ntt(a, -1), a.resize(rLen);
return a;
}
vector<mint> divide(vector<mint> f, vector<mint> g) {
if (f.size() < g.size()) return {};
int rLen = f.size() - g.size() + 1;
reverse(f.begin(), f.end());
reverse(g.begin(), g.end());
f = multiple(f, getInv(g, rLen));
f.resize(rLen), reverse(f.begin(), f.end());
return f;
}
vector<mint> modulo(vector<mint> f, vector<mint> g) {
int rLen = g.size() - 1;
vector<mint> q = multiple(g, divide(f, g));
q.resize(rLen), f.resize(rLen);
for (int i = 0; i < rLen; i++) f[i] -= q[i];
return f;
}
vector<mint> qpow(vector<mint> a, int b, vector<mint> m) {
vector<mint> r = {1};
for (; b; b >>= 1, a = modulo(multiple(a, a), m)) {
if (b & 1) r = modulo(multiple(r, a), m);
}
return r;
}
int main() {
int n, k;
scanf("%d%d", &n, &k);
vector<mint> m(k + 1), a(k);
m[k] = 1;
for (int i = k - 1, x; i >= 0; i--) scanf("%d", &x), m[i] = -x;
for (int i = 0, x; i < k; i++) scanf("%d", &x), a[i] = x;
vector<mint> b = qpow({0, 1}, n, m);
mint ans = 0;
for (int i = 0; i < k; i++) ans += b[i] * a[i];
printf("%d\n", raw(ans));
return 0;
}
标签:取模,return,int,rhs,vector,modint,const,递推,乘法
From: https://www.cnblogs.com/caijianhong/p/solution-p4723.html