发现一个性质:能跳不超过 \(j\) 步到达 \(i\) 的所有点形成一段区间。设这这段区间为 \([L_{i, j}, R_{i, j}]\)。
那么答案即为:
\[\sum\limits_{i = 1}^n \sum\limits_{j = 0} n - R_{i, j} + L_{i, j} - 1 \]并且:
\[[L_{i, j}, R_{i, j}] = \bigcup\limits_{k \in [L_{i, j - 1}, R_{i, j - 1}]} [L_{k, 1}, R_{k, 1}] \]\(L_{i, 1}\) 和 \(R_{i, 1}\) 分别是 \(i\) 左侧和右侧第一个字符和 \(s_i\) 相等的位置,\(L_{i, 0} = R_{i, 0} = i\)。
现在我们想优化暴力跳的过程。一个自然的想法是倍增。但是跳的过程涉及到 \(L, R\) 两个量的变化,它们不互相独立,不太好维护。
但是又发现,最右边(或最左边)的一些字符对 \(L\)(或 \(R\))的转移没用。比如 \(\texttt{abcab}\),最右边的 \(\texttt{ab}\) 由于前面出现过所以不会对 \(L_{i, j}\) 有贡献。
也就是说 \(L_{i, j} = \min\limits_{k \in [L_{i, j - 1}, T]}\),其中 \(T\) 是最小的满足 \([L_{i, j - 1}, T]\) 和 \([L_{i, j - 1}, R_{i, j - 1}]\) 的字符集相等的位置。所以如果固定字符集大小和 \(L_{i, j - 1}\),那么 \(L_{i, j}\) 能直接不依赖 \(R_{i, j - 1}\) 计算出来。对于 \(R\) 也类似。
字符集只会变化 \(O(|\Sigma|)\) 次。于是可以枚举字符集大小 \(k\),然后同时维护 \(4\) 个倍增数组:
- \(fl_{i, c}\) 表示对于一个 \(L_{t, j - 1} = i\),它在一开始字符集大小为 \(k\) 的情况下,跳 \(2^c\) 步后的位置。
- \(gl_{i, c}\) 表示 \(L\) 跳 \(2^c\) 步经过的所有点的和。
- \(fr_{i, c}\) 表示对于一个 \(R_{t, j - 1} = i\),它在一开始字符集大小为 \(k\) 的情况下,跳 \(2^c\) 步后的位置。
- \(gr_{i, c}\) 表示 \(R\) 跳 \(2^c\) 步经过的所有点的和。
倍增跳到最大的 \(L\) 和 \(R\) 满足 \([L, R]\) 中的字符集大小为 \(k\) 即可。
时间复杂度 \(O(n |\Sigma| \log n)\)。
code
// Problem: F. Cursor Distance
// Contest: Codeforces - Codeforces Round 596 (Div. 1, based on Technocup 2020 Elimination Round 2)
// URL: https://codeforces.com/problemset/problem/1246/F
// Memory Limit: 512 MB
// Time Limit: 2000 ms
//
// Powered by CP Editor (https://cpeditor.org)
#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mkp make_pair
#define mems(a, x) memset((a), (x), sizeof(a))
using namespace std;
typedef long long ll;
typedef double db;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<ll, ll> pii;
const int maxn = 100100;
int n, pre[maxn][26], nxt[maxn][26];
int L[maxn], R[maxn], pr[maxn], nx[maxn], fl[20][maxn], fr[20][maxn];
ll gl[20][maxn], gr[20][maxn];
char s[maxn];
inline bool check(int l, int r, int k) {
return k == 26 || nxt[l][k] > r;
}
void solve() {
scanf("%s", s + 1);
n = strlen(s + 1);
for (int i = 1; i <= n; ++i) {
L[i] = R[i] = i;
}
for (int i = 0; i < 26; ++i) {
pre[0][i] = 1;
nxt[n + 1][i] = n;
}
for (int i = 1; i <= n; ++i) {
for (int j = 0; j < 26; ++j) {
pre[i][j] = pre[i - 1][j];
}
pre[i][s[i] - 'a'] = i;
}
for (int i = n; i; --i) {
for (int j = 0; j < 26; ++j) {
nxt[i][j] = nxt[i + 1][j];
}
nxt[i][s[i] - 'a'] = i;
}
for (int i = 1; i <= n; ++i) {
fl[0][i] = pr[i] = pre[i - 1][s[i] - 'a'];
fr[0][i] = nx[i] = nxt[i + 1][s[i] - 'a'];
L[i] = R[i] = i;
}
for (int i = 1; i <= n; ++i) {
sort(pre[i], pre[i] + 26, greater<int>());
sort(nxt[i], nxt[i] + 26);
}
ll ans = 0;
for (int k = 1; k <= 26; ++k) {
for (int i = 1; i <= n; ++i) {
fl[0][i] = gl[0][i] = min(fl[0][i], pr[nxt[i][k - 1]]);
fr[0][i] = gr[0][i] = max(fr[0][i], nx[pre[i][k - 1]]);
}
for (int j = 1; j <= 18; ++j) {
for (int i = 1; i <= n; ++i) {
fl[j][i] = fl[j - 1][fl[j - 1][i]];
fr[j][i] = fr[j - 1][fr[j - 1][i]];
gl[j][i] = gl[j - 1][i] + gl[j - 1][fl[j - 1][i]];
gr[j][i] = gr[j - 1][i] + gr[j - 1][fr[j - 1][i]];
}
}
for (int i = 1; i <= n; ++i) {
if (check(L[i], R[i], k)) {
ans += n - 1 - R[i] + L[i];
for (int j = 18; ~j; --j) {
int l = fl[j][L[i]], r = fr[j][R[i]];
if (check(l, r, k)) {
ans += (1LL << j) * (n - 1) - gr[j][R[i]] + gl[j][L[i]];
L[i] = l;
R[i] = r;
}
}
L[i] = fl[0][L[i]];
R[i] = fr[0][R[i]];
}
}
}
printf("%lld\n", ans);
}
int main() {
int T = 1;
// scanf("%d", &T);
while (T--) {
solve();
}
return 0;
}