【模板】后缀平衡树
题目描述
给你一个字符串 init
,要求你支持三个操作:
-
在当前字符串的后面插入若干个字符。
-
在当前字符串的后面删除若干个字符。
-
询问字符串 \(s\) 在当前字符串中出现了几次(作为连续子串)?
你必须在线支持这些操作。
Solution
此处写一种非常规做法,复杂度比后缀平衡树稍劣,但是比后缀平衡树好写很多。洛谷上的模板题卡不过去,最大用时会跑到 \(2.6\) 秒左右。
考虑一些暴力,最简单的暴力显然是使用 KMP 或者 Hash 来判定当前字符串出现的次数。复杂度显然是 \(\mathcal O(qn)\) 的。
如果没有插入和删除操作,那么查询一个模式串在文本串中出现的次数显然可以建后缀机来计算,复杂度 \(\mathcal O(\sum L)\)。
考虑将这两种暴力揉到一起。对于一次询问,我们可以将答案看成两部分的贡献组成:一部分是原串,一部分是新加入的串。对于原串建立后缀机来查询,新串跑暴力即可。但是有可能新串很长,原串很短,这样复杂度就退化了,所以使用根号重构的技巧,当新串的长度超过了一个阈值 \(S\) 就将整个后缀机重构。重构次数显然是 \(\mathcal O(\dfrac{\sum L}{S})\) 的,每次暴力的复杂度是 \(\mathcal O(S)\) 的,因此复杂度为 \(\mathcal O(\dfrac{\sum L}{S}+qS)\),\(\sum L\) 和 \(q\) 同阶时 \(S\) 取 \(\sqrt q\) 即可。
代码实现的时候会发现,当进行删除操作的时候,可能会使得后缀机上的一些节点不能再对答案产生贡献。对于这种情况,容易发现被删除的长度不会超过 \(\mathcal O(S)\),因此在后缀机上跑出答案后用暴力计算出被删除的部分的贡献然后减去即可。
代码不算很难写。
Code
// Cirno is not baka!
#include <bits/stdc++.h>
using namespace std;
#ifdef CIRNO
#include <debug.hpp>
#else
#define Debug(...)
#define Debug_r(...)
#endif
#define For(i, a, b) for (int i = (a); i <= (int)(b); ++i)
#define Rof(i, a, b) for (int i = (a); i >= (int)(b); --i)
#define All(x) x.begin(), x.end()
#define pii pair<int, int>
#define fi first
#define se second
#define i64 long long
#define u64 unsigned long long
#define mkp make_pair
#define epb emplace_back
const int _N = 1.6e6 + 5, mod = 1e9 + 7, inf = 1e9, P = 13331;
template<typename T> void Max(T &x, T y) {x = max(x, y);}
template<typename T> void Min(T &x, T y) {x = min(x, y);}
namespace BakaCirno {
struct State {
int ch[26], len, link;
} s[_N];
int tot, lst, f[_N];
vector<int> e[_N];
void init() {
For(i, 1, tot) {
memset(s[i].ch, 0, sizeof s[i].ch);
s[i].len = s[i].link = f[i] = 0;
e[i].clear();
}
tot = lst = 1;
}
void insert(int c) {
int cur = ++tot; f[cur] = 1;
s[cur].len = s[lst].len + 1;
int p = lst;
while (p && !s[p].ch[c])
s[p].ch[c] = cur, p = s[p].link;
if (!p) s[cur].link = 1;
else {
int q = s[p].ch[c];
if (s[q].len == s[p].len + 1) s[cur].link = q;
else {
int clone = ++tot;
s[clone] = s[q];
s[clone].len = s[p].len + 1;
s[q].link = s[cur].link = clone;
while (p && s[p].ch[c] == q)
s[p].ch[c] = clone, p = s[p].link;
}
}
lst = cur;
}
void build(char *str, int N) {
init();
For(i, 0, N - 1) insert(str[i] - 'A');
For(i, 2, tot) e[s[i].link].epb(i);
auto dfs = [&](const auto &dfs, int x)->void {
for (int v: e[x]) dfs(dfs, v), f[x] += f[v];
};
dfs(dfs, 1);
}
int Q, lstPos, ptr, Lim, mask, delLen, N;
char base[_N], opt[10], tmpStr[_N], samStr[_N];
u64 pw[_N];
void initHash(int n) {
pw[0] = 1;
For(i, 1, n) pw[i] = pw[i-1] * P;
}
void decode(char *str, int N, int mask) {
For(i, 0, N - 1) {
mask = (mask * 131 + i) % N;
swap(str[i], str[mask]);
}
}
u64 getHash(char *str, int N) {
u64 res = 0;
For(i, 0, N - 1) res = res * P + str[i] - 'A' + 1;
return res;
}
int getMatch(char *str, int N, u64 hsh, int M) {
u64 res = 0;
int ans = 0;
For(i, 0, N - 1) {
res = res * P + str[i] - 'A' + 1;
if (i - M >= 0) res -= (str[i-M] - 'A' + 1) * pw[M];
ans += hsh == res;
}
return ans;
}
void _() {
cin >> Q >> base;
initHash(2e6);
ptr = lstPos = strlen(base) - 1;
build(base, N = ptr + 1);
memcpy(samStr, base, sizeof(char) * N);
Lim = 2e4;
while (Q--) {
cin >> opt;
if (opt[0] == 'A') {
cin >> tmpStr;
int len = strlen(tmpStr);
decode(tmpStr, len, mask);
memcpy(base + ptr + 1, tmpStr, sizeof(char) * len);
ptr += len;
}
if (opt[0] == 'D') {
cin >> delLen;
ptr -= delLen;
Min(lstPos, ptr);
}
if (opt[0] == 'Q') {
cin >> tmpStr;
int len = strlen(tmpStr);
decode(tmpStr, len, mask);
int cur = 1;
For(i, 0, len - 1) cur = s[cur].ch[tmpStr[i]-'A'];
int ans = 0;
if (cur) ans += f[cur];
int pl = max(0, lstPos - len + 2);
u64 hsh = getHash(tmpStr, len);
ans -= getMatch(samStr + pl, N - pl, hsh, len);
ans += getMatch(base + pl, ptr - pl + 1, hsh, len);
cout << ans << '\n';
mask ^= ans;
}
if (ptr - lstPos > Lim) {
build(base, N = ptr + 1);
memcpy(samStr, base, sizeof(char) * N);
lstPos = ptr;
}
}
}
}
void File(const string file) {
freopen((file + ".in").c_str(), "r", stdin);
freopen((file + ".out").c_str(), "w", stdout);
}
signed main() {
double ST = clock();
// File("P6164_1");
cin.tie(0)->sync_with_stdio(0);
int T = 1;
// cin >> T;
while (T--) BakaCirno::_();
Debug((clock() - ST) / CLOCKS_PER_SEC, "s\n");
}
这份代码可以通过 P5212(没有删除操作)。
Bonus
询问变为询问当前串的 \([L,R]\) 区间内有多少模式串,同样强制在线。
会发现暴力的部分仍然不会发生改变,但是后缀机上的部分就不是很好计算贡献了,因为需要统计一个区间内的点数。
会发现这个问题变成了二维数点,即在线询问后缀 link 树上某个节点的子树内有多少满足 \([L,R]\) 限制的点。这可以使用线段树合并提前预处理出每个节点的点集,然后在线段树上直接求区间和即可。复杂度多个 \(\log\),所以需要调一下 \(S\) 的大小。
Code
// Cirno is not baka!
#include <bits/stdc++.h>
using namespace std;
#ifdef CIRNO
#include <debug.hpp>
#else
#define Debug(...)
#define Debug_r(...)
#endif
#define For(i, a, b) for (int i = (a), i##end = (b); i <= i##end; ++i)
#define Rof(i, a, b) for (int i = (a), i##end = (b); i >= i##end; --i)
#define All(x) x.begin(), x.end()
#define pii pair<int, int>
#define fi first
#define se second
#define i64 long long
#define u64 unsigned long long
#define mkp make_pair
// #define int long long
#define epb emplace_back
const int _N = 4e5 + 5, mod = 1e9 + 7, inf = 1e9;
template<typename T> void Max(T &x, T y) {x = max(x, y);}
template<typename T> void Min(T &x, T y) {x = min(x, y);}
namespace BakaCirno {
namespace Map {
char id[127];
int di[127];
void init() {
For(i, 0, 25) id[i] = 'a' + i;
For(i, 0, 25) id[i + 26] = 'A' + i;
For(i, 0, 51) di[(int)id[i]] = i;
}
inline char gChar(int x) {return id[x];}
inline int gInt(char x) {return di[(int)x];}
}
using Map::gChar, Map::gInt;
struct State {
int ch[52], len, link;
} s[_N];
int lst, tot, pos[_N];
vector<int> e[_N];
int rt[_N], lc[_N*30], rc[_N*30], val[_N*30], nt;
void pushup(int k) {val[k] = val[lc[k]] + val[rc[k]];}
#define mid ((l + r) >> 1)
void update(int &k, int l, int r, int p, int v) {
k = ++nt;
if (l == r) return val[k] = v, void();
if (p <= mid) update(lc[k], l, mid, p, v);
else update(rc[k], mid + 1, r, p, v);
pushup(k);
}
int merge(int x, int y, int l, int r) {
if (!x || !y) return x + y;
int cur = ++nt;
val[cur] = val[x], lc[cur] = lc[x], rc[cur] = rc[x];
if (l == r) return val[cur] += val[y], cur;
lc[cur] = merge(lc[x], lc[y], l, mid);
rc[cur] = merge(rc[x], rc[y], mid + 1, r);
return pushup(cur), cur;
}
int query(int k, int l, int r, int a, int b) {
if (a > b || !k) return 0;
if (l >= a && r <= b) return val[k];
int res = 0;
if (a <= mid) res += query(lc[k], l, mid, a, b);
if (b > mid) res += query(rc[k], mid + 1, r, a, b);
return res;
}
void init() {
For(i, 1, tot) {
memset(s[i].ch, 0, sizeof s[i].ch);
s[i].len = s[i].link = rt[i] = 0;
pos[i] = -1;
e[i].clear();
}
lst = 1, tot = 1, nt = 0;
}
void insert(int c) {
int cur = ++tot;
s[cur].len = s[lst].len + 1;
int p = lst;
while (p && !s[p].ch[c])
s[p].ch[c] = cur, p = s[p].link;
if (!p) s[cur].link = 1;
else {
int q = s[p].ch[c];
if (s[q].len == s[p].len + 1) s[cur].link = q;
else {
int clone = ++tot;
s[clone] = s[q];
s[clone].len = s[p].len + 1;
s[q].link = s[cur].link = clone;
while (p && s[p].ch[c] == q)
s[p].ch[c] = clone, p = s[p].link;
}
}
lst = cur;
}
void build(const char *str, int n) {
init();
For(i, 0, n - 1) {
insert(gInt(str[i]));
pos[lst] = i;
}
For(i, 2, tot) e[s[i].link].epb(i);
auto dfs = [&](const auto &dfs, int x)->void {
// if (~pos[x]) Debug("(", x, pos[x], ' ');
if (~pos[x]) update(rt[x], 0, n - 1, pos[x], 1);
for (int v: e[x]) {
dfs(dfs, v);
rt[x] = merge(rt[x], rt[v], 0, n - 1);
}
// if (~pos[x]) Debug(")", ' ');
};
dfs(dfs, 1);
}
int Q, N, lstAns, ptr, lstPos, Lim;
char base[_N];
inline char gReal(char c) {return gChar((gInt(c) + lstAns) % 52);}
void gReal(string &str) {for (char &c: str) c = gReal(c);}
const int P = 13331;
u64 pw[_N];
void initHash(int n) {
memset(pos, -1, sizeof pos);
pw[0] = 1;
For(i, 1, n) pw[i] = pw[i - 1] * P;
}
u64 getHash(const string &str) {
u64 res = 0;
for (char c: str) res = res * P + c;
return res;
}
void _() {
Map::init();
initHash(2e5);
string tmp; cin >> tmp;
N = tmp.length();
Debug(N, '\n');
ptr = lstPos = N - 1;
For(i, 0, tmp.length() - 1) base[i] = tmp[i];
build(base, N);
Lim = N + Q;
cin >> Q;
int cnt = 0;
while (Q--) {
int opt, L, R; char c;
cin >> opt;
if (opt == 1) {
cin >> c; c = gReal(c);
base[++ptr] = c;
}
if (opt == 2) {
--ptr;
Min(lstPos, ptr);
}
if (opt == 3) {
++cnt;
string mod;
cin >> L >> R >> mod;
L ^= lstAns, R ^= lstAns, gReal(mod);
// Debug(L, R, mod, '\n');
if ((int)mod.length() > R - L + 1) {
cout << (lstAns = 0) << '\n';
continue;
}
--L, --R;
int cur = 1;
for (char c: mod) cur = s[cur].ch[gInt(c)];
int ans = 0;
if (cur)
ans += query(rt[cur], 0, N - 1, L + mod.length() - 1, min(R, lstPos));
u64 modHash = getHash(mod);
int pl = max(L, lstPos - (int)mod.length() + 2);
u64 res = 0;
For(i, pl, min(ptr, R)) {
res = res * P + base[i];
if (i - (int)mod.length() >= pl)
res -= base[i - mod.length()] * pw[mod.length()];
if (res == modHash) ++ans;
}
cout << (lstAns = ans) << '\n';
}
if (ptr - lstPos > Lim) {
build(base, N = ptr + 1);
lstPos = ptr;
Debug("Rebuild!\n");
// assert(0);
}
}
}
}
void File(const string file) {
freopen((file + ".in").c_str(), "r", stdin);
freopen((file + ".out").c_str(), "w", stdout);
}
signed main() {
File("fantasy");
cin.tie(0)->sync_with_stdio(0); int T = 1;
// cin >> T;
while (T--) BakaCirno::_();
}