这种题出出来有什么必要吗,不就是难写的暴力弱智题。
题意
给定一棵树和一个文本串 \(T\),每个结点上有一个字符,问树上任意路径构成的字符串在 \(T\) 中的出现次数之和。
\(n, m \leq 5 \times 10^4\)
思路
点分治 + 后缀自动机 + 根号分治。
首先可以发现期望是假的。
然后考虑点分治做树上路径计数。这里对于当前的分治重心 \(u\),统计经过 \(u\) 的所有路径构成的字符串在 \(S\) 中的出现次数之和。
看到子串出现次数之和考虑用 SAM 做。
听起来很复杂对不对?暴力怎么做?搜索!搜索!搜索!
考虑把一条经过点 \(u\) 的树上路径 \(p \rightarrow q\) 分成 \(p \rightarrow u\) 和 \(u \rightarrow q\) 两部分,后面那一部分可以在子树里面直接搜,暴力在 parent tree 上跳。
第一部分是在反串上做的镜像问题,因为要在整串的前方加入一个字符,所以需要在 parent tree 上先处理出在整串的前方加入字符所导向的子结点,这个很好做。
在 parent tree 上跳到某个结点终止后,显然这个结点所代表的串的任意一个后缀都在 \(T\) 中出现过,所以出现次数是所有祖先的出现次数之和,考虑做一个树上前缀和。
考虑统计答案:考虑 \(T\) 中每一个位置的贡献,那么答案就是 \(\sum\limits\) 前缀 \([1, i]\) 的任意后缀在 \(T\) 中的出现次数之和 \(\times\) 后缀 \([i + 1, m]\) 的任意前缀在 \(T\) 中的出现次数之和。
这里的时间复杂度是 \(O(size + m)\),约等于白做。
但是我们还有另一种暴力:枚举子树中的初始结点,然后暴搜接下来的树上路径,复杂度 \(O(size^2)\)
那直接套一个根号分治就做完了。
代码
#include <cstdio>
#include <cstring>
#include <cmath>
#include <vector>
using namespace std;
typedef long long ll;
const int maxn = 5e4 + 5;
const int sam_sz = 1e5 + 5;
const int inf = 0x3f3f3f3f;
int n, m;
int rt, sz_max, sz_sum, block;
int qlen, q[maxn], up[maxn], rc[maxn];
ll ans;
int sz[maxn];
bool vis[maxn];
char s[maxn], t[maxn];
vector<int> g[maxn];
struct SAM
{
int cur = 1, lst = 1;
int nd[maxn];
int len[sam_sz], fa[sam_sz], son[sam_sz][26], to[sam_sz][26];
int cnt[sam_sz], pos[sam_sz], seq[sam_sz], tot[sam_sz], f[sam_sz];
char str[sam_sz];
void clear() { memset(f, 0, (cur + 1) * sizeof(int)); }
void copy_node(int to, int from)
{
len[to] = len[from], fa[to] = fa[from], pos[to] = pos[from];
memcpy(son[to], son[from], sizeof(son[to]));
}
void insert(int c, int lstp)
{
int p = lst, np = lst = ++cur;
len[np] = len[p] + 1, pos[np] = lstp, cnt[np] = 1, nd[lstp] = cur;
for ( ; p && (!son[p][c]); p = fa[p]) son[p][c] = np;
if (!p) fa[np] = 1;
else
{
int q = son[p][c];
if (len[q] == len[p] + 1) fa[np] = q;
else
{
int nq = ++cur;
copy_node(nq, q);
len[nq] = len[p] + 1;
fa[q] = fa[np] = nq;
for ( ; p && (son[p][c] == q); p = fa[p]) son[p][c] = nq;
}
}
}
void sort()
{
for (int i = 1; i <= cur; i++) tot[len[i]]++;
for (int i = 1; i <= m; i++) tot[i] += tot[i - 1];
for (int i = 1; i <= cur; i++) seq[tot[len[i]]--] = i;
}
void build()
{
for (int i = 1; i <= m; i++) insert(str[i] - 'a', i);
memset(tot, 0, (m + 1) * sizeof(int));
sort();
for (int i = cur; i >= 2; i--)
{
int u = seq[i];
cnt[fa[u]] += cnt[u];
if (fa[u] != 1) to[fa[u]][str[pos[u] - len[fa[u]]] - 'a'] = u;
}
}
int nxt(int u, int l, char c)
{
if (len[u] >= l) return (str[pos[u] - l + 1] == c ? u : 0);
return to[u][c - 'a'];
}
void dfs(int u, int p, int len, int cur)
{
if (!cur) return;
f[cur]++;
for (int v : g[u])
{
if ((v == p) || vis[v]) continue;
dfs(v, u, len + 1, nxt(cur, len + 1, s[v]));
}
}
void calc()
{
for (int i = 2; i <= cur; i++)
{
int u = seq[i];
f[u] += f[fa[u]];
}
}
} s1, s2;
void get_sz(int u, int f, int sum)
{
int mx = -1;
sz[u] = 1;
for (int v : g[u])
{
if ((v == f) || vis[v]) continue;
get_sz(v, u, sum);
sz[u] += sz[v];
mx = max(mx, sz[v]);
}
mx = max(mx, sum - sz[u]);
if (mx < sz_max) rt = u, sz_max = mx;
}
void get_rt(int u, int sum)
{
sz_max = inf;
get_sz(u, 0, sum);
}
void dfs1(int u, int fa)
{
q[++qlen] = u;
for (int v : g[u])
{
if ((v == fa) || vis[v]) continue;
up[v] = u;
dfs1(v, u);
}
}
void dfs2(int u, int fa, int cur, int nega)
{
if (!cur) return;
// printf("%d %d\n", nega, s1.cnt[cur]);
ans += nega * s1.cnt[cur];
for (int v : g[u])
{
if ((v == fa) || vis[v]) continue;
dfs2(v, u, s1.son[cur][rc[v]], nega);
}
}
void calc1(int u, int c, int nega)
{
s1.clear(), s2.clear();
if (nega == 1)
{
s1.dfs(u, 0, 1, s1.son[1][rc[u]]);
s2.dfs(u, 0, 1, s2.son[1][rc[u]]);
}
else
{
s1.dfs(u, 0, 2, s1.nxt(s1.son[1][c], 2, s[u]));
s2.dfs(u, 0, 2, s2.nxt(s2.son[1][c], 2, s[u]));
}
s1.calc(), s2.calc();
for (int i = 1; i <= m; i++) ans += 1ll * nega * s1.f[s1.nd[i]] * s2.f[s2.nd[m - i + 1]];
// for (int i = 1; i <= m; i++)
// {
// printf("cur : %d %d %d %d %d\n", nega, s1.f[s1.nd[i]], s2.f[s2.nd[m - i + 1]], s1.nd[i], s2.nd[m - i + 1]);
// ans += 1ll * nega * s1.f[s1.nd[i]] * s2.f[s2.nd[m - i + 1]];
// }
}
void calc2(int u, int rt)
{
qlen = 0, up[u] = rt;
dfs1(u, 0);
for (int i = 1; i <= qlen; i++)
{
int cur = q[i], nd = 1;
while (cur != rt) nd = s1.son[nd][rc[cur]], cur = up[cur];
nd = s1.son[nd][rc[cur]], nd = s1.son[nd][rc[u]];
dfs2(u, 0, nd, -1);
}
}
void solve(int u)
{
if (sz_sum <= block)
{
qlen = 0, dfs1(u, 0);
for (int i = 1; i <= qlen; i++) dfs2(q[i], 0, s1.son[1][rc[q[i]]], 1);
return;
}
int cur_sz = sz_sum;
vis[u] = true;
calc1(u, 1, 1);
for (int v : g[u])
{
if (vis[v]) continue;
int to_sz = (sz[u] < sz[v] ? cur_sz - sz[u] : sz[v]);
if (to_sz > block) calc1(v, rc[u], -1);
else calc2(v, u);
}
for (int v : g[u])
{
if (vis[v]) continue;
sz_sum = (sz[u] < sz[v] ? cur_sz - sz[u] : sz[v]);
get_rt(v, sz_sum);
solve(rt);
}
}
int main()
{
scanf("%d%d", &n, &m);
block = sqrt(n);
for (int i = 1, u, v; i <= n - 1; i++)
{
scanf("%d%d", &u, &v);
g[u].push_back(v);
g[v].push_back(u);
}
scanf("%s%s", s + 1, t + 1);
for (int i = 1; i <= n; i++) rc[i] = s[i] - 'a';
for (int i = 1; i <= m; i++) s1.str[i] = t[i], s2.str[i] = t[m - i + 1];
s1.build(), s2.build();
sz_sum = n, get_rt(1, n);
solve(rt);
printf("%lld\n", ans);
return 0;
}
/*
3 5
1 2
1 3
aab
abaab
15
*/
标签:sz,cur,sam,int,题解,len,fa,CTSC2010,P4218
From: https://www.cnblogs.com/lingspace/p/p4218.html