生日蛋糕(birth)
伤心题。。。
题意
\(n\) 个点的树,第 \(i\) 个点有点权 \(1\le a_i\le m\)。
对于每个 \(i\) 满足 \(1\le i\le m\),求出连通块内点权最大值为 \(i\) 的个数,对 \(998244353\) 取模。
题解
现场得分:10/100(调了 3h 没调出来,标记合并挂了。。。)
首先这道题与 Minimax 类似。但那道题只用考虑包含根节点的连通块,而这道题则需要考虑不包含根的连通块。
考虑对于一个点 \(i\),记 \(f_{i,x}\) 表示 \(i\) 子树内,包含 \(i\) 且点权最大值为 \(x\) 的连通块个数,同时考虑多记一个值表示不包含 \(i\) 且最大值为 \(x\) 的连通块个数 \(g_{i,x}\)。
那么一样考虑线段树合并,考虑将子树根的线段树 \(rt_i\) 和儿子 \(rt_son\) 合并。
- 若 \(rt_i=0\):
\(f_{rt_i,x}=f_{rt_son,x} \times \sum\limits_{j<x} f_{rt_i,j}\),\(\sum\limits_{j<x} f_{rt_i,j}\) 合并的时候维护下就可以了;
\(g_{rt_i,x}=g_{rt_son,x}\)。 - 若 \(rt_son=0\):
\(f_{rt_i,x}=f_{rt_i,x} \times (\sum\limits_{j<x} f_{rt_son,j} + 1)\),\(\sum\limits_{j<x} f_{rt_son,j}\) 合并的时候维护下就可以了。
\(g_{rt_i,x}\) 不变。 - 若为叶子:
\(f_{rt_i,x}=f_{rt_i,x} \times (\sum\limits_{j<=x} f_{rt_son,j} + 1) + f_{rt_son, x} \times \sum\limits_{j<x} f_{rt_son,j}\)。
\(g_{rt_i,x} = g_{rt_i, x} + f_{rt_son,x} + g_{rt_son,x}\)。
对于 1. 2. 考虑相当于实现一个线段树支持如下操作:
- 维护两个值 \(f,g\);
- 支持对 \(f\) 区间乘;
- 支持对 \(g\) 区间对位加,即 \(g_i+=f_i\);
- 单点查询。
考虑维护标记 \(tag_1\) 和 \(tag_2\),表示先 \(g_i += tag_2f_i\),再 \(f_i = tag_1f_i\)。
考虑如何合并标记:
初始时标记为 \(tag_1,tag_2\) 新加入标记为 \(tag'_1,tag'_2\),那么新的标记为 \(tag_1\times tag'_1,tag_2+tag_1\times tag'_2\)。
PS:比赛是一直没有意识到要 \(+tag_1\times tag'_2\),一直没有冷静分析对,实际上是很明显的。
代码
#include <bits/stdc++.h>
#define SZ(x) (int) x.size() - 1
#define all(x) x.begin(), x.end()
#define ms(x, y) memset(x, y, sizeof x)
#define F(i, x, y) for (int i = (x); i <= (y); i++)
#define DF(i, x, y) for (int i = (x); i >= (y); i--)
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
template <typename T> void chkmax(T &x, T y) {x = max(x, y); }
template <typename T> void chkmin(T &x, T y) {x = min(x, y); }
template <typename T> void read(T &x) {
x = 0; int f = 1; char c = getchar();
for (; !isdigit(c); c = getchar()) if (c == '-') f = -f;
for (; isdigit(c); c = getchar()) x = x * 10 + c - '0';
x *= f;
}
template <typename T> void write(T x) {
T l = 0;
ull y = 0;
if (!x) { putchar(48); return; }
if (x < 0) { x = -x; putchar('-'); }
while (x) { y = y * 10 + x % 10; x /= 10; ++l; }
while (l) { putchar(y % 10 + 48); y /= 10; --l; }
}
template <typename T> void writeln(T x) { write(x); puts(""); }
template <typename T> void writes(T x) { write(x); putchar(32); }
const int N = 2e5 + 10, MOD = 998244353;
int n, m, rt[N], tot, tmp[N], seg[N * 40], seg2[N * 40], ls[N * 40], rs[N * 40], ans, tag[N * 40], a[N], tag2[N * 40];//, tag3[N * 40];
vector <int> v[N];
void down2(int num, int x) {
tag2[num] = (tag2[num] + (ll) x * tag[num]) % MOD;
seg2[num] = (seg2[num] + (ll) x * seg[num]) % MOD;
}
void down(int num, int x) {
tag[num] = (ll) tag[num] * x % MOD;
seg[num] = (ll) seg[num] * x % MOD;
}
void pushdown(int num) {
if (tag2[num]) {
down2(ls[num], tag2[num]); down2(rs[num], tag2[num]);
tag2[num] = 0;
}
if (tag[num] != 1) {
down(ls[num], tag[num]); down(rs[num], tag[num]);
tag[num] = 1;
}
}
void pushup(int num) {
seg[num] = (seg[ls[num]] + seg[rs[num]]) % MOD;
seg2[num] = (seg2[ls[num]] + seg2[rs[num]]) % MOD;
}
void insert(int &num, int l, int r, int x, int y) {
num = ++tot, tag[num] = 1;
if (l == r) return seg[num] = y, void();
int mid = (l + r) >> 1;
if (mid >= x) insert(ls[num], l, mid, x, y);
else insert(rs[num], mid + 1, r, x, y);
pushup(num);
}
void mer(int &x, int y, int a, int b, int l, int r) {
if (!x && !y) return;
if (!y) return down(x, a + 1), void();
if (!x) return x = y, down2(x, 1), down(x, b), void();
if (l == r) {
seg2[x] = ((ll) seg2[x] + seg2[y] + seg[y]) % MOD;
seg[x] = ((ll) seg[x] * (a + seg[y] + 1) % MOD + (ll) seg[y] * b % MOD) % MOD;
return;
}
pushdown(x); pushdown(y);
int mid = (l + r) >> 1;
int lx = seg[ls[x]], rx = seg[rs[x]], ly = seg[ls[y]], ry = seg[rs[y]];
mer(ls[x], ls[y], a, b, l, mid);
mer(rs[x], rs[y], (a + ly) % MOD, (b + lx) % MOD, mid + 1, r);
pushup(x);
}
void query(int num, int l, int r) {
if (l == r) {
writes((seg[num] + seg2[num]) % MOD);
return;
} pushdown(num);
int mid = (l + r) >> 1;
query(ls[num], l, mid); query(rs[num], mid + 1, r);
}
void dfs(int x, int fa) {
insert(rt[x], 1, m, a[x], 1);
for (int i: v[x])
if (i != fa) {
dfs(i, x);
mer(rt[x], rt[i], 0, 0, 1, m);
}
}
signed main() {
freopen("birth.in", "r", stdin);
freopen("birth.out", "w", stdout);
read(n); read(m);
F(i, 1, n) read(a[i]);
F(i, 2, n) {
int x, y; read(x); read(y);
v[x].push_back(y);
v[y].push_back(x);
} dfs(1, 0);
query(rt[1], 1, m);
return 0;
}
标签:rt,int,void,tag,T1,seg,num,20230129,birth
From: https://www.cnblogs.com/zhaohaikun/p/17079956.html