决策单调性是 dp 转移方程的一种性质。一般而言,我们有许多方法优化一个具有决策单调性的转移方程。
这里主要讲解一种用分治解决决策单调性问题的方法。
引入
先看一道题:CF868F
我们可以想到一个 \(O(n^2k)\) 暴力。定义 \(dp_{i, j}\) 为令点 \(i\) 为第 \(j\) 段的最后一个点产生的最小代价。则可以得出:
\[dp_{i, j} = min(dp_{k, j - 1}+f(k + 1, i))(k<i) \]其中 \(f(k+1,i)\) 即为这一段的代价。答案即为 \(dp_{n, k}\)。这是好求的。
显然,我们需要砍掉一个 \(n\) 才能通过此题。然后我们通过打表得到了一个性质:对于 \(dp_{i, j}\) 与 \(dp{k,j}\) 的转移点(这两个状态由它们的转移点转移而来)\(I,K\),如果 \(i>k\),则 \(I \ge K\)。也就是说,这个转移具有决策单调性!
发现
也许已经有人想要大显身手了,但是我们先来关注一下我们是如何得出决策单调性的。
让我们了解一下四边形不等式。对于 \(dp_i = min(g_j+f(j, i))\),如果有四个点 \(a<b<c<d\) 且 \(f(a,d) + f(b,c) \ge f(a, c) + f(b,d)\),那么我们就说这个转移方程满足四边形不等式(最大也是可以的,反过来就行了)。这一性质也被称作“包含劣于相交”。
那么我们发现,如果一个形如 \(dp_i = min(g_j+f(j, i))\) 的转移方程满足四边形不等式,那么这个方程具有决策单调性。
考虑证明:考虑反证法,然后读者自证。
在实际做题中,我们通过感觉感应到四边形不等式的存在(有时也会直接感应到决策单调性),然后打表加以证明。
分治
我们并不知道这个代价函数除了决策单调性以外的性质,也就是说,尽管我们知道了点 \(i\) 的决策点 \(d_i\),对于大于点 \(i\) 的点 \(j\),我们仍然要暴力枚举转移点 \([a_i, j]\)。这很容易被卡掉。
考虑更好地排除转移点,对于需要求解的区间 \([l, r]\),我们取区间中点 \(mid\),并暴力计算出 \(mid\) 的转移点 \(d_mid\)。如果原区间转移范围在 \([L,R]\) 之间,那么我们可以分两半。\([l,mid]\) 转移范围在 \([L,d_mid]\),\([mid+1,r]\) 转移范围在 \([d_mid, R]\)。
考虑证明时间复杂度。对于分治的每一层,我们暴力计算的时间复杂度总和都是 \(O(n)\)。但由于分治树只有 \(\log n\) 层,所以总复杂度变为 \(n \log n\)。
修改
现在我们需要解决的问题,是我们无法快速求出 \(f(l, r)\)。
我们发现虽然我们无法快速求出 \(f(l, r)\),但我们可以快速地由 \(f(l+1,r)\) 转移到 \(f(l,r)\)。
这时我们就有一个很憨的想法:类似莫队,我们直接在上一次查询的基础上一个一个暴力转移。
很神奇的是,这个方法复杂度还是 \(O(n \log n)\) 的!考虑两个指针进行跳跃。对于从 \([l,d_mid]\) 到 \([d_mid, r]\) 的跳跃,时间复杂度不会超过区间长度。那么这一层的时间复杂度总和还是 \(O(n)\) 的。然后分治树只有 \(\log n\) 层,于是总时间复杂度还是 \(O(n \log n)\)。
实现
注意由于这里也是一个一个的跳跃,所以跳跃顺序和莫队一样有讲究。具体可以看代码。
#include <iostream>
#include <algorithm>
#include <string.h>
using namespace std;
const int N = 1e5 + 5, K = 21;
int n, k;
int a[N];
long long dp[N], g[N];
long long ton[N];
long long f(int L, int R) {
static long long l = 1, r = 0, ans = 0;
auto del = [&](int i) {
ton[a[i]]--;
ans -= ton[a[i]];
};
auto add = [&](int i) {
ans += ton[a[i]];
ton[a[i]]++;
};
while (l > L) add(--l);
while (r < R) add(++r);
while (l < L) del(l++);
while (r > R) del(r--);//important
return ans;
}
void solve(int l, int r, int dl, int dr) {
int mid = (l + r) >> 1, dec = 0;
for (int i = dl; i <= min(dr, mid - 1); ++i) {
if (dp[mid] > g[i] + f(i + 1, mid)) {
dp[mid] = g[i] + f(i + 1, mid);
dec = i;
}
}
if (l == r) return ;
solve(l, mid, dl, dec);
solve(mid + 1, r, dec, dr);
}
signed main() {
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
cin >> n >> k;
for (int i = 1; i <= n; ++i) cin >> a[i];
long long tmp = 0;
for (int i = 0; i <= n; ++i) {
tmp += ton[a[i]];
dp[i] = tmp;
ton[a[i]]++;
}
for (int i = 1; i <= n; ++i) ton[i] = 0;
for (int lask = 1; lask <= k - 1; ++lask) {
for (int i = 1; i <= n; ++i) g[i] = dp[i], dp[i] = 2e13;
solve(lask + 1, n, lask, n - 1);
}
cout << dp[n] << endl;
return 0;
}