很厉害的题!
考虑所有车站已确定,如何求 \(0\) 到 \(n+1\) 的最短路。设 \(g_{i,0}\) 为只考虑 \(0 \sim i\) 的点,到 \(i\) 和它左边第一个 \(\text{A}\) 的最短路,\(g_{i,1}\) 同理。有转移:
- 若 \(s_{i-1} = \text{A}, s_i = \text{A}, g_{i,0} \gets g_{i-1,0} + 1\)
- 若 \(s_{i-1} = \text{A}, s_i = \text{B}, g_{i,0} \gets \min(g_{i-1,0}, g_{i-1,1} + 2)\)
- 若 \(s_{i-1} = \text{B}, s_i = \text{A}, g_{i,0} \gets \min(g_{i-1,0} + 1, g_{i-1,1} + 1)\)
- 若 \(s_{i-1} = \text{B}, s_i = \text{A}, g_{i,0} \gets g_{i-1,0}\)
\(g_{i,1}\) 的转移是对称的。
设 \(f_{i,x,y,0/1}\) 表示当前考虑了 \(0 \sim i\) 的车站,\(g_{i,0} = x, g_{i,1} = y\),\(s_i\) 为 \(\text{A}\) 或 \(\text{B}\) 的方案数。这是 \(O(n^3)\) 的。
考虑压状态。显然遇到 \(\text{ABB...B}\) 时 \(x,y\) 相差就会很大。但是要到达最后一个 \(\text{B}\),可以先跳一步 \(\text{A}\) 再往回走。这是我们的最终目标,我们不关心途中的最短路数值究竟是什么。因此可以做出如下优化:当 \(x \ge y + 2\) 时,强制让 \(x \gets y + 2\),\(y\) 同理。这样不会影响最终答案。这样是 \(O(n^2)\) 的。
实现时用 unordered_map
记录所有状态有效。
code
// Problem: F - AtCoder Express 3
// Contest: AtCoder - AtCoder Regular Contest 119
// URL: https://atcoder.jp/contests/arc119/tasks/arc119_f
// Memory Limit: 1024 MB
// Time Limit: 4000 ms
//
// Powered by CP Editor (https://cpeditor.org)
#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mems(a, x) memset((a), (x), sizeof(a))
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<int, int> pii;
const int maxn = 4010;
const int mod = 1000000007;
int n, m;
char s[maxn];
unordered_map<int, int> f[2][maxn][2];
inline void upd(int o, int x, int y, int k, int val) {
x = min(x, y + 2);
y = min(y, x + 2);
int &p = f[o][x][k][y];
p += val;
(p >= mod) && (p -= mod);
}
void solve() {
scanf("%d%d%s", &n, &m, s + 1);
--n;
if (s[1] != 'B') {
f[1][1][0][0] = 1;
}
if (s[1] != 'A') {
f[1][0][1][1] = 1;
}
for (int i = 2, o = 0; i <= n; ++i, o ^= 1) {
for (int x = 0; x <= n + 2; ++x) {
for (int k = 0; k <= 1; ++k) {
f[o][x][k].clear();
}
}
for (int x = 0; x <= n + 2; ++x) {
for (pii p : f[o ^ 1][x][0]) {
int y = p.fst, val = p.scd;
if (s[i] != 'B') {
upd(o, x + 1, y, 0, val);
}
if (s[i] != 'A') {
upd(o, min(x, y + 2), min(x + 1, y + 1), 1, val);
}
}
for (pii p : f[o ^ 1][x][1]) {
int y = p.fst, val = p.scd;
if (s[i] != 'B') {
upd(o, min(x + 1, y + 1), min(x + 2, y), 0, val);
}
if (s[i] != 'A') {
upd(o, x, y + 1, 1, val);
}
}
}
}
int ans = 0;
for (int x = 0; x <= n; ++x) {
for (int k = 0; k <= 1; ++k) {
for (pii p : f[n & 1][x][k]) {
int y = p.fst, val = p.scd;
if (min(x, y) + 1 <= m) {
ans = (ans + val) % mod;
}
}
}
}
printf("%d\n", ans);
}
int main() {
int T = 1;
// scanf("%d", &T);
while (T--) {
solve();
}
return 0;
}