给定一个 \(n \times m\),值域 \([0,9]\) 的矩阵 \(B\),计数有多少个大小相同的矩阵 \(A\) 满足下列条件:
- 分别对 \(A\) 的每一列中元素从小到大排序,再分别对 \(A\) 的每一行中元素从小到大排序能够得到 \(B\)。
- 分别对 \(A\) 的每一行中元素从小到大排序,再分别对 \(A\) 的每一列中元素从小到大排序能够得到 \(B\)。
\(1\le n,m\le 1500\),答案对 \(998244353\) 取模。
先考虑只有 \(01\) 怎么做。手玩一下,容易发现判定 \(A\) 是否合法只需要判定行和列 \(1\) 出现次数的可重集和 \(B\) 是否相等。答案就是可重集排列,对应的方案是唯一的。值域更大的时候考虑枚举一下,\(A\) 合法的充要条件是对于每个 \(0 \le k \le 9\),\(≤ k\) 的部分合法。容易知道这样算也是对的。
改写一下可重集的限制,它等价于:对于每个 \(k\),存在排列 \(p(k,i)\) 和 排列 \(q(k,i)\) 满足 \(A_{i,j}≤k \Leftrightarrow B_{p(k,i),q(k,j)}≤k\)。并且显然这样的一组排列和矩阵 \(A\) 是一一对应的。
\(k\) 固定的情况就相当于只有 \(01\)。现在问题是怎么把这一些排列拼在一起,相当于我们要数有多少个排列对的序列 \((p(1),q(1))...(p(k),q(k))\) 满足对任意 \(k\),都有:\(B_{p(k,i),q(k,j)}≤k \Leftrightarrow B_{p(k+1,i),q(k+1,j)}≤k+1\)。
考虑现在已经确定了前 \(k\) 对排列,现在要确定第 \(k+1\) 对。我们注意到,这里的方案数事实上和前 \(k\) 对排列是什么没有关系,因为同时做置换之后是一一对应,因此我们不妨设 \(p(k-1,i)=q(k-1,i) = i\),这样每个 \(k\) 就独立了。对于每个 \(k\),我们要求有是多少个排列 \(p,q\) 满足:\(B_{i,j}≤k \Rightarrow B_{p(i),q(j)}≤k+1\)。
画画图,我们可以进一步转化条件,令 \(a_i\) 表示有多少 \(j\) 满足 \(B_{i,j}≤k\),\(b_j\) 表示有多少 \(i\) 满足 \(B_{i,j}≤k+1\),则问题转化为给定两个不升的序列 \(a,b\),有多少个排列 \(p,q\) 满足:\(j≤a_i \Rightarrow p(i) ≤ b_{q(j)}\)。更进一步地,由 \(b\) 不升,式子还可以改写为 \(p(i)≤b_{\max(q(a_1)⋯q(a_i))}\)。
但现在想要计数似乎还是不太容易,式子里既有 \(p(i)\) 又有 \(q(i)\),很烦。考虑拆一下。
我们按 \(x = \max(a_1⋯a_i)\) 从后往前 DP,设 \(f(i,j)\) 表示 \(\max(q(1)...q(i))=j\) 的方案数,\(q_i\) 的方案数可以在转移时维护:\(f(i,j) \gets (j-i+1) \times f(i-1,j) + \sum_{k < j} f(i-1,k)\),后半部分可以前缀和算。
再维护一个指针 \(t\),容易在 \(a_t=i\) 时统计 \(p(i)\) 的方案数,有转移:\(f(i,j) \gets (b_j - i + 1) \times f(i,j)\)。
最后注意一下可重集会多算,除掉即可。总时间复杂度 \(\mathcal{O}(Vm(n+m))\)。
code
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
typedef pair <int, int> pi;
constexpr int N = 2e3 + 5, V = 10, mod = 998244353;
bool Mbe;
int ksm(int a, int b) {
int ret = 1;
for (; b; b >>= 1, a = 1LL * a * a % mod) if (b & 1) ret = 1LL * ret * a % mod;
return ret;
}
int n, m, a[N][N], r[V][N], c[V][N];
int fc[N], ifc[N], f[N][N];
bool Med;
int main() {
// fprintf(stderr, "%.3lf MB\n", (&Med - &Mbe) / 1048576.0);
ios :: sync_with_stdio(false);
cin.tie(nullptr);
cin >> n >> m;
int M = max(n, m);
fc[0] = 1;
for (int i = 1; i <= M; i++) fc[i] = 1LL * fc[i - 1] * i % mod;
ifc[M] = ksm(fc[M], mod - 2);
for (int i = M; i >= 1; i--) ifc[i - 1] = 1LL * ifc[i] * i % mod;
for (int i = 1; i <= n; i++) {
for (int j = 1; j <= m; j++) {
cin >> a[i][j];
r[a[i][j]][i] += 1;
c[a[i][j]][j] += 1;
}
}
int ans = 1;
for (int k = 0; k + 1 < V; k++) {
for (int i = 1; i <= n; i++) r[k + 1][i] += r[k][i];
for (int j = 1; j <= m; j++) c[k + 1][j] += c[k][j];
memset(f, 0, sizeof(f));
f[0][0] = 1;
int t = n;
for (; t && r[k][t] == 0; t--);
for (int i = 1; i <= m; i++) {
for (int j = 1; j <= m; j++) f[i][j] = (f[i][j - 1] + f[i - 1][j - 1]) % mod;
for (int j = 1; j <= m; j++) f[i][j] = (f[i][j] + 1LL * (j - i + 1) * f[i - 1][j] % mod) % mod;
for (; t && r[k][t] == i; t--) {
for (int j = 1; j <= m; j++) {
f[i][j] = 1LL * f[i][j] * (c[k + 1][j] - t + 1) % mod;
}
}
}
ans = 1LL * ans * f[m][m] % mod;
static int cnt[N];
memset(cnt, 0, sizeof(cnt));
for (int i = 1; i <= n; i++) cnt[r[k][i]] += 1;
for (int j = 1; j <= m; j++) ans = 1LL * ans * ifc[cnt[j]] % mod;
memset(cnt, 0, sizeof(cnt));
for (int j = 1; j <= m; j++) cnt[c[k][j]] += 1;
for (int i = 0; i <= n; i++) ans = 1LL * ans * ifc[cnt[i]] % mod;
}
cout << ans << "\n";
return 0;
}