前言
题意简述
给你一个 \(n \times m\) 的矩形 \(a\),其中已经有 \(q\) 个位置填上了数,你需要为剩下的位置分别填上一个非负整数,使得最终任意一个 \(2 \times 2\) 的子矩形内,左上角的数加右下角的数等于右上角的数加左下角的数,即 \(\forall i\in[1,n), j\in[1,m),\ a_{i,j}+a_{i+1,j+1}=a_{i,j+1}+a_{i+1,j}\)。
\(n, m, q \leq 10^5\)。
题目分析
第一步肯定是找性质,我们从 \(n = 2\) 开始观察。发现合法的方案满足 \(\forall j\in[1,m],\ a_{1,j}-a_{2,j}=k\),其中 \(k\) 为一常数。发现这是因为由 \(a_{i,j}+a_{i+1,j+1}=a_{i,j+1}+a_{i+1,j}\) 得到 \(a_{i,j}-a_{i+1,j}=a_{i,j+1}-a_{i+1,j+1}\),这 \(m-1\) 个等式可以等起来,也就得到 \(a_{1,j}-a_{2,j}\) 为一定值。
对其推广,对于 \(n \times m\) 的矩形,对于所有 \(i\in[1,n)\),满足 \((i, i+1)\) 两行对应位置值之差为一定值,说明合法。(对于列有相同结论,但是任意一种都是合法的充要条件,二者可以互推,所以这里不妨仅考虑行。)
我们发现,还可以进一步拓展结论:由于 \((i,i+1), (i+1,i+2)\) 满足结论,那么 \((1, i+2)\) 也满足结论,进一步,也就是对于任意两行 \((i, j)\),它们对应位置值之差是定值。
我们想到做前缀和后,变为差分约束搞,但是显然不太对。我们想到可以使用带权并查集来维护,这种套路在带权并查集是常见的,不做赘述。可以对于每一列,依次把相邻的两个有值的位置所在的行进行合并(这里相邻指的是中间没有其他有值的位置)。
除了合并的时候无解,我们注意到题目中还有非负整数的限制。我们对于每一列的每一个有值的位置,可以通过并查集中的边,推断出和它处在同一个联通块内,所有行对应位置的值,我们仅需要保证这些能被唯一确定的点值均非负即可,剩下的位置是不确定的。那么把每个连通块路径压缩成一个菊花后,我们只需要维护并查集联通块内,到父亲边权的最值即可,这个可以预处理。具体判断逻辑见代码。
于是本题做完了,时间复杂度带一个并查集的小常数。如果使用搜索可以避免,但是会进入复杂度陷阱,拥有一个大常数。可以对于每一列新建一个虚拟点,这样可以减少码量和常数,建议在理解基础算法后再尝试理解,两种代码都给出,供读者学习。
代码
普通实现
#include <cstdio>
#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;
#define BAD "No"
#define OK "Yes"
const int N = 1e5 + 10;
using lint = long long;
int n, m, q;
vector<pair<int, int>> vec[N];
int fa[N];
lint d[N], mi[N];
int get(int x) {
if (x == fa[x]) return x;
int f = get(fa[x]);
d[x] += d[fa[x]];
return fa[x] = f;
}
inline void merge(int x, int y, lint v) {
int a = get(x), b = get(y);
if (a == b) return d[x] - d[y] != v && (puts(BAD), exit(0), 0), void();
d[a] = d[y] - d[x] + v, fa[a] = b;
}
signed main() {
#ifndef XuYueming
freopen("zibi.in", "r", stdin);
freopen("zibi.out", "w", stdout);
#endif
scanf("%d%d%d", &n, &m, &q);
for (int x, y, v; q--; ) {
scanf("%d%d%d", &x, &y, &v);
vec[y].emplace_back(x, v);
}
for (int i = 1; i <= n; ++i) fa[i] = i, mi[i] = 0x3f3f3f3f3f3f3f3f;
for (int i = 1; i <= m; ++i) {
sort(vec[i].begin(), vec[i].end());
for (int lst = 0, lv = 0; auto [x, v] : vec[i]) {
if (lst) merge(x, lst, v - lv);
lst = x, lv = v;
}
}
for (int i = 1; i <= n; ++i) get(i), mi[fa[i]] = min(mi[fa[i]], d[i]);
for (int i = 1; i <= m; ++i) {
for (auto [x, v] : vec[i]) {
if (-d[x] + v + mi[fa[x]] < 0) return puts(BAD), 0;
}
}
puts(OK);
return 0;
}
虚拟点实现
#include <cstdio>
#include <iostream>
using namespace std;
#define BAD "No"
#define OK "Yes"
const int N = 1e5 + 10;
using lint = long long;
int n, m, q;
int fa[N << 1];
lint d[N << 1], mi[N << 1];
int get(int x) {
if (x == fa[x]) return x;
int f = get(fa[x]);
d[x] += d[fa[x]];
return fa[x] = f;
}
inline void merge(int x, int y, lint v) {
int a = get(x), b = get(y);
if (a > b) swap(a, b), swap(x, y), v = -v;
if (a == b) return d[x] - d[y] != v && (puts(BAD), exit(0), 0), void();
d[a] = d[y] - d[x] + v, fa[a] = b;
}
signed main() {
#ifndef XuYueming
freopen("zibi.in", "r", stdin);
freopen("zibi.out", "w", stdout);
#endif
scanf("%d%d%d", &n, &m, &q);
for (int i = 1; i <= n + m; ++i) fa[i] = i, mi[i] = 0x3f3f3f3f3f3f3f3f;
for (int x, y, v; q--; ) {
scanf("%d%d%d", &x, &y, &v);
merge(y + n, x, -v);
}
for (int i = 1; i <= n + m; ++i) get(i);
for (int i = 1; i <= n; ++i) mi[fa[i]] = min(mi[fa[i]], d[i]);
for (int i = 1; i <= m; ++i)
if (-d[i + n] + mi[fa[i + n]] < 0) return puts(BAD), 0;
puts(OK);
return 0;
}
卡常后最优解代码
#include <cstdio>
#include <cstdlib>
using namespace std;
const int MAX = 1 << 26;
char buf[MAX], *inp = buf;
template <typename T>
inline void read(T &x) {
x = 0; char ch = *inp++;
for (; ch < 48; ch = *inp++);
for (; ch >= 48; ch = *inp++) x = (x << 3) + (x << 1) + (ch ^ 48);
}
inline void swap(int &a, int &b) { a ^= b ^= a ^= b; }
#define BAD "No"
#define OK "Yes"
const int N = 1e5 + 10;
using lint = long long;
int n, m, q;
int fa[N << 1];
lint d[N << 1], mi[N << 1];
int get(int x) {
return x == fa[x] ? x : (get(fa[x]), d[x] += d[fa[x]], fa[x] = fa[fa[x]]);
}
inline void merge(int x, int y, lint v) {
int a = get(x), b = get(y);
if (a > b) swap(a, b), swap(x, y), v = -v;
if (a == b) return d[x] - d[y] != v && (puts(BAD), exit(0), 0), void();
d[a] = d[y] - d[x] + v, fa[a] = b;
}
signed main() {
fread(buf, 1, MAX, stdin), read(n), read(m), read(q);
for (int i = 1; i <= n + m; ++i) fa[i] = i, mi[i] = 0x3f3f3f3f3f3f3f3f;
for (int x, y, v; q--; ) read(x), read(y), read(v), merge(y + n, x, -v);
for (int i = 1; i <= n + m; ++i) get(i);
for (int i = 1; i <= n; ++i) mi[fa[i]] > d[i] && (mi[fa[i]] = d[i]);
for (int i = 1; i <= m; ++i) if (-d[i + n] + mi[fa[i + n]] < 0) return puts(BAD), 0;
puts(OK);
return 0;
}