首先考虑一个经典的套路:转 \(01\)。具体而言,我们考虑若值域是 \([0, 1]\) 怎么做。
发现可以很容易地判定一个 \(A\) 是否合法。设矩阵第 \(i\) 行的和为 \(r_i\),第 \(j\) 列的和为 \(c_j\),那么合法当且仅当 \(A\) 的 \(\{r_i\}\) 和 \(\{c_j\}\)(可重集)分别与 \(B\) 的 \(\{r_i\}\) 和 \(\{c_j\}\) 相同。并且 \(r_i, c_j\) 的每一种不同的排列方案都恰好对应一个可以被操作成 \(B\) 的 \(A\)。
那么值域为 \([0, 1]\) 时答案就是 \(\{r_i\}\) 和 \(\{c_j\}\) 的可重集排列数相乘。
考虑值域为 \([0, 9]\) 的情况,考虑枚举 \(k \in [0, 8]\),把 \(\le k\) 的值赋成 \(0\),\(> k\) 赋成 \(1\)。那么 \(A\) 合法等价于,对于每个 \(k \in [0, 8]\),都存在两个排列 \(p_k(i), q_k(j)\),使得 \(A_{i, j} \le k \Longleftrightarrow B_{p_k(i), q_k(j)} \le k\)。那么一堆排列 \((p_0, q_0, p_1, q_1, \ldots, p_8, q_8)\) 可以唯一确定一个 \(A\),但是因为 \(\{r_i\}, \{c_j\}\) 是可重集,所以一个 \(A\) 实际上会对应 \(\prod\limits_{k = 0}^m (\sum\limits_{i = 1}^n [r_i = k])! \times \prod\limits_{k = 0}^n (\sum\limits_{i = 1}^m [c_i = k])!\) 堆排列。最后除一下即可。
于是考虑对这堆排列 \((p_0, q_0, p_1, q_1, \ldots, p_8, q_8)\) 计数。条件 \(A_{i, j} \le k \Longleftrightarrow B_{p_k(i), q_k(j)} \le k\) 里面有 \(A\),不妨把 \(A\) 扔掉,根据 \(A_{i, j} \le k \Longrightarrow A_{i, j} \le k + 1\) 有 \(B_{p_k(i), q_k(j)} \le k \Longrightarrow B_{p_{k + 1}(i), q_{k + 1}(j)} \le k + 1\)。
发现我们实际上只关心 \(p_{k + 1} \circ p_k^{-1}\) 和 \(q_{k + 1} \circ q_k^{-1}\)。于是条件可以被改写成 \(B_{i, j} \le k \Longrightarrow B_{p_k(i), q_k(j)} \le k + 1\)。
考察 \([B_{i, j} \le k]\) 的杨表结构。设 \(a_i = \sum\limits_{j = 1}^m [B_{i, j} \le k], b_j = \sum\limits_{i = 1}^n [B_{i, j} \le k + 1]\),可以发现 \(a, b\) 单调不升。若把 \(B\) 旋转 \(180°\),那么条件可以转化为 \(j \le a_i \Longrightarrow p_k(i) \le b_{q_k(j)}\)。更进一步地,因为 \(b\) 单调不升,所以 \(p_k(i) \le b_{\max\limits_{j = 1}^{a_i} q_k(j)}\)。
然后可以 dp 计数了。设 \(f_{i, j}\) 为 \(\max\limits_{o = 1}^i q_k(o) = j\) 的方案数。我们 \(p_k, q_k\) 的方案数分别统计。\(f_{i - 1} \to f_i\) 时,计算 \(q_k(i)\) 的方案数,有 \(f_{i, j} \gets (j - i + 1) f_{i - 1, j} + \sum\limits_{k = 1}^{j - 1} f_{i - 1, k}\),分别表示 \(q_k(i)\) 是或不是最大值。然后我们再维护一个指针 \(t\) 从后往前扫,当 \(a_t = j\) 时,计算 \(p_k(t)\) 的方案,相当于 \(p_k(t)\) 有一个 \(b_j\) 的上界,于是有 \(f_{i, j} \gets f_{i, j} \times (b_j - t + 1)\),因为 \(p_k(t)\) 的上界随 \(t\) 减小而减小。
每次的 \(f_{m, m}\) 相乘,然后再除一下上面的 \(\prod\limits_{k = 0}^m (\sum\limits_{i = 1}^n [r_i = k])! \times \prod\limits_{k = 0}^n (\sum\limits_{i = 1}^m [c_i = k])!\) 就是最终答案。
时间复杂度 \(O(Vnm)\)。
code
// Problem: E - RowCol/ColRow Sort
// Contest: AtCoder - AtCoder Grand Contest 057
// URL: https://atcoder.jp/contests/agc057/tasks/agc057_e
// Memory Limit: 1024 MB
// Time Limit: 3000 ms
//
// Powered by CP Editor (https://cpeditor.org)
#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mkp make_pair
#define mems(a, x) memset((a), (x), sizeof(a))
using namespace std;
typedef long long ll;
typedef double db;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<ll, ll> pii;
const int maxn = 1510;
const ll mod = 998244353;
inline ll qpow(ll b, ll p) {
ll res = 1;
while (p) {
if (p & 1) {
res = res * b % mod;
}
b = b * b % mod;
p >>= 1;
}
return res;
}
ll n, m, a[maxn][maxn], b[19][maxn], c[19][maxn], fac[maxn], ifac[maxn], f[maxn][maxn], d[maxn];
void solve() {
scanf("%lld%lld", &n, &m);
fac[0] = 1;
for (int i = 1; i <= max(n, m); ++i) {
fac[i] = fac[i - 1] * i % mod;
}
ifac[max(n, m)] = qpow(fac[max(n, m)], mod - 2);
for (int i = max(n, m) - 1; ~i; --i) {
ifac[i] = ifac[i + 1] * (i + 1) % mod;
}
for (int i = 1; i <= n; ++i) {
for (int j = 1; j <= m; ++j) {
scanf("%lld", &a[i][j]);
++b[a[i][j]][i];
++c[a[i][j]][j];
}
}
for (int k = 1; k <= 9; ++k) {
for (int i = 1; i <= n; ++i) {
b[k][i] += b[k - 1][i];
}
for (int i = 1; i <= m; ++i) {
c[k][i] += c[k - 1][i];
}
}
ll ans = 1;
for (int k = 0; k <= 8; ++k) {
mems(f, 0);
f[0][0] = 1;
int p = n;
while (p && b[k][p] == 0) {
ans = ans * (n - p + 1) % mod;
--p;
}
for (int i = 1; i <= m; ++i) {
ll s = 0;
for (int j = 1; j <= m; ++j) {
s = (s + f[i - 1][j - 1]) % mod;
f[i][j] = (f[i - 1][j] * (j - i + 1) % mod + s) % mod;
}
while (p && b[k][p] == i) {
for (int j = 1; j <= m; ++j) {
f[i][j] = f[i][j] * (c[k + 1][j] - p + 1) % mod;
}
--p;
}
}
ans = ans * f[m][m] % mod;
mems(d, 0);
for (int i = 1; i <= n; ++i) {
++d[b[k][i]];
}
for (int i = 0; i <= m; ++i) {
ans = ans * ifac[d[i]] % mod;
}
mems(d, 0);
for (int i = 1; i <= m; ++i) {
++d[c[k][i]];
}
for (int i = 0; i <= n; ++i) {
ans = ans * ifac[d[i]] % mod;
}
}
printf("%lld\n", ans);
}
int main() {
int T = 1;
// scanf("%d", &T);
while (T--) {
solve();
}
return 0;
}