没做完。
A. Almost Prefix Concatenation
给定字符串 \(S,T\)。称一个串是好的,当且仅当可以通过修改不超过一个字符使其成为 \(T\) 的前缀。
称一个把 \(S\) 划分成 \(n\) 个非空子串 \(S_1,S_2,\cdots,S_n\) 的方案是合法的,当且仅当对于任意 \(1 \le i \le n\),串 \(S_i\) 都是好的。定义合法划分方案的权值为 \(n^2\)。
求所有合法划分方案的权值和对 \(998244353\) 取模后的值。
\(1 \le |S|,|T| \le 10^6\)。
预处理出从每个位置开始最长的合法串长度。这等价于先求一次 LCP,跳过下一位,然后再求一次 LCP,可以直接二分哈希。计数是简单的,考虑 \((x+1)^2 = x^2+2x+1\),我们直接 DP 出每一项的系数即可。转移可以使用前缀和优化,时间复杂度 \(\mathcal{O}(n \log n)\)。
code
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
typedef pair <int, int> pi;
#define fi first
#define se second
constexpr int N = 1e6 + 5, mod = 998244353;
constexpr int B = 233, mod0 = 1e9 + 7, mod1 = 1e9 + 9;
bool Mbe;
void add(int &x, int y) {
x = x + y >= mod ? x + y - mod : x + y;
}
int n, m, h[N][2], sval[N][2], tval[N][2];
char s[N], t[N];
int f[N][3], sum[N][3], R[N];
pi get_hash(int val[N][2], int l, int r) {
pi res;
res.fi = (val[r][0] + mod0 - 1LL * val[l - 1][0] * h[r - l + 1][0] % mod0) % mod0;
res.se = (val[r][1] + mod1 - 1LL * val[l - 1][1] * h[r - l + 1][1] % mod1) % mod1;
return res;
}
int lcp(int i, int j) {
if (i > n || j > m) return 0;
int l = 1, r = min(n - i + 1, m - j + 1);
int res = 0;
while (l <= r) {
int mid = l + r >> 1;
if (get_hash(sval, i, i + mid - 1) == get_hash(tval, j, j + mid - 1)) l = mid + 1, res = mid;
else r = mid - 1;
}
return res;
}
bool Med;
int main() {
// fprintf(stderr, "%.9lf\n", 1.0 * (&Mbe - &Med) / 1048576.0);
ios :: sync_with_stdio(false);
cin.tie(0), cout.tie(0);
cin >> s + 1;
cin >> t + 1;
n = strlen(s + 1), m = strlen(t + 1);
h[0][0] = h[0][1] = 1;
for (int i = 1; i <= n; i++) {
sval[i][0] = (1LL * sval[i - 1][0] * B % mod0 + (s[i] - 'a')) % mod0;
sval[i][1] = (1LL * sval[i - 1][1] * B % mod1 + (s[i] - 'a')) % mod1;
tval[i][0] = (1LL * tval[i - 1][0] * B % mod0 + (t[i] - 'a')) % mod0;
tval[i][1] = (1LL * tval[i - 1][1] * B % mod1 + (t[i] - 'a')) % mod1;
h[i][0] = 1LL * h[i - 1][0] * B % mod0;
h[i][1] = 1LL * h[i - 1][1] * B % mod1;
}
for (int i = 1; i <= n; i++) {
int l1 = 0, l2 = 0;
l1 = lcp(i, 1);
if (i + l1 - 1 == n || l1 == m) R[i] = i + l1 - 1;
else {
l2 = lcp(i + l1 + 1, 1 + l1 + 1);
R[i] = i + l1 + 1 + l2 - 1;
}
}
// for (int i = 1; i <= n; i++) cout << i << " " << R[i] << "\n";
f[n + 1][0] = sum[n + 1][0] = 1;
for (int i = n; i >= 1; i--) {
int s0 = (sum[i + 1][0] + mod - sum[R[i] + 2][0]) % mod;
int s1 = (sum[i + 1][1] + mod - sum[R[i] + 2][1]) % mod;
int s2 = (sum[i + 1][2] + mod - sum[R[i] + 2][2]) % mod;
f[i][0] = s0;
f[i][1] = (s0 + s1) % mod;
f[i][2] = (1LL * s0 + 2LL * s1 % mod + 1LL * s2) % mod;
for (int j = 0; j < 3; j++)
sum[i][j] = (sum[i + 1][j] + f[i][j]) % mod;
}
cout << f[1][2] << "\n";
// cerr << 1e3 * clock() / CLOCKS_PER_SEC << "ms\n";
return 0;
}
/*
ababaab
aba
*/
B. Palindromic Beads
给定一颗 \(n\) 个点的树,点 \(i\) 有元素 \(c_i\)。定义一条路径的权值为路径上元素的最长回文子序列的长度。求最大权值。
\(1 \le n \le 2 \times 10^5\),保证每种元素出现不超过 \(2\) 次。
先不考虑单点。把所有链按照长度排序,这样可以发现一条链只可能被前面的链包含。考虑直接 DP,设 \(f_i\) 为考虑了前 \(i\) 条链的答案,转移找到所有包含 \(i\) 的链 \(j\),令 \(f_i \gets f_j + 2\)。然后可以发现对于一条链 \(i\),包含它的链 \(j\) 在 DFS 序上是一个矩形,所以其实就是单点修改矩形求 \(\max\),树套树维护即可。
最后考虑单点,它只可能在最后出现一次。容易将包含点 \(u\) 的链拆分成 \(\mathcal{O}(deg_u)\) 个 DFS 序上的矩形,直接查就行了。总时间复杂度 \(\mathcal{O}(n \log^2 n)\)。
code
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
typedef pair <int, int> pi;
#define fi first
#define se second
constexpr int N = 2e5 + 5, M = N * 200;
bool Mbe;
void chkmx(int &x, int y) {
x = x > y ? x : y;
}
int n, c[N], st[N], ed[N], tim, dep[N], f[N], g[N], ans;
vector <int> e[N], v[N];
namespace SGT2 {
#define m ((l + r) >> 1)
int tot, root[N << 2], tr[M], ls[M], rs[M];
void mdf(int &x, int l, int r, int p, int v) {
if (!x) x = ++tot;
chkmx(tr[x], v);
if (l == r) return;
if (p <= m) mdf(ls[x], l, m, p, v);
else mdf(rs[x], m + 1, r, p, v);
}
int qry(int x, int l, int r, int ql, int qr) {
if (ql <= l && qr >= r) return tr[x];
if (qr <= m) return qry(ls[x], l, m, ql, qr);
if (ql > m) return qry(rs[x], m + 1, r, ql, qr);
return max(qry(ls[x], l, m, ql, qr), qry(rs[x], m + 1, r, ql, qr));
}
#undef m
}
namespace SGT1 {
#define m ((l + r) >> 1)
void mdf(int x, int l, int r, int p, int y, int z) {
SGT2 :: mdf(SGT2 :: root[x], 1, n, y, z);
if (l == r) return;
if (p <= m) mdf(x << 1, l, m, p, y, z);
else mdf(x << 1 | 1, m + 1, r, p, y, z);
}
int qry(int x, int l, int r, int L1, int R1, int L2, int R2) {
if (L1 <= l && R1 >= r) return SGT2 :: qry(SGT2 :: root[x], 1, n, L2, R2);
if (R1 <= m) return qry(x << 1, l, m, L1, R1, L2, R2);
if (L1 > m) return qry(x << 1 | 1, m + 1, r, L1, R1, L2, R2);
return max(qry(x << 1, l, m, L1, R1, L2, R2), qry(x << 1 | 1, m + 1, r, L1, R1, L2, R2));
}
int qry(int L1, int R1, int L2, int R2) {
if (L1 > R1 || L2 > R2) return 0;
return qry(1, 1, n, L1, R1, L2, R2);
}
#undef m
}
void dfs(int u, int ff) {
for (auto it = e[u].begin(); it != e[u].end(); it++)
if (*it == ff) {
e[u].erase(it);
break;
}
st[u] = ++tim, dep[u] = dep[ff] + 1;
for (auto v : e[u])
if (v != ff) dfs(v, u);
ed[u] = tim;
}
int climb(int x, int y) {
return *prev(upper_bound(e[y].begin(), e[y].end(), x, [&](int x, int y) {
return st[x] < st[y];
}));
}
void upd(int x) {
// cerr << x << "!\n";
chkmx(ans, 1 + SGT1 :: qry(1, st[x], st[x], ed[x]));
for (auto y : e[x]) chkmx(ans, 1 + SGT1 :: qry(st[y], ed[y], ed[y] + 1, n));
}
bool Med;
int main() {
// fprintf(stderr, "%.9lf\n", 1.0 * (&Mbe - &Med) / 1048576.0);
ios :: sync_with_stdio(false);
cin.tie(0), cout.tie(0);
cin >> n;
for (int i = 1; i <= n; i++) cin >> c[i], v[c[i]].emplace_back(i);
for (int i = 1; i < n; i++) {
int x, y;
cin >> x >> y;
e[x].emplace_back(y);
e[y].emplace_back(x);
}
dfs(1, 0);
for (int i = 1; i <= n; i++) {
f[i] = i;
sort(v[i].begin(), v[i].end(), [&](int x, int y) {
return st[x] < st[y];
});
for (auto it : v[i]) g[i] = max(g[i], dep[it]);
}
sort(f + 1, f + n + 1, [&](int x, int y) {
return g[x] > g[y];
});
// for (int i = 1; i <= n; i++) cerr << f[i] << " \n"[i == n];
// for (int i = 1; i <= n; i++) cerr << st[i] << " " << ed[i] << "\n";
for (int i = 1; i <= n; i++) {
int cur = f[i];
if (!g[cur] || v[cur].size() < 2) continue;
int x = v[cur][0], y = v[cur][1];
// cerr << x << " " << y << "\n";
int val = 0;
upd(x), upd(y);
if (st[x] <= st[y] && ed[y] <= ed[x]) {
int r = climb(y, x);
chkmx(val, SGT1 :: qry(1, st[r] - 1, st[y], ed[y]));
chkmx(val, SGT1 :: qry(st[y], ed[y], ed[r] + 1, n));
} else chkmx(val, SGT1 :: qry(st[x], ed[x], st[y], ed[y]));
val += 2;
SGT1 :: mdf(1, 1, n, st[x], st[y], val);
chkmx(ans, val);
}
for (int i = 1; i <= n; i++)
if (v[c[i]].size() == 1) upd(i);
cout << ans << "\n";
// cerr << 1e3 * clock() / CLOCKS_PER_SEC << "ms\n";
return 0;
}
/*
4
1 1 2 2
1 2
2 3
2 4
*/
H. Hurricane
给定一张 \(n\) 个点 \(m\) 条边的无向图,所有边边权均为 \(1\)。
对 \(k = 1 \sim n\) 求出有多少无序二元组 \((u,v)\) 满足在补图中 \(u\) 到 \(v\) 的最短距离恰为 \(k\)。
\(2 \le n \le 10^5\),\(0 \le m \le 2 \times 10^5\)。
注意到,对于两个点 \(u\) 和 \(v\),如果满足 \(deg_u + deg_v < n\),那么一定有 \(dis(u, v) ≤ 2\),这是因为补图中一定存在至少一个点与 \(u\) 和 \(v\) 均相连。
考虑从所有满足 \(2deg_u ≥ n\) 的 \(u\) 出发跑一次单源补图最短路,这样的 \(u\) 只有 \(\mathcal{O}(\frac{m}{n}) \le \mathcal{O}(\sqrt m)\) 个,其余所有点对之间的最短路长度都不超过 \(2\),并且当且仅当两个点之间有边直接相连时,两点之间的最短路长度为 \(2\)。总时间复杂度 \(\mathcal{O}( (n+m)\sqrt {m})\)。
code
include <bits/stdc++.h>
using namespace std;
typedef long long LL;
typedef pair <int, int> pi;
#define fi first
#define se second
constexpr int N = 1e5 + 5, mod = 998244353;
bool Mbe;
int n, m, vis[N], dis[N], deg[N]; LL ans[N];
vector <int> e[N];
void bfs(int s) {
vector <int> cur;
queue <int> q;
for (int i = 1; i <= n; i++) {
if (i == s) q.push(i), dis[i] = 0;
else cur.emplace_back(i), dis[i] = n + 1;
vis[i] = 0;
}
while (!q.empty()) {
int u = q.front();
q.pop();
for (auto v : e[u]) vis[v] = u;
for (int i = 0; i < cur.size(); i++) {
int v = cur[i];
if (vis[v] == u) continue;
dis[v] = dis[u] + 1;
q.push(v);
swap(cur[i], cur.back()), cur.pop_back(), i--;
}
}
}
bool Med;
int main() {
// fprintf(stderr, "%.9lf\n", 1.0 * (&Mbe - &Med) / 1048576.0);
ios :: sync_with_stdio(false);
cin.tie(0), cout.tie(0);
cin >> n >> m;
for (int i = 1; i <= m; i++) {
int u, v;
cin >> u >> v;
e[u].emplace_back(v);
e[v].emplace_back(u);
deg[u]++;
deg[v]++;
}
ans[1] = 1LL * n * (n - 1) / 2 - m;
for (int u = 1; u <= n; u++) {
if (deg[u] >= n / 2) {
bfs(u);
for (auto v : e[u]) {
if (deg[v] >= n / 2 && v < u) continue;
ans[dis[v]]++;
}
} else {
for (auto v : e[u]) {
if (deg[v] >= n / 2 || v < u) continue;
ans[2]++;
}
}
}
for (int i = 1; i < n; i++) cout << ans[i] << " ";
// cerr << 1e3 * clock() / CLOCKS_PER_SEC << "ms\n";
return 0;
}
/*
4 2
1 2
3 4
*/
L. Partially Free Meal
给定长为 \(n\) 的序列 \(a,b\)。对于子序列 \(p_1,p_2,\cdots,p_k\),定义其代价为 \(\sum \limits_{i=1}^k a_{p_i} + \max\limits_{i=1}^k b_{p_i}\)。
对 \(k = 1 \sim n\) 求出所有长为 \(k\) 的子序列的最小代价。
\(1 \le n \le 10^5\),\(1 \le a_i,b_i \le 10^9\)。
首先考虑如何计算固定的 \(k\) 的答案。将所有元素按照 \(b\) 从小到大排序,枚举第 \(x (k ≤ x ≤n)\) 个元素作为选中的 \(b\) 最大的元素,那么剩下的 \(k − 1\) 个显然是贪心选择前 \(x − 1\) 个元素中 \(a\) 最小的 \(k − 1\) 个。\(k\) 固定的情况可以用堆维护,如果多次给定 \(k\) 和 \(x\),可以通过可持久线段树 \(\mathcal{O}(\log n)\) 求出对应方案的值 \(w(k, x)\)。
接下来的想法很牛。令 \(f(k)\) 表示使 \(k\) 取到最优解的 \(x\)。对于两个不同的决策 \(x, y (x < y)\),若 \(w(k, x) ≥ w(k, y)\),那么增大 \(k\) 之后由于 \(y\) 的可选择范围严格包含了 \(x\) 的可选择范围,因此 \(y\) 新选的 \(a\) 值一定不大于 \(x\) 所选的,即 \(w(k', x) ≥ w(k', y)\) 对于 \(k ≤ k' ≤ n\) 恒成立。由此可得 \(f(1) ≤ f(2) \cdots ≤ f(n)\),即最优决策具有单调性,分治即可做到 \(\mathcal{O}(n \log^2 n)\)。
code
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
typedef pair <int, int> pi;
#define fi first
#define se second
constexpr int N = 2e5 + 5, M = N << 5;
constexpr LL inf = 1e18;
bool Mbe;
int n, m; LL val[N], ans[N];
struct dat {
LL a, b;
bool operator < (const dat &p) const {
return b < p.b;
}
} c[N];
#define m ((l + r) >> 1)
int root[N], tot, ls[M], rs[M], cnt[M]; LL sum[M];
int build(int l, int r) {
int x = ++tot;
if (l == r) return x;
ls[x] = build(l, m);
rs[x] = build(m + 1, r);
return x;
}
int mdf(int pre, int l, int r, int p) {
int x = ++tot;
ls[x] = ls[pre], rs[x] = rs[pre], cnt[x] = cnt[pre], sum[x] = sum[pre];
if (l == r) {
cnt[x]++;
sum[x] += val[p];
return x;
}
if (p <= m) ls[x] = mdf(ls[pre], l, m, p);
else rs[x] = mdf(rs[pre], m + 1, r, p);
cnt[x] = cnt[ls[x]] + cnt[rs[x]];
sum[x] = sum[ls[x]] + sum[rs[x]];
return x;
}
LL qry(int x, int l, int r, int k) {
if (!k) return 0;
if (l == r) return 1LL * k * val[l];
if (cnt[ls[x]] > k) return qry(ls[x], l, m, k);
return qry(rs[x], m + 1, r, k - cnt[ls[x]]) + sum[ls[x]];
}
#undef m
LL qry(int x, int k) {
if (x < k) return inf;
return qry(root[x], 1, m, k);
}
void solve(int l, int r, int ql, int qr) {
if (l > r || ql > qr) return;
if (ql == qr) {
for (int i = l; i <= r; i++) ans[i] = c[ql].b + qry(ql, i);
return;
}
LL res = inf, pos = ql;
int mid = l + r >> 1;
for (int i = ql; i <= qr; i++) {
LL val = c[i].b + qry(i, mid);
if (val < res) res = val, pos = i;
}
ans[mid] = res;
solve(l, mid - 1, ql, pos);
solve(mid + 1, r, pos, qr);
}
bool Med;
int main() {
// fprintf(stderr, "%.9lf\n", 1.0 * (&Mbe - &Med) / 1048576.0);
ios :: sync_with_stdio(false);
cin.tie(0), cout.tie(0);
cin >> n;
for (int i = 1; i <= n; i++) {
cin >> c[i].a >> c[i].b;
val[i] = c[i].a;
}
sort(c + 1, c + n + 1);
sort(val + 1, val + n + 1);
m = unique(val + 1, val + n + 1) - val - 1;
root[0] = build(1, n);
for (int i = 1; i <= n; i++) {
int pos = lower_bound(val + 1, val + m + 1, c[i].a) - val;
root[i] = mdf(root[i - 1], 1, m, pos);
}
solve(1, n, 1, n);
for (int i = 1; i <= n; i++) cout << ans[i] << "\n";
// cerr << 1e3 * clock() / CLOCKS_PER_SEC << "ms\n";
return 0;
}
/*
3
2 5
4 3
3 7
*/