挺有意思的题。
首先显然地,一个棋子不会走回头路。于是一个棋子沿着边走的效果就是区间异或。
更进一步,设 \(s_i\) 为 \(i-1 \to i\) 的边颜色与 \(i \to i+1\) 的边颜色是否相同(差分),相当于对于每个 \(i\) 都选择 \(s_{a_i}\) 和 \(s_{x_i}\),将它们异或上 \(1\)(\(x_i\) 任选),代价为 \(|a_i - x_i|\),最后要求恰好 \(k\) 个位置 \(t_1,t_2,...,t_k\) 为 \(1\),求最小代价。
\(s_{a_i}\) 的异或操作可以事先处理。操作后需要被 \(x_i\) 异或奇数次的位置可知。现在问题又变成了,数轴上有 \(n\) 个红点 \(a_1,a_2,...,a_n\) 和 \(m\) 个蓝点 \(b_1,b_2,...,b_m\),红点和红点可以匹配(随便选一个它们之间的点互相抵消),红点和黑点也可以匹配,要求每个黑点和每个红点都要被匹配,每对匹配的价值是坐标之差的绝对值,求最小代价。
考虑 \(n = m\) 怎么做。这个是个经典问题,\(a\) 和 \(b\) 排序后每对匹配相交一定不优,于是答案就是排序后 \(ans = \sum\limits_{i=1}^n |a_i - b_i|\)。
现在 \(n > m\)(如果 \(n < m\) 就无解),不妨沿用 \(n = m\) 时的结论。不难得出两个红点仅当它们在红点之中相邻才能形成匹配。发现数据范围允许 \(O(nm)\),大力 dp,设 \(f_{i,j}\) 为排序后前 \(i\) 个红点和前 \(j\) 个蓝点都能匹配的代价最小值。有转移:
\[f_{i,j} \gets \min(f_{i-1,j-1} + |a_i - b_j|, f_{i-2,j} + a_i - a_{i-1}) \]答案为 \(f_{n,m}\)。
code
// Problem: D - Moving Pieces on Line
// Contest: AtCoder - AtCoder Regular Contest 114
// URL: https://atcoder.jp/contests/arc114/tasks/arc114_d
// Memory Limit: 1024 MB
// Time Limit: 2000 ms
//
// Powered by CP Editor (https://cpeditor.org)
#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mems(a, x) memset((a), (x), sizeof(a))
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<ll, ll> pii;
const int maxn = 5050;
ll n, m, a[maxn], b[maxn << 1], c[maxn], f[maxn][maxn];
void solve() {
scanf("%lld%lld", &n, &m);
for (int i = 1; i <= n; ++i) {
scanf("%lld", &a[i]);
b[i] = a[i];
}
for (int i = n + 1; i <= n + m; ++i) {
scanf("%lld", &b[i]);
}
m += n;
sort(a + 1, a + n + 1);
sort(b + 1, b + m + 1);
int tot = 0;
for (int i = 1, j = 1; i <= m; i = (++j)) {
while (j < m && b[j + 1] == b[i]) {
++j;
}
if ((j - i + 1) & 1) {
c[++tot] = b[i];
}
}
m = tot;
if (n < m) {
puts("-1");
return;
}
mems(f, 0x3f);
f[0][0] = 0;
for (int i = 1; i <= n; ++i) {
for (int j = 0; j <= m; ++j) {
f[i][j] = f[i - 1][j - 1] + abs(a[i] - c[j]);
if (i >= 2) {
f[i][j] = min(f[i][j], f[i - 2][j] + a[i] - a[i - 1]);
}
}
}
printf("%lld\n", f[n][m]);
}
int main() {
int T = 1;
// scanf("%d", &T);
while (T--) {
solve();
}
return 0;
}