考虑第一问,设一个区间的价值 \(g(l, r)\) 为 \(f(l, r) - a_r + a_{l - 1}\),其中 \(a_i = \sum\limits_{j = 1}^i c_j\),\(f(l, r)\) 为 \([l, r]\) 中最大的 \(k\) 个 \(b_i\) 的和,设 \(p_i\) 为以 \(i\) 为右端点,区间价值最大的左端点,那么 \(p_i\) 满足决策单调性,也就是 \(p_i \le p_{i + 1}\)。
证明即证:
\[g(a, c) + g(b, d) \ge g(a, b) + g(c, d) \]其中 \(a \le b \le c \le d\)。化简得:
\[f(a, c) + f(b, d) \ge f(a, b) + f(c, d) \]即证:
\[f(a, b) + f(a + 1, b + 1) \ge f(a, b + 1) + f(a + 1, b) \]考虑 \([a, b + 1]\) 相当于是 \([a + 1, b]\) 的加入两个单点 \(a\) 和 \(b\),它们能替换 \([a + 1, b]\) 中前 \(k\) 大的最小值和次小值,但是如果换成 \([a, b]\) 和 \([a + 1, b + 1]\),那么 \(a\) 和 \(b\) 只能分别替换 \([a + 1, b]\) 的最小值,所以 \([a, b], [a + 1, b + 1]\) 一定不劣。
所以第一问就可以直接基于这个利用分治算,\(g(l, r)\) 可以维护 \(l, r\) 的指针然后移动指针的同时用树状数组维护。
对于第二问,对于每个点求最大的 \(p_i\),然后把所有 \(g(p_i, i)\) 等于答案的 \([p_i, i]\) 从左往右扫,那么右端点为 \(i\),左端点在 \(l \sim p_i\) 之间(其中 \(l\) 为上一个 \(g(p_j, j)\) 等于答案的 \(p_j\))的区间会对第二问的答案有贡献。就相当于现在有一些形如 \((l, r, k)\) 的操作,意义是把 \(l \sim r\) 且 \(b_i \ge k\) 的 \(i\) 的答案设为 \(1\),那么直接对 \(k\) 扫描线,链表或并查集维护全部 \(b_i \ge k\) 且还没被覆盖过的点即可。
总时间复杂度 \(O(n \log^2 n)\)。
code
#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mkp make_pair
#define mems(a, x) memset((a), (x), sizeof(a))
using namespace std;
typedef long long ll;
typedef double db;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<int, int> pii;
const int maxn = 250100;
ll n, m, a[maxn], b[maxn], lsh[maxn], tot, ans = -1e18;
vector<pii> S, vc[maxn];
vector<int> ap[maxn];
struct BIT {
ll c[maxn];
inline void init() {
mems(c, 0);
}
inline void update(int x, ll d) {
x = tot - x + 1;
for (int i = x; i <= tot; i += (i & (-i))) {
c[i] += d;
}
}
inline ll query(int x) {
x = tot - x + 1;
ll res = 0;
for (int i = x; i; i -= (i & (-i))) {
res += c[i];
}
return res;
}
inline int kth(int k) {
int s = 0, p = 0;
for (int i = 18; ~i; --i) {
if (p + (1 << i) <= tot && s + c[p + (1 << i)] < k) {
s += c[p += (1 << i)];
}
}
return tot - p;
}
} t1, t2;
int l = 1, r;
inline ll calc(int L, int R) {
while (l > L) {
t1.update(b[--l], 1);
t2.update(b[l], lsh[b[l]]);
}
while (r < R) {
t1.update(b[++r], 1);
t2.update(b[r], lsh[b[r]]);
}
while (l < L) {
t1.update(b[l], -1);
t2.update(b[l], -lsh[b[l]]);
++l;
}
while (r > R) {
t1.update(b[r], -1);
t2.update(b[r], -lsh[b[r]]);
--r;
}
int p = t1.kth(m);
return t2.query(p + 1) + lsh[p] * (m - t1.query(p + 1)) - a[R] + a[L - 1];
}
void dfs(int l, int r, int pl, int pr) {
if (l > r || pl > pr) {
return;
}
int mid = (l + r) >> 1, p = -1;
ll res = -1e18;
for (int i = pl; i <= min(pr, mid - (int)m + 1); ++i) {
ll val = calc(i, mid);
if (val >= res) {
res = val;
p = i;
}
}
if (res > ans) {
ans = res;
vector<pii>().swap(S);
S.pb(p, mid);
} else if (res == ans) {
S.pb(p, mid);
}
dfs(l, mid - 1, pl, p);
dfs(mid + 1, r, p, pr);
}
int fa[maxn];
bool f[maxn];
int find(int x) {
return fa[x] == x ? x : fa[x] = find(fa[x]);
}
void solve() {
scanf("%lld%lld", &n, &m);
for (int i = 1; i <= n; ++i) {
scanf("%lld", &a[i]);
a[i] += a[i - 1];
}
for (int i = 1; i <= n; ++i) {
scanf("%lld", &b[i]);
lsh[++tot] = b[i];
}
sort(lsh + 1, lsh + tot + 1);
tot = unique(lsh + 1, lsh + tot + 1) - lsh - 1;
for (int i = 1; i <= n; ++i) {
b[i] = lower_bound(lsh + 1, lsh + tot + 1, b[i]) - lsh;
ap[b[i]].pb(i);
}
dfs(m, n, 1, n);
sort(S.begin(), S.end());
int i = 1;
for (pii p : S) {
while (1) {
if (calc(i, p.scd) == ans) {
vc[t1.kth(m)].pb(i, p.scd);
}
if (i == p.fst) {
break;
}
++i;
}
}
for (int i = 1; i <= n + 1; ++i) {
fa[i] = i;
}
for (int i = 1; i <= tot; ++i) {
for (pii p : vc[i]) {
for (int j = find(p.fst); j <= p.scd; j = find(j + 1)) {
f[j] = 1;
fa[j] = j + 1;
}
}
for (int j : ap[i]) {
fa[j] = j + 1;
}
}
printf("%lld\n", ans);
for (int i = 1; i <= n; ++i) {
putchar('0' + f[i]);
}
}
int main() {
int T = 1;
// scanf("%d", &T);
while (T--) {
solve();
}
return 0;
}