被 QOJ1193 Ambiguous Encoding 撞了。
考虑直接 dp,设 \(f_{i, j}\) 为较长的串未被较短的串覆盖的部分是第 \(i\) 个字符串的长为 \(j\) 的后缀。转移考虑枚举接在较短的串后面是第 \(k\) 个串,然后讨论一下 \(j\) 和第 \(k\) 个字符串的大小关系就可以确定转移到哪。
发现转移成环,考虑建图用最短路转移。这里 \(|E| = n^2m, |V| = nm\),复杂度看似是 \(O(|E| \log |V|) = O(n^2m \log nm)\) 的。但是 Dijkstra 最短路复杂度中的 \(|E|\) 实际上是松弛次数,但是这题考虑最大边权 \(\le m\) 所以任意时刻队列距离的极差 \(\le m\),所以每个点松弛次数 \(\le m\),所以复杂度其实是 \(O(n^2m + nm^2 \log nm)\)。这可以解释为什么 QOJ1193 Ambiguous Encoding 跑得那么快。
code
#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mkp make_pair
#define mems(a, x) memset((a), (x), sizeof(a))
using namespace std;
typedef long long ll;
typedef double db;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<int, int> pii;
const int maxn = 55;
const int maxm = 2510;
int n, id[maxn][maxn], nt, f[maxm];
bool vis[maxm];
vector<pii> G[maxm];
string s[maxn];
struct node {
int u, d;
node(int a = 0, int b = 0) : u(a), d(b) {}
};
inline bool operator < (const node &a, const node &b) {
return a.d > b.d;
}
priority_queue<node> pq;
void solve() {
scanf("%d%*d", &n);
for (int i = 1; i <= n; ++i) {
cin >> s[i];
for (int j = 0; j < (int)s[i].size(); ++j) {
id[i][j] = ++nt;
}
}
for (int i = 1; i <= n; ++i) {
for (int j = 1; j < (int)s[i].size(); ++j) {
for (int k = 1; k <= n; ++k) {
int t = min(j, (int)s[k].size());
if (s[k].substr(0, t) == s[i].substr((int)s[i].size() - j, t)) {
if (t == (int)s[k].size()) {
G[id[i][j]].pb(id[i][j - t], t);
} else {
G[id[i][j]].pb(id[k][(int)s[k].size() - j], t);
}
}
}
}
}
mems(f, 0x3f);
set<string> S;
for (int i = 1; i <= n; ++i) {
S.insert(s[i]);
}
for (int i = 1; i <= n; ++i) {
for (int j = 1; j < (int)s[i].size(); ++j) {
string t = s[i].substr(0, j);
if (S.find(t) != S.end()) {
f[id[i][(int)s[i].size() - j]] = j;
pq.emplace(id[i][(int)s[i].size() - j], j);
}
}
}
while (pq.size()) {
int u = pq.top().u;
pq.pop();
if (vis[u]) {
continue;
}
vis[u] = 1;
for (pii p : G[u]) {
int v = p.fst, d = p.scd;
if (f[v] > f[u] + d) {
f[v] = f[u] + d;
if (!vis[v]) {
pq.emplace(v, f[v]);
}
}
}
}
int ans = 2e9;
for (int i = 1; i <= n; ++i) {
ans = min(ans, f[id[i][0]]);
}
printf("%d\n", ans > 1e9 ? -1 : ans);
}
int main() {
int T = 1;
// scanf("%d", &T);
while (T--) {
solve();
}
return 0;
}