思路
SAM + 树剖。
好仙的题啊,做了一天。
令 \(\operatorname{lcs}(i, j)\) 表示长度为 \(i, j\) 的前缀的最长公共后缀长度,则题目中的 border 可以等价转化成:求最大且满足:
-
\(l \leq p \leq r\)
-
\(\operatorname{lcs}(p, r) \geq p - l\)
其中 \(\operatorname{lcs}(p, r)\) 正是它们对应结点在后缀树上 LCA 的 \(len\) 值。
再加上字符串区间查询,可以想象到有一个高妙的 SAM 做法。
首先考虑一下暴力做法。
假设我们把问题直接暴力挂在从根到 SAM 中结点的路径上。
从其中一个结点一直向上跳枚举 LCA,然后找到与该结点 LCA 为当前点的所有结点,尝试更新答案。
当我们跳到一个结点时,相当于确定了 \(\operatorname{lcs}\) 的长度,所以对于当前点,我们只需要查询最大的 \(p \leq l + len - 1\) 的合法的 \(p\).
这个可以线段树合并直接维护。
然后接下来考虑优化上面那个做法。
首先确定问题转化成:
每次假设已经确定树上的一个结点 \(u\) 和常数 \(x\),其中 \(x = l + len - 1\)。现在要求出所有满足 \(\operatorname{lca}(u, v) = k\) 并且 \(v \leq l + len_k - 1\) 的编号最大的 \(v\),其中 \(k\) 是当前枚举的 LCA.
这里有一个比较仙的想法:把询问划分到 \(O(\log n)\) 条重链上。
换言之对于每条重链上的每一个点,考虑它的子树对它的贡献。
假设从结点 \(u\) 进入当前重链,重链从浅到深依次为 \(p_1, ..., p_m\).
对于比 \(u\) 浅的部分,它们和 \(u\) 的 LCA 为其自身;对于比 \(u\) 深的部分,它们和 \(u\) 的 LCA 为 \(u\).
可以预处理出 \(dis(i)\) 表示点 \(i\) 与当前跳到的结点在树上 LCA 处的 \(len\) 值。
假设现在要考虑结点 \(u\) 处的询问,\(u\) 在 \(p\) 中的下标为 \(k\). 令 \(subt(u)\) 为结点 \(u\) 的子树,则答案可以划分成两部分贡献:
-
\(\max_{i = 1}^{k - 1} \max\limits_{v \in subt(p_i)} [l \leq v \leq r] [v \leq l + dis(v) - 1]\)
-
\(\max_{i = k}^{m} \max\limits_{v \in subt(p_i)} [l \leq v \leq r] [v \leq l + dis(u) - 1]\)
上面的划分实际上等价于划分成:
-
重链链首的子树中除点 \(u\) 的重子树外的部分
-
结点 \(u\) 的重子树
可以考虑对整个重链上下跑两次扫描线再线段树合并维护。
对于第一个限制可以直接线段树区间查询。
对于第二个限制可以考虑把所有 \(v\) 减去 \(dis(u)\) 再扔进线段树。
从 \(O(\log n)\) 条重链的低端进入,每条重链暴力遍历的复杂度是 \(O(n \log n)\),所以时间复杂度是 \(O(n \log^2 n)\)
注意到询问只需要挂在重链的最底端即可,所以空间复杂度是 \(O(n \log n)\)
代码
#include <cstdio>
#include <cmath>
#include <cstring>
// #include <vector>
#include <algorithm>
// using namespace std;
namespace IO
{
//by cyffff
int len = 0;
char ibuf[(1 << 20) + 1], *iS, *iT, out[(1 << 26) + 1];
#define gh() (iS == iT ? iT = (iS = ibuf) + fread(ibuf, 1, (1 << 20) + 1, stdin), (iS == iT ? EOF : *iS++) : *iS++)
#define reg register
inline int read()
{
reg char ch = gh();
reg int x = 0;
reg char t = 0;
while (ch < '0' || ch > '9') t |= ch == '-', ch = gh();
while (ch >= '0' && ch <= '9') x = x * 10 + (ch ^ 48), ch = gh();
return t ? -x : x;
}
inline void putc(char ch) { out[len++] = ch; }
template<class T>
inline void write(T x)
{
if (x < 0) putc('-'), x = -x;
if (x > 9) write(x / 10);
out[len++] = x % 10 + 48;
}
inline void flush()
{
fwrite(out, 1, len, stdout);
len = 0;
}
}
using IO::read;
using IO::write;
using IO::flush;
using IO::putc;
#define reg
typedef long long ll;
const int maxn = 3e5 + 1;
const int maxm = 3e5 + 1;
const int sz = 2e3 + 1;
struct Modify
{
int t, pos, val;
} q1[sz];
struct Query
{
int t, l, r, x, idx;
} q2[sz];
struct node
{
int idx, res, t[4];
bool flag;
inline void clear() { t[0] = t[1] = t[2] = t[3] = flag = 0; }
} stk[sz], tmp;
struct Edge
{
int to, nxt;
} edge[maxn];
int n, m, sqn, sqm, tot, c1, c2, top, ecnt;
int st[sz], ed[sz];
int head[maxn], a[maxn], bel[maxn], pos[maxn];
bool vis[maxn], chg[maxn], used[maxn], seq[maxn];
ll gt[maxn];
ll ans[sz], sum[sz];
// vector<int> idx[maxn];
inline int min(reg const int &a, reg const int &b) { return (a <= b ? a : b); }
inline bool cmp(const Query& x, const Query& y) { return (x.x < y.x); }
inline void modify(reg const int& p, reg bool flag)
{
pos[p] = p, seq[p] = true, tmp.idx = p;
reg const int lst = p - 1, nxt = p + 1, tmp_p = pos[lst];
reg const bool flag1 = (seq[lst] && (st[bel[p]] != p)), flag2 = (seq[nxt] && (ed[bel[p]] != p));
if ((!flag1) && (!flag2)) tmp.clear(), tmp.res = 1;
else
{
tmp.flag = true;
if (flag1 && flag2)
{
tmp.res = (nxt - pos[lst]) * (pos[nxt] - lst);
tmp.t[0] = pos[lst], tmp.t[1] = pos[pos[lst]], pos[pos[lst]] = pos[nxt];
tmp.t[2] = pos[nxt], tmp.t[3] = pos[pos[nxt]], pos[pos[nxt]] = tmp_p;
}
else if(flag1)
{
tmp.res = nxt - pos[lst];
tmp.t[0] = p, tmp.t[1] = pos[p], pos[p] = pos[lst];
tmp.t[2] = pos[lst], tmp.t[3] = pos[pos[lst]], pos[pos[lst]] = p;
}
else
{
tmp.res = pos[nxt] - lst;
tmp.t[0] = p, tmp.t[1] = pos[p], pos[p] = pos[nxt];
tmp.t[2] = pos[nxt], tmp.t[3] = pos[pos[nxt]], pos[pos[nxt]] = p;
}
}
sum[bel[p]] += tmp.res;
if (flag) stk[++top] = tmp;
}
inline ll query(const int& l, const int& r)
{
reg ll ans = 0;
if (bel[l] == bel[r])
{
reg int cnt = 0;
for (reg int i = l; i <= r; i++)
if (seq[i]) cnt++;
else ans += gt[cnt], cnt = 0;
return ans + gt[cnt];
}
reg int cnt1 = 0, cnt2 = 0;
for (reg int i = l; i <= ed[bel[l]]; i++)
if (seq[i]) cnt1++;
else ans += gt[cnt1], cnt1 = 0;
for (reg int i = r; i >= st[bel[r]]; i--)
if (seq[i]) cnt2++;
else ans += gt[cnt2], cnt2 = 0;
reg int res = cnt1;
for (reg int i = bel[l] + 1; i <= bel[r] - 1; i++)
{
if (pos[st[i]] == ed[i]) res += ed[i] - st[i] + 1;
else
{
if (seq[st[i]]) res += pos[st[i]] - st[i] + 1, ans -= gt[pos[st[i]] - st[i] + 1];
ans += gt[res] + sum[i], res = 0;
if (seq[ed[i]]) res += ed[i] - pos[ed[i]] + 1, ans -= gt[ed[i] - pos[ed[i]] + 1];
}
}
return ans + gt[res + cnt2];
}
inline void add_edge(const int& u, const int& v)
{
edge[++ecnt] = (Edge){v, head[u]};
head[u] = ecnt;
}
inline void solve()
{
memset(seq, false, sizeof(seq));
memset(pos, 0, sizeof(pos));
memset(sum, 0, sizeof(sum));
for (reg int i = 1; i <= c1; i++) chg[q1[i].pos] = true;
for (reg int i = 1; i <= n; i++)
if (!chg[i]) add_edge(a[i], i);
std::sort(q2 + 1, q2 + c2 + 1, cmp);
reg int lim = 1;
for (reg int i = 1; i <= c2; i++)
{
while (lim <= q2[i].x)
{
for (reg int& j = head[lim]; j; j = edge[j].nxt) modify(edge[j].to, 0);
lim++;
}
for (reg int j = c1; j >= 1; j--)
if ((q1[j].t < q2[i].t) && (!used[q1[j].pos]))
{
used[q1[j].pos] = true;
if (q1[j].val <= q2[i].x) modify(q1[j].pos, 1);
}
for (reg int j = 1; j <= c1; j++)
if (!used[q1[j].pos])
{
used[q1[j].pos] = true;
if (a[q1[j].pos] <= q2[i].x) modify(q1[j].pos, 1);
}
ans[q2[i].idx] = query(q2[i].l, q2[i].r);
while (top)
{
tmp = stk[top--], sum[bel[tmp.idx]] -= tmp.res, seq[tmp.idx] = false;
if (tmp.flag) pos[tmp.t[2]] = tmp.t[3], pos[tmp.t[0]] = tmp.t[1];
}
for (int j = 1; j <= c1; j++) used[q1[j].pos] = false;
}
ecnt = 0;
memset(head, 0, sizeof(head));
// for (int i = 0; i <= n; i++) idx[i].clear();
for (int i = 1; i <= c1; i++) chg[q1[i].pos] = false;
}
int main()
{
n = read(), m = read(), sqn = 516, sqm = 1821;
tot = (n + sqn - 1) / sqn;
for (reg int i = 1; i <= tot; i++) st[i] = ed[i - 1] + 1, ed[i] = (i == tot ? n : i * sqn);
for (reg int i = 1; i <= n; i++) a[i] = read(), bel[i] = (i - 1) / sqn + 1, gt[i] = 1ll * i * (i + 1) / 2;
for (reg int i = 1, j; i <= m; i = j + 1)
{
j = min(m, i + sqm), c1 = c2 = 0;
for (reg int k = i; k <= j; k++)
if (read() == 1) q1[++c1] = (Modify){k, read(), read()};
else q2[++c2] = (Query){k, read(), read(), read(), c2};
solve();
for (reg int k = 1; k <= c2; k++) write(ans[k]), putc('\n');
for (reg int k = 1; k <= c1; k++) a[q1[k].pos] = q1[k].val;
}
flush();
return 0;
}
标签:sz,结点,P4482,int,题解,BJWC2018,len,leq,maxn
From: https://www.cnblogs.com/lingspace/p/p4482.html