题解 P7468【[NOI Online 2021 提高组] 愤怒的小 N】
problem
首先是有一个字符串 \(S=\texttt{"0"}\),做无限次“将 \(S\) 的每一位取反接在 \(S\) 后面”的操作,形如 \(S=0110100110010110\cdots\)。
另外给一个 \(k-1\) 次多项式 \(f\),求 \(\sum_{i=0}^{n-1}S_if(i).\)
\(n\leq 2^{5\times 10^5}, k\leq 500\)。
solution 0
第一个观察是 \(S_i=parity(i)\)。因为每次将高位拿掉,值就反转。
考虑 dp。\(dp(i, j, 0/1)\) 表示 \([0,2^i)\) 中 \(parity=0/1\) 的数字的 \(j\) 次方和。
转移
初值为 \(dp(0, j, 0)=[j=0]\) 表示只有 \(0\) 一个数字。
\[\begin{aligned} dp(i, j, e)&=dp(i-1, j, e)+\sum_{l=0, parity(l)\neq e}^{2^{i-1}-1}(l+2^{i-1})^j\\ &=dp(i-1, j, e)+\sum_{l=0, parity(l)\neq e}^{2^{i-1}-1}\sum_{t=0}^{j}\binom{t}{j}l^t(2^{i-1})^{j-t}\\ &=dp(i-1, j, e)+\sum_{t=0}^{j}\binom{j}{t}(2^{i-1})^{j-t}\sum_{l=0, parity(l)\neq e}^{2^{i-1}-1}l^t\\ &=dp(i-1, j, e)+\sum_{t=0}^{j}\binom{j}{t}(2^{i-1})^{j-t}dp(i-1, t, e\oplus 1)\\ \end{aligned} \]统计答案
-
取出 \(2^T=lowbit(n), L=n-2^T\)。
-
答案累加 \(\displaystyle\sum_{l=L, parity(l)\neq parity(L)}^{L+2^t-1}f(l)\)。注意这里 \(l-L, L\) 相加不进位,所以这玩意等于
- \[\begin{aligned} \displaystyle\sum_{l=L, parity(l)\neq parity(L)}^{L+2^T-1}\sum_{j=0}^{k-1}f_j(l+L)^j &=\displaystyle\sum_{l=L, parity(l)\neq parity(L)}^{L+2^T-1}\sum_{j=0}^{k-1}\sum_{t=0}^{j}\binom{j}{t} f_jl^tL^{j-t}\\ &=\sum_{j=0}^{k-1}\sum_{t=0}^{j}\binom{j}{t} f_jL^{j-t}\displaystyle\sum_{l=L, parity(l)\neq parity(L)}^{L+2^T-1}l^t\\ &=\sum_{j=0}^{k-1}\sum_{t=0}^{j}\binom{j}{t} f_jL^{j-t}dp(T, t, parity(L)\oplus 1)\\ &=\sum_{t=0}^{k-1}dp(T, t, parity(L)\oplus 1)\sum_{j=t}^{k-1}\binom{j}{t} f_jL^{j-t}\\ \end{aligned} \]
-
\(n:=L\)。
-
明显枚举了所有区间。
optimize
现在的复杂度是 \(O(k^2\log n)\)。
重量级结论是,\(i>j\) 时 \(dp(i, j, 0)=dp(i, j, 1)=\frac{1}{2}\sum_{l=0}^{2^i-1}l^j\)。(怎么证明呢,待补,关键是对 \(i-1\to i\) 归纳,用二项式定理展开,考察各项系数)
换句话来说,对于 \(i>j\) 的一大段区间,我们直接求出整段区间的 \(f\) 的和,然后除以二就断定是区间的答案。这一大段区间,只算 \(i\geq k,j<k\) 的,就是 \(0\) 到 “\(n\) 的二进制表示中后面 \(k\) 为改成 \(0\)” 减一,于是可以计算。并观察到 \(f\) 的前缀和是 \(k-1\) 次多项式,考虑直接拉格朗日插值,\(O(k^2)-O(n+k)\) 完成这一部分。
可能发生 \(i<j\) 的区间,假定是 \(i<k\) 的,暴力计算是 \(O(k^3)\) 的。
所以总的复杂度是 \(O(\log n+k^3)\)。就是将其中一个很大的 \(\log n\) 用结论打成 \(k\)。
code
#include <cstdio>
#include <vector>
#include <cstring>
#include <cassert>
#include <algorithm>
using namespace std;
#ifdef LOCAL
#define debug(...) fprintf(stderr, ##__VA_ARGS__)
#else
#define debug(...) void(0)
#endif
typedef long long LL;
template <unsigned P>
struct modint {
unsigned v;
modint() : v(0) {}
template <class T>
modint(T x) { x %= (int)P, v = x < 0 ? x + P : x; }
modint operator+() const { return *this; }
modint operator-() const { return modint(0) - *this; }
modint inv() const { return assert(v), qpow(*this, P - 2); }
friend int raw(const modint &self) { return self.v; }
template <class T> friend modint qpow(modint a, T b) {
modint r = 1;
for (; b; b >>= 1, a *= a) if (b & 1) r *= a;
return r;
}
modint &operator+=(const modint &rhs) { if (v += rhs.v, v >= P) v -= P; return *this; }
modint &operator-=(const modint &rhs) { if (v -= rhs.v, v >= P) v += P; return *this; }
modint &operator*=(const modint &rhs) { v = 1ull * v * rhs.v % P; return *this; }
modint &operator/=(const modint &rhs) { return *this *= rhs.inv(); }
friend modint operator+(modint lhs, const modint &rhs) { return lhs += rhs; }
friend modint operator-(modint lhs, const modint &rhs) { return lhs -= rhs; }
friend modint operator*(modint lhs, const modint &rhs) { return lhs *= rhs; }
friend modint operator/(modint lhs, const modint &rhs) { return lhs /= rhs; }
friend bool operator==(const modint &lhs, const modint &rhs) { return lhs.v == rhs.v; }
friend bool operator!=(const modint &lhs, const modint &rhs) { return lhs.v != rhs.v; }
};
typedef modint<1000000007> mint;
vector<mint> multiple(const vector<mint> &a, const vector<mint> &b) {
vector<mint> c(a.size() + b.size() - 1);
for (int i = 0; i < a.size(); i++) {
for (int j = 0; j < b.size(); j++) c[i + j] += a[i] * b[j];
}
return c;
}
vector<mint> addition(const vector<mint> &a, const vector<mint> &b) {
vector<mint> c(max(a.size(), b.size()));
for (int i = 0; i < a.size(); i++) c[i] += a[i];
for (int i = 0; i < b.size(); i++) c[i] += b[i];
return c;
}
vector<mint> divide(vector<mint> a, mint b1) {
vector<mint> res(a.size() - 1);
for (int i = (int) a.size() - 1; i >= 1; i--) {
mint coe = res[i - 1] = a[i];
a[i - 1] -= a[i] * b1;
}
return res;
}
vector<mint> numes[510];
mint idenos[510];
vector<mint> lagrange(const vector<mint> &a, const vector<mint> &b) {
assert(a.size() == b.size());
vector<mint> ans(a.size());
for (int i = 0; i < a.size(); i++) {
mint coe = b[i];
for (int j = 0; j < a.size(); j++) ans[j] += numes[i][j] * coe;
}
return ans;
}
mint getValue(const vector<mint> &a, mint x) {
mint res = 0;
for (int i = (int) a.size() - 1; i >= 0; i--)
res = res * x + a[i];
return res;
}
int n, k;
char a[1 << 19];
vector<mint> f, sumG[510], sumF; //sumG[j](n) = sum{i=0..n-1} i^j
mint dp[510][510][2], qp2[1 << 19], binom[510][510];
const mint inv2 = 1 / mint(2);
void init() {
for (int i = raw(qp2[0] = 1); i <= max(k * k, n); i++) qp2[i] = qp2[i - 1] + qp2[i - 1];
for (int i = 0; i < k; i++) {
binom[i][0] = 1;
for (int j = 1; j <= i; j++) binom[i][j] = binom[i - 1][j] + binom[i - 1][j - 1];
}
vector<mint> per = {};
for (int i = 1; i <= k + 1; i++) per.push_back(i);
vector<mint> ans(per.size()), product = {1};
for (int i = 0; i < per.size(); i++)
product = multiple(product, {-per[i], 1});
for (int i = 0; i < per.size(); i++) {
numes[i] = divide(product, -per[i]);
idenos[i] = 1;
for (int j = 0; j < per.size(); j++)
if (i != j) idenos[i] *= per[i] - per[j];
idenos[i] = 1 / idenos[i];
for (int j = 0; j < per.size(); j++) numes[i][j] *= idenos[i];
}
for (int j = 0; j < k; j++) {//这一段没用,,,
vector<mint> tmp = {};
for (int i = 1; i <= k + 1; i++) tmp.push_back(qpow(mint(i - 1), j));
for (int i = 1; i <= k; i++) tmp[i] += tmp[i - 1];
sumG[j] = lagrange(per, tmp);
}
{
vector<mint> tmp = {};
for (int i = 1; i <= k + 1; i++) tmp.push_back(getValue(f, i - 1));
for (int i = 1; i <= k; i++) tmp[i] += tmp[i - 1];
sumF = lagrange(per, tmp);
}
}
void DP() {
for (int j = 0; j < k; j++) dp[0][j][0] = !j;
for (int i = 1; i < min(n, k); i++) {
//for (int i = 1; i < n; i++) {
memcpy(dp[i], dp[i - 1], sizeof dp[i]);
for (int j = 0; j < k; j++) {
for (int e: {0, 1}) {
for (int t = 0; t <= j; t++) {
dp[i][j][e] += dp[i - 1][t][1 - e] * binom[j][t] * qp2[(i - 1) * (j - t)];
}
}
}
}
//forall i > j, dp[i][j][e] = sumg[j](2^i) / 2
}
mint solve() {
mint L = 0, ans = 0;
bool flag = 0;
if (n > k) {
mint lim = 0;
for (int i = n - 1; i >= k; i--) if (a[i]) lim += qp2[i];
ans += getValue(sumF, lim) * inv2;
for (int i = n - 1; i >= k; i--) if (a[i]) {
L += qp2[i], flag ^= 1;
}
}
for (int i = min(k, n) - 1; i >= 0; i--) if (a[i]) {
for (int t = 0; t < k; t++) {
mint coe = 0, now = 1;
for (int j = t; j < k; j++, now *= L)
coe += binom[j][t] * f[j] * now;
ans += dp[i][t][flag ^ 1] * coe;
}
L += qp2[i], flag ^= 1;
}
return ans;
}
int main() {
scanf("%s%d", a, &k), n = strlen(a);
for (int i = 0; i < n; i++) a[i] -= '0';
reverse(a, a + n);
f = vector<mint>(k);
for (int i = 0; i < k; i++) scanf("%u", &f[i].v);
init(), DP();
printf("%d\n", raw(solve()));
return 0;
}
标签:P7468,const,NOI,int,题解,sum,return,modint,size
From: https://www.cnblogs.com/caijianhong/p/solution-P7468.html