C - Not So Consecutive
Problem Statement
You are given an integer $N$. An integer sequence $x=(x_1,x_2,\cdots,x_N)$ of length $N$ is called a good sequence if and only if the following conditions are satisfied:
- Each element of $x$ is an integer between $1$ and $N$, inclusive.
- For each integer $i$ ($1 \leq i \leq N$), there is no position in $x$ where $i$ appears $i+1$ or more times in a row.
You are given an integer sequence $A=(A_1,A_2,\cdots,A_N)$ of length $N$. Each element of $A$ is $-1$ or an integer between $1$ and $N$. Find the number, modulo $998244353$, of good sequences that can be obtained by replacing each $-1$ in $A$ with an integer between $1$ and $N$.
Constraints
- $1 \leq N \leq 5000$
- $A_i=-1$ or $1 \leq A_i \leq N$.
- All input values are integers.
Input
The input is given from Standard Input in the following format:
$N$
$A_1$ $A_2$ $\cdots$ $A_N$
Output
Print the answer.
Sample Input 1
2
-1 -1
Sample Output 1
3
You can obtain four sequences by replacing each $-1$ with an integer between $1$ and $2$.
$A=(1,1)$ is not a good sequence because $1$ appears twice in a row.
The other sequences $A=(1,2),(2,1),(2,2)$ are good.
Thus, the answer is $3$.
Sample Input 2
3
2 -1 2
Sample Output 2
2
Sample Input 3
4
-1 1 1 -1
Sample Output 3
0
Sample Input 4
20
9 -1 -1 -1 -1 -1 -1 -1 -1 -1 7 -1 -1 -1 19 4 -1 -1 -1 -1
Sample Output 4
128282166
解题思路
纯动态规划优化题,硬是从 $O(n^4)$ 优化到 $O(n^2)$ 甚至是 $O(n^2 \log{n})$,超有意思的说。
状态还是很容易想到的,定义 $f(i,j)$ 表示由前 $i$ 个数构成且第 $i$ 个数是 $j$ 的所有合法方案的数量。根据序列最后一段有多少个连续 $j$(假设有 $k$ 个),以及第 $i-k$ 个数是哪个数(假设是 $u$,需满足 $u \ne j$)进行状态划分,状态转移方程就是$$f(i,j) = \sum\limits_{k=1}^{\min \{ i, j \}}{\sum\limits_{\begin{array}{c} u=1 \\ u \ne j \end{array}}^{n}{f(i-k, u)}}$$
实际上这个状态转移方程是有问题的,因为默认了 $a_{i-1}, a_{i-2}, \ldots, a_{i- \min \{ i, j \}}$ 都是 $-1$ 的情况。考虑 $a_1 \sim a_{i-1}$,如果这些数中存在某些 $a_v \ne -1$ 且 $a_v \ne j$,不妨假设 $v$ 是这些数中的最大下标。如果不存在这样的 $v$,即该范围内的数均是 $-1$ 或 $j$,则令 $v = 0$,同时规定 $a_0 = 0$。分情况讨论,如果 $i-v \leq j$,那么很明显最后一段最多只能有 $i - v$ 个连续的 $j$,且第 $v$ 个数 $a_v$ 是固定的。否则连续一段 $j$ 的最大长度就是 $j$。另外如果存在 $a_{i-k} = j$ 的情况,跳过即可。因此正确的状态转移方程应该是
$$f(i,j) = \begin{cases}
\left( \sum\limits_{\begin{array}{c} k=1 \\ a_{i-k} \ne j \end{array}}^{i-v-1}{\sum\limits_{\begin{array}{c} u=1 \\ u \ne j \end{array}}^{n}{f(i-k, u)}} \right) + f(v, a_v), &i-v \leq j \\\\
\sum\limits_{\begin{array}{c} k=1 \\ a_{i-k} \ne j \end{array}}^{j}{\sum\limits_{\begin{array}{c} u=1 \\ u \ne j \end{array}}^{n}{f(i-k, u)}}, &\text{others}
\end{cases}$$
同时规定 $f(0,0) = 0$,这样对于序列前 $i$ 个数都是 $j$ 的状态可以从 $f(0,0)$ 转移得到。
容易知道整个 dp 的时间复杂度是 $O(n^4)$,不过一个很明显可以优化的地方是 $\sum\limits_{\begin{array}{c} u=1 \\ u \ne j \end{array}}^{n}{f(i-k, u)}$ 这部分。本质是累加所有第一维是 $i-k$ 的状态 $f(i-k, *)$,然后减去 $f(i-k, j)$,而 $f(i-k, *)$ 在之前就已经全部求出来了。所以定义 $s_i = \sum\limits_{k=1}^{n}{f(i, k)}$,那么 $\sum\limits_{\begin{array}{c} u=1 \\ u \ne j \end{array}}^{n}{f(i-k, u)}$ 就可以等价成 $s_{i-k} - f(i-k, j)$,而 $s_i$ 只需在计算完 $f(i, *)$ 时进行累加即可,这样时间复杂度就降到了 $O(n^3)$。
对应的状态转移方程如下:
$$f(i,j) = \begin{cases}
\left( \sum\limits_{k=1}^{i-v-1}{s_{i-k} - f(i-k,j)} \right) + f(v, a_v), &i-v \leq j \\\\
\sum\limits_{k=1}^{j}{s_{i-k} - f(i-k,j)}, &\text{others}
\end{cases}$$
对于 $a_{i-k} = j$ 的情况原本是要跳过的,但对于这种情况必然有 $s_{i-k} = f(i-k, j)$,这是因为 $a_{i-k}$ 是定值,$f(i-k, u) = 0, \, u \ne j$,因此 $s_{i-k} - f(i-k, j) = 0$,并没有影响。
先放出 TLE 代码,时间复杂度为 $O(n^3)$:
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 5010, mod = 998244353;
int a[N];
int f[N][N];
int s[N];
int main() {
int n;
scanf("%d", &n);
for (int i = 1; i <= n; i++) {
scanf("%d", a + i);
}
f[0][0] = 1;
for (int i = 1; i <= n; i++) {
if (a[i] == -1) {
for (int j = 1; j <= n; j++) {
for (int k = 1; k <= j && k <= i; k++) {
if (a[i - k] != -1 && a[i - k] != j) {
f[i][j] = (f[i][j] + f[i - k][a[i - k]]) % mod;
break;
}
f[i][j] = ((LL)f[i][j] + s[i - k] - f[i - k][j] + mod) % mod;
}
}
}
else {
int j = a[i];
for (int k = 1; k <= j && k <= i; k++) {
if (a[i - k] != -1 && a[i - k] != j) {
f[i][j] = (f[i][j] + f[i - k][a[i - k]]) % mod;
break;
}
f[i][j] = ((LL)f[i][j] + s[i - k] - f[i - k][j] + mod) % mod;
}
}
for (int j = 1; j <= n; j++) {
s[i] = (s[i] + f[i][j]) % mod;
}
}
int ret = 0;
for (int i = 1; i <= n; i++) {
ret = (ret + f[n][i]) % mod;
}
printf("\n%d", ret);
return 0;
}
上述的代码是在枚举 $k$ 的过程中找到 $v$ 的。很明显如果还要优化的话那么就应该继续把求和符号去掉,求和的部分本质也是对第一维某个已求得区间的 $s_{i} - f(i,j)$ 进行累加,因此可以用前缀和进行优化。
定义 $S_i = \sum\limits_{k=1}^{i}{s_i}$,$g(i,j) = \sum\limits_{k=1}^{i}{f(k,j)}$。
那么 $\sum\limits_{k=1}^{i-v-1}{s_{i-k} - f(i-k,j)}$ 就等价于 $S_{i-1} - S_{v} - (g(i-1,j) - g(v,j))$。
同理 $\sum\limits_{k=1}^{j}{s_{i-k} - f(i-k,j)}$ 就等价于 $S_{i-1} - S_{i-j-1} - (g(i-1,j) - g(i-j-1,j))$。
状态转移方程变成了
$$f(i,j) = \begin{cases}
S_{i-1} - S_{v} - (g(i-1,j) - g(v,j)) + f(v, a_v), &i-v \leq j \\\\
S_{i-1} - S_{i-j-1} - (g(i-1,j) - g(i-j-1,j)), &\text{others}
\end{cases}$$
现在关键的问题对于每个状态 $f(i,j)$ 如何快速确定对应的 $v$。本质是在 $a_0 \sim a_{i-1}$ 中找到同时满足 $a_v \ne -1$ 且 $a_v \ne j$ 的最大下标 $v$,所以可以用 std::set<std::pair<int, int>>
来动态维护 $0 \sim n$ 每个值出现的最大下标,其中第一个关键字是下标,第二个关键字是值,按第一个关键字降序排序。另外开一个数组 $p$ 表示每个值对应的最大下标。
当枚举到 $j$ 时,查看 st.begin()->second
,如果不等于 $j$,则对应的 $v$ 就是 st.begin()->first
,否则就是 next(st.begin())->first
。
当枚举到的 $a_i$ 是一个定值,那么只需从 std::set
中删除原本的数对 $(p_{a_i}, a_i)$,并重新插入 $(i, a_i)$,同时更新 $p_{a_i} \gets i$。
AC 代码如下,时间复杂度为 $O(n^2 \log{n})$:
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
typedef pair<int, int> PII;
const int N = 5010, mod = 998244353;
int a[N], p[N];
int f[N][N], g[N][N];
int s[N];
int main() {
int n;
scanf("%d", &n);
set<PII> st({{0, 0}});
for (int i = 1; i <= n; i++) {
scanf("%d", a + i);
st.insert({0, i});
}
f[0][0] = 1;
for (int i = 1; i <= n; i++) {
if (a[i] == -1) {
for (int j = 1; j <= n; j++) {
int x = -st.begin()->first, y = st.begin()->second;
if (st.begin()->second == j) x = -next(st.begin())->first, y = next(st.begin())->second;
if (x < i - j) f[i][j] = ((LL)s[i - 1] - s[max(0, i - j - 1)] - g[i - 1][j] + g[max(0, i - j - 1)][j]) % mod;
else f[i][j] = ((LL)s[i - 1] - s[x] - g[i - 1][j] + g[x][j] + f[x][y]) % mod;
}
}
else {
int j = a[i];
int x = -st.begin()->first, y = st.begin()->second;
if (st.begin()->second == j) x = -next(st.begin())->first, y = next(st.begin())->second;
if (x < i - j) f[i][j] = ((LL)s[i - 1] - s[max(0, i - j - 1)] - g[i - 1][j] + g[max(0, i - j - 1)][j]) % mod;
else f[i][j] = ((LL)s[i - 1] - s[x] - g[i - 1][j] + g[x][j] + f[x][y]) % mod;
st.erase({-p[j], j});
st.insert({-i, j});
p[j] = i;
}
s[i] = s[i - 1];
for (int j = 1; j <= n; j++) {
s[i] = (s[i] + f[i][j]) % mod;
g[i][j] = (g[i - 1][j] + f[i][j]) % mod;
}
}
int ret = 0;
for (int i = 1; i <= n; i++) {
ret = (ret + f[n][i]) % mod;
}
ret = (ret + mod) % mod;
printf("%d", ret);
return 0;
}
其实 $O(n^2 \log{n})$ 的复杂度已经可以过了,实际上还可以优化到 $O(n^2)$,如果有兴趣可以继续往下看。
上面的 $p_i$ 表示值 $i$ 的最大下标,可以反过来考虑,变成对于值不为 $i$ 的最大下标。那么对于 $f(i,j)$,对应的 $v$ 就直接等于 $p_j$。另外可以发现只有 $a_i \ne -1$ 的情况才需要更新 $p$ 数组,只需暴力枚举 $k$,令 $p_k = i, \, k \ne j$ 即可。
AC 代码如下,时间复杂度为 $O(n^2)$:
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
typedef pair<int, int> PII;
const int N = 5010, mod = 998244353;
int a[N];
int f[N][N], g[N][N];
int s[N];
int p[N];
int main() {
int n;
scanf("%d", &n);
for (int i = 1; i <= n; i++) {
scanf("%d", a + i);
}
f[0][0] = 1;
for (int i = 1; i <= n; i++) {
if (a[i] == -1) {
for (int j = 1; j <= n; j++) {
if (p[j] < i - j) f[i][j] = ((LL)s[i - 1] - s[max(0, i - j - 1)] - g[i - 1][j] + g[max(0, i - j - 1)][j]) % mod;
else f[i][j] = ((LL)s[i - 1] - s[p[j]] - g[i - 1][j] + g[p[j]][j] + f[p[j]][a[p[j]]]) % mod;
}
}
else {
int j = a[i];
if (p[j] < i - j) f[i][j] = ((LL)s[i - 1] - s[max(0, i - j - 1)] - g[i - 1][j] + g[max(0, i - j - 1)][j]) % mod;
else f[i][j] = ((LL)s[i - 1] - s[p[j]] - g[i - 1][j] + g[p[j]][j] + f[p[j]][a[p[j]]]) % mod;
for (int k = 1; k <= n; k++) {
if (k != j) p[k] = i;
}
}
s[i] = s[i - 1];
for (int j = 1; j <= n; j++) {
s[i] = (s[i] + f[i][j]) % mod;
g[i][j] = (g[i - 1][j] + f[i][j]) % mod;
}
}
int ret = 0;
for (int i = 1; i <= n; i++) {
ret = (ret + f[n][i]) % mod;
}
ret = (ret + mod) % mod;
printf("%d", ret);
return 0;
}
参考资料
Editorial - estie Programming Contest 2023 (AtCoder Regular Contest 169):https://atcoder.jp/contests/arc169/editorial/7911
AtCoder Regular Contest 169(A~D):https://zhuanlan.zhihu.com/p/671467218
标签:begin,limits,int,sum,ne,st,So,Consecutive From: https://www.cnblogs.com/onlyblues/p/17900946.html