记 \(b_i = \sum\limits_{j = 1}^m a_{i, j}, c_j = \sum\limits_{i = 1}^n a_{i, j}\)。
首先考虑这样一个事情,就是对于 \(b_i \le 0\) 的行有没有可能被选。如果选了它:
- 如果没有选任何列,选这一行肯定不优;
- 如果选了若干列,根据题目的要求,这若干列与这一行重叠的部分只可能是非负数。考虑看成是先选列再选行,那选这一行必然会造成负贡献,雪上加霜。
所以我们得到只选 \(b_i > 0\) 的行一定不劣。类似地,只选 \(c_j > 0\) 的列也一定不劣。
这种行列网格的题,可以考虑转化成二分图一类的问题。考虑直接硬冲一个最大费用最大流。下面设 \((u, v, x, y)\) 表示一条 \(u \to v\),容量 \(x\),费用 \(y\) 的边。行作为左部点,列作为右部点。
- 对于 \(S\) 到左部点的边,若 \(b_i \le 0\) 直接不管,否则我们希望有流量就产生 \(b_i\) 的贡献,而不管流量多少。因此连边 \((S, i, 1, b_i), (S, i, +\infty, 0)\)。
- 对于一个匹配,可以看成是这一行跟这一列都选。那么重叠的部分要减掉,即连边 \((i, j, 1, -a_{i, j})\)。
- 对于右部点到 \(T\) 的边,类似地,连边 \((j, T, 1, c_j), (j, T, +\infty, 0)\)。
最后答案就是最大费用。
但是理论复杂度是 \(O(n^5)\) 的,实际也会 T 得很惨。
类似 ABC214H 地考虑,不妨换种思路,计算最少损失(初始的答案是 \(\sum\limits_{i = 1}^n b_i + \sum\limits_{j = 1}^m c_j\))。考虑转化成最小割,这样建图(用 \((u, v, x)\) 表示一条 \(u \to v\),容量为 \(x\) 的边):
- 对于 \(b_i > 0\) 的行,连边 \((S, i, b_i)\);
- 对于 \(c_j > 0\) 的列,连边 \((j, T, c_j)\);
- 对于 \(a_{i, j} \ge 0\),连边 \((i, j, a_{i, j})\);
- 对于 \(a_{i, j} < 0\),连边 \((i, j, +\infty)\)。
这样保证了对于任意 \(S \to i \to j \to T\) 的路径:
- 要么割掉 \(S \to i\) 或 \(j \to T\) 的边,表示它们不同时被选;
- 要么割掉 \(i \to j\) 的边,表示它们同时被选,但是因为有重叠,所以产生了 \(-a_{i, j}\) 的负贡献;
- 如果 \(a_{i, j} < 0\),那么也能保证 \(i, j\) 不同时被选,因为 \(i \to j\) 的边权是 \(+\infty\),表示它不能被割,只能割 \(S \to i\) 或 \(j \to T\) 的边。
最后答案就是 \(\sum\limits_{i = 1}^n b_i + \sum\limits_{j = 1}^m c_j - \text{mincut} = \sum\limits_{i = 1}^n b_i + \sum\limits_{j = 1}^m c_j - \text{maxflow}\)。
因为这是二分图,所以时间复杂度是 \(O(n^{2.5})\)。
code
// Problem: G - Grid Card Game
// Contest: AtCoder - AtCoder Beginner Contest 259
// URL: https://atcoder.jp/contests/abc259/tasks/abc259_g
// Memory Limit: 1024 MB
// Time Limit: 2000 ms
//
// Powered by CP Editor (https://cpeditor.org)
#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mems(a, x) memset((a), (x), sizeof(a))
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef double db;
typedef long double ldb;
typedef pair<ll, ll> pii;
const int maxn = 110;
const int maxm = 1000100;
const ll inf = 0x3f3f3f3f3f3f3f3fLL;
ll n, m, a[maxn][maxn], head[maxm], len = 1, ntot, S, T, b[maxn], c[maxn], id[maxn][2];
struct edge {
ll to, next, cap, flow;
} edges[maxm];
inline void add_edge(ll u, ll v, ll c, ll f) {
edges[++len].to = v;
edges[len].next = head[u];
edges[len].cap = c;
edges[len].flow = f;
head[u] = len;
}
struct Dinic {
ll d[maxm], cur[maxm];
bool vis[maxm];
inline void add(ll u, ll v, ll c) {
add_edge(u, v, c, 0);
add_edge(v, u, 0, 0);
}
bool bfs() {
for (int i = 1; i <= ntot; ++i) {
d[i] = -1;
vis[i] = 0;
}
queue<int> q;
q.push(S);
d[S] = 0;
vis[S] = 1;
while (q.size()) {
int u = q.front();
q.pop();
for (int i = head[u]; i; i = edges[i].next) {
edge &e = edges[i];
if (e.cap > e.flow && !vis[e.to]) {
vis[e.to] = 1;
d[e.to] = d[u] + 1;
q.push(e.to);
}
}
}
return vis[T];
}
ll dfs(ll u, ll a) {
if (u == T || !a) {
return a;
}
ll flow = 0, f;
for (ll &i = cur[u]; i; i = edges[i].next) {
edge &e = edges[i];
if (d[e.to] == d[u] + 1 && (f = dfs(e.to, min(a, e.cap - e.flow))) > 0) {
e.flow += f;
edges[i ^ 1].flow -= f;
flow += f;
a -= f;
if (!a) {
break;
}
}
}
return flow;
}
ll solve() {
ll flow = 0;
while (bfs()) {
for (int i = 1; i <= ntot; ++i) {
cur[i] = head[i];
}
ll t = dfs(S, inf);
flow += t;
}
return flow;
}
} solver;
void solve() {
scanf("%lld%lld", &n, &m);
for (int i = 1; i <= n; ++i) {
for (int j = 1; j <= m; ++j) {
scanf("%lld", &a[i][j]);
b[i] += a[i][j];
c[j] += a[i][j];
}
}
ll s = 0;
for (int i = 1; i <= n; ++i) {
if (b[i] > 0) {
s += b[i];
}
}
for (int i = 1; i <= m; ++i) {
if (c[i] > 0) {
s += c[i];
}
}
S = ++ntot;
T = ++ntot;
for (int i = 1; i <= n; ++i) {
id[i][0] = ++ntot;
}
for (int i = 1; i <= m; ++i) {
id[i][1] = ++ntot;
}
for (int i = 1; i <= n; ++i) {
if (b[i] > 0) {
solver.add(S, id[i][0], b[i]);
}
}
for (int i = 1; i <= m; ++i) {
if (c[i] > 0) {
solver.add(id[i][1], T, c[i]);
}
}
for (int i = 1; i <= n; ++i) {
for (int j = 1; j <= m; ++j) {
if (a[i][j] >= 0) {
solver.add(id[i][0], id[j][1], a[i][j]);
} else {
solver.add(id[i][0], id[j][1], inf);
}
}
}
printf("%lld\n", s - solver.solve());
}
int main() {
int T = 1;
// scanf("%d", &T);
while (T--) {
solve();
}
return 0;
}