好题。同时考查了 slope trick 和选手的计数能力,不愧是 AGC E。
先考虑问题的第一部分。
你现在有一个初始全为 \(0\) 的序列 \(b\)。你每次可以给 \(b\) 单点 \(\pm 1\),代价为 \(1\),或区间 \(\pm 1\),代价为 \(m\)。求把 \(b\) 变成给定序列 \(a\) 的最小代价。
考虑先执行区间加减操作,设操作完的序列为 \(c\),那么之后单点加减操作的代价就是 \(\sum\limits_{i = 1}^n |c_i - a_i|\)。对 \(c\) 差分,可以发现最少的区间操作次数就是差分后数组的正数项的值的和。所以总代价就是 \(\sum\limits_{i = 1}^n |c_i - a_i| + m \sum\limits_{i = 1}^n \max(c_i - c_{i - 1}, 0)\)。
考虑 dp,设 \(f_{i, j}\) 为 \(c_i = j\) 的最小代价(显然最优时 \(c_i \ge 0\))。初始有 \(\forall i \ge 0, f_{0, i} = mi\)。转移枚举 \(c_{i - 1}\),可得:
\[f_{i, j} = |a_i - j| + \min\limits_k f_{i - 1, k} + m \max(j - k, 0) \]看到转移式里面有计算绝对值,联想到 slope trick。发现 \(f_i\) 是凸函数,并且每段的斜率在 \([-1, m + 1]\) 之间。容易归纳,我们初始的图像是一条斜率为 \(m\) 的射线,然后我们先进行 \(f_{i, j} \gets \min\limits_k f_{i - 1, k} + m \max(j - k, 0)\) 的更新,这个更新产生的影响就是,斜率为 \(-1\) 的段被拍平(斜率变成 \(0\)),斜率为 \(m + 1\) 的段斜率变成 \(m\)。然后我们给这个图像整体加上 \(|a_i - j|\) 的分段函数,也就是把 \(\le a_i\) 的段斜率减少 \(1\),\(> a_i\) 的段斜率增加 \(1\)。
考虑运用 slope trick,用 multiset
维护这个分段函数。回忆一下在 slope trick 中,我们维护的是分段函数图像变化的断点,并且一个断点代表斜率增加 \(1\)。那么在这题中,我们初始往 multiset
添加 \(m\) 个 \(0\) 代表断点为 \(0\) 且斜率为 \(m\),然后当 \(i = 1\) 时,因为图像不存在斜率为 \(-1\) 或 \(m + 1\) 的段,因此我们不需要进行删除操作,直接添加两个 \(a_i\) 表示 \(a_i\) 处斜率变化为 \(2\)。当 \(i > 1\) 时,我们先删除 multiset
中的最小值和最大值表示这两个断点被拍平了,不存在了,再添加两个 \(a_i\)。
至于统计答案,我们在每次添加两个 \(a_i\) 后统计,此时 multiset
中的最小值就是斜率 \(-1 \to 0\) 变化的断点,因此我们把答案累加 \(a_i - p\),其中 \(p\) 为 multiset
中的最小值(不用加绝对值是因为此时加入 \(a_i\) 后 \(p\) 一定 \(\le a_i\))。
于是我们现在可以 \(O(n \log n)\) 求解这个问题了。
考虑问题的第二部分,即统计所有可能的 \(a_i\) 对应的答案之和。
考虑我们上面的算法流程。
初始往
multiset
中添加 \(m\) 个 \(0\)。\(i = 1\) 时,往multiset
中添加 \(2\) 个 \(a_i\),然后计算 \(a_i - p\),其中 \(p\) 为multiset
中最小值;\(i > 1\) 时,先删除multiset
中的最小值和最大值,然后往其中添加 \(2\) 个 \(a_i\),再计算 \(a_i - p\)。
\(a_i\) 部分的贡献系数是容易统计的,就是 \(K^{n - 1}\)(选定 \(a_i\) 后其他的可以任意选,都能产生贡献)。问题还剩下统计所有 \(p\) 的和。
我们枚举 \(nK\) 个可能的 \(p\),分别计算最小值 \(< p\) 和最小值 \(\le p\) 的方案数,二者差分一下就是 \(p\) 的贡献系数。
直接做不好维护 multiset
,但是如果 \(a_i \in \{0, 1\}\),我们就能维护 \(1\) 的个数来表示整个 multiset
了(非常经典的套路:任意值转 \(01\))。我们不妨让 \(a_i \gets [a_i \ge p]\),这样最小值 \(< p\) 等价于最小值 \(= 0\)。
发现只有 \(1\) 的个数是 \(m + 2\) 时,multiset
中的最小值才不是 \(0\)。因此考虑一个容斥,总方案数减去最小值为 \(1\) 的方案数。总方案数显然是 \(nK^n\)(一共 \(n\) 轮,\(a\) 数组有 \(K^n\) 种产生方式),如果我们设 \(f_{i, j}\) 为进行到第 \(i\) 轮,multiset
中有 \(j\) 个 \(1\) 的方案数,那么最小值为 \(1\) 的方案数就是 \(\sum\limits_{i = 1}^n K^{n - i} f_{i, m + 2}\)(乘上 \(K^{n - i}\) 是因为第 \(i + 1 \sim n\) 轮中 \(a_i\) 的选择都不影响第 \(i\) 轮的最小值是 \(1\))。
现在考虑 \(f_{i - 1} \to f_i\)。对于一个 \(f_{i - 1, j}\),我们先进行删除操作,即 \(j \gets j - [j > 0] - [j = m + 2]\),然后我们考虑选择 \(a_i\),设 \(t = \sum\limits_{j = 1}^K [b_{i, j} \ge p]\),也就是能选的 \(1\) 的个数,那么 \(f_{i, j} \gets (K - t) f_{i - 1, j}\),\(f_{i, j + 2} \gets t f_{i - 1, j}\)。记得特判 \(i = 1\)。
至此我们终于以 \(O(n^3K)\) 的时间复杂度完成了这题。
code
// Problem: E - Increment Decrement
// Contest: AtCoder - AtCoder Grand Contest 049
// URL: https://atcoder.jp/contests/agc049/tasks/agc049_e
// Memory Limit: 1024 MB
// Time Limit: 2000 ms
//
// Powered by CP Editor (https://cpeditor.org)
#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mems(a, x) memset((a), (x), sizeof(a))
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef double db;
typedef long double ldb;
typedef pair<ll, ll> pii;
const int maxn = 55;
const ll mod = 1000000007;
ll n, m, K, a[maxn][maxn], lsh[maxn * maxn], tot, pw[maxn], f[maxn][maxn];
inline void upd(ll &x, ll y) {
x += y;
(x >= mod) && (x -= mod);
}
inline ll calc(ll x) {
mems(f, 0);
ll cnt = 0;
for (int i = 1; i <= K; ++i) {
cnt += (a[1][i] >= x);
}
f[1][0] = K - cnt;
f[1][2] = cnt;
for (int i = 2; i <= n; ++i) {
cnt = 0;
for (int j = 1; j <= K; ++j) {
cnt += (a[i][j] >= x);
}
for (int j = 0; j <= m + 2; ++j) {
if (!f[i - 1][j]) {
continue;
}
int nj = j - (j > 0) - (j == m + 2);
upd(f[i][nj], f[i - 1][j] * (K - cnt) % mod);
upd(f[i][nj + 2], f[i - 1][j] * cnt % mod);
}
}
ll ans = n * pw[n] % mod;
for (int i = 1; i <= n; ++i) {
ans = (ans - f[i][m + 2] * pw[n - i] % mod + mod) % mod;
}
return ans;
}
void solve() {
scanf("%lld%lld%lld", &n, &m, &K);
pw[0] = 1;
for (int i = 1; i <= n; ++i) {
pw[i] = pw[i - 1] * K % mod;
}
ll ans = 0;
for (int i = 1; i <= n; ++i) {
for (int j = 1; j <= K; ++j) {
scanf("%lld", &a[i][j]);
lsh[++tot] = a[i][j];
ans = (ans + a[i][j]) % mod;
}
}
sort(lsh + 1, lsh + tot + 1);
tot = unique(lsh + 1, lsh + tot + 1) - lsh - 1;
ans = ans * pw[n - 1] % mod;
for (int i = 1; i <= tot; ++i) {
ans = (ans - lsh[i] * (calc(lsh[i] + 1) - calc(lsh[i]) + mod) % mod + mod) % mod;
}
printf("%lld\n", ans);
}
int main() {
int T = 1;
// scanf("%d", &T);
while (T--) {
solve();
}
return 0;
}