提供一个 dp 思路。
下文设串长为 \(n\),串中 \(1\) 个数为 \(m\)。
考虑如何求 \(d(s, t)\)。设 \(s\) 的 \(1\) 位置分别为 \(a_1, a_2, ..., a_m\),\(t\) 的 \(1\) 位置分别为 \(b_1, b_2, ..., b_m\)。那么 \(d(s, t) = \sum\limits_{i=1}^m |a_i - b_i|\)。
更进一步地,对于串 \(s'\),设 \(s'\) 的 \(1\) 位置分别为 \(c_1, c_2, ..., c_m\),那么 \(\forall i \in [1, m], c_i \in [\min(a_i, b_i), \max(a_i, b_i)]\) 是 \(d(s, s') = d(s', t)\) 的充要条件。设 \(l_i = \min(a_i, b_i), r_i = \max(a_i, b_i)\)。
考虑转化题中的 "beauty",相当于最小化极长同字符连续段个数。
考虑 dp。设 \(f_i\) 表示填完前 \(i\) 个 \(1\) 的最小极长连续段个数。枚举 \(j\) 表示第 \(j \sim i\) 个 \(1\) 放一起,合法当且仅当 \(l_i - r_j + 1 \le i - j + 1 \le r_i - l_j + 1\),此时有转移 \(f_i = f_{j-1} + 2\),表示先加一段 \(0\) 再加一段 \(1\)。
注意处理开头一段是 \(1\) 和末尾一段是 \(1\) 的情况。
感性理解,对于每个 \(i\),合法的 \(j\) 一定形成一个区间,因为不可能 \(j_1 \sim i\) 能放一起但是 \(j_2 \sim i\) 不能(\(j_1 < j_2\))。可以二分得出左端点,然后使用线段树优化。可以做到 \(O(n \log n)\),已经可以通过。
更进一步,发现对于每个 \(i\),最左合法转移点单调不降。并且由于这个性质,\(f_i\) 单调不降。可以双指针维护每个 \(i\) 的最左合法转移点,然后 \(O(1)\) 转移。总时间复杂度降至 \(O(n)\)。
$O(n \log n)$ 的 code
// Problem: D - Between Two Binary Strings
// Contest: AtCoder - AtCoder Regular Contest 132
// URL: https://atcoder.jp/contests/arc132/tasks/arc132_d
// Memory Limit: 1024 MB
// Time Limit: 2000 ms
//
// Powered by CP Editor (https://cpeditor.org)
#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mems(a, x) memset((a), (x), sizeof(a))
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef double db;
typedef long double ldb;
typedef pair<ll, ll> pii;
const int maxn = 300100;
const int inf = 0x3f3f3f3f;
int n, m, f[maxn];
char s[maxn], t[maxn];
struct node {
int l, r;
node(int a = 0, int b = 0) : l(a), r(b) {}
} a[maxn];
namespace SGT {
int tree[maxn << 2];
inline void init() {
mems(tree, 0x3f);
}
inline void pushup(int x) {
tree[x] = min(tree[x << 1], tree[x << 1 | 1]);
}
void update(int rt, int l, int r, int x) {
if (l == r) {
tree[rt] = f[x];
return;
}
int mid = (l + r) >> 1;
(x <= mid) ? update(rt << 1, l, mid, x) : update(rt << 1 | 1, mid + 1, r, x);
pushup(rt);
}
int query(int rt, int l, int r, int ql, int qr) {
if (ql > qr) {
return inf;
}
if (ql <= l && r <= qr) {
return tree[rt];
}
int mid = (l + r) >> 1, res = inf;
if (ql <= mid) {
res = min(res, query(rt << 1, l, mid, ql, qr));
}
if (qr > mid) {
res = min(res, query(rt << 1 | 1, mid + 1, r, ql, qr));
}
return res;
}
}
void solve() {
scanf("%d%d%s%s", &n, &m, s + 1, t + 1);
if (!n || !m) {
printf("%d\n", n + m - 1);
return;
}
n += m;
vector<int> va, vb;
for (int i = 1; i <= n; ++i) {
if (s[i] == '1') {
va.pb(i);
}
if (t[i] == '1') {
vb.pb(i);
}
}
for (int i = 0; i < m; ++i) {
int x = va[i], y = vb[i];
if (x > y) {
swap(x, y);
}
a[i + 1] = node(x, y);
// printf("l, r: %d %d\n", x, y);
}
SGT::init();
mems(f, 0x3f);
f[0] = 0;
SGT::update(1, 0, m, 0);
for (int i = 1; i <= m; ++i) {
if (a[1].l == 1 && a[i].l <= i && i <= a[i].r) {
f[i] = 1;
}
if (a[i].l - a[1].r + 1 <= i && i <= a[i].r - a[1].l + 1) {
f[i] = min(f[i], 2);
}
int l = 2, r = i, pos = i + 1;
while (l <= r) {
int mid = (l + r) >> 1;
if (a[i].l - a[mid].r + 1 <= i - mid + 1 && i - mid + 1 <= a[i].r - a[mid].l + 1) {
pos = mid;
r = mid - 1;
} else {
l = mid + 1;
}
}
f[i] = min(f[i], SGT::query(1, 0, m, pos - 1, i - 1) + 2);
// for (int j = i; j > 1; --j) {
// if (a[i].l - a[j].r + 1 <= i - j + 1 && i - j + 1 <= a[i].r - a[j].l + 1) {
// f[i] = min(f[i], f[j - 1] + 2);
// } else {
// break;
// }
// }
SGT::update(1, 0, m, i);
}
int ans = f[m] + 1;
for (int i = m; i; --i) {
if (a[m].r == n && a[i].l <= n - (m - i) && n - (m - i) <= a[i].r) {
ans = min(ans, f[i - 1] + 2);
}
}
printf("%d\n", n - ans);
}
int main() {
int T = 1;
// scanf("%d", &T);
while (T--) {
solve();
}
return 0;
}
$O(n)$ 的 code
// Problem: D - Between Two Binary Strings
// Contest: AtCoder - AtCoder Regular Contest 132
// URL: https://atcoder.jp/contests/arc132/tasks/arc132_d
// Memory Limit: 1024 MB
// Time Limit: 2000 ms
//
// Powered by CP Editor (https://cpeditor.org)
#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mems(a, x) memset((a), (x), sizeof(a))
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef double db;
typedef long double ldb;
typedef pair<ll, ll> pii;
const int maxn = 300100;
const int inf = 0x3f3f3f3f;
int n, m, f[maxn];
char s[maxn], t[maxn];
struct node {
int l, r;
node(int a = 0, int b = 0) : l(a), r(b) {}
} a[maxn];
void solve() {
scanf("%d%d%s%s", &n, &m, s + 1, t + 1);
if (!n || !m) {
printf("%d\n", n + m - 1);
return;
}
n += m;
vector<int> va, vb;
for (int i = 1; i <= n; ++i) {
if (s[i] == '1') {
va.pb(i);
}
if (t[i] == '1') {
vb.pb(i);
}
}
for (int i = 0; i < m; ++i) {
int x = va[i], y = vb[i];
if (x > y) {
swap(x, y);
}
a[i + 1] = node(x, y);
}
mems(f, 0x3f);
f[0] = 0;
for (int i = 1, j = 1; i <= m; ++i) {
if (a[1].l == 1 && a[i].l <= i && i <= a[i].r) {
f[i] = 1;
}
if (a[i].l - a[1].r + 1 <= i && i <= a[i].r - a[1].l + 1) {
f[i] = min(f[i], 2);
}
while (j <= i && !(a[i].l - a[j].r + 1 <= i - j + 1 && i - j + 1 <= a[i].r - a[j].l + 1)) {
++j;
}
if (j <= i) {
f[i] = min(f[i], f[j - 1] + 2);
}
}
int ans = f[m] + 1;
for (int i = m; i; --i) {
if (a[m].r == n && a[i].l <= n - (m - i) && n - (m - i) <= a[i].r) {
ans = min(ans, f[i - 1] + 2);
}
}
printf("%d\n", n - ans);
}
int main() {
int T = 1;
// scanf("%d", &T);
while (T--) {
solve();
}
return 0;
}