先考虑怎么安排崛起的先后顺序最优。
但是发现好像没有一个很好的顺序去进行崛起,并且由于 \(a_i\) 的值域会很大,所以即使知道顺序应该也会难以进行维护。
转换一下方向,正难则反。考虑每个点的贡献,但是颜色不同时只会算一次,所以要钦定是哪一个点造成的贡献。令当前考虑的点为 \(u\),发现可以在不影响 \(u\) 的祖先的的贡献的情况下对 \(u\) 子树内的点的相对操作顺序进行改变。所以 \(u\) 点所产生的贡献是很容易计算的:若 \(u\) 以及 \(u\) 的子树内的所有点的 \(a\) 值都没有超过 \(sza_u\) 的一半,则贡献为 \(sza_u-1\)。其中 \(sza_u\) 是 \(u\) 及 \(u\) 的子树的 \(a\) 值的和,减一是因为第一次不会产生贡献。否则就是 \(2(sza_u-mxa_u)\),因为最大的 \(a\) 无法都产生贡献,注意这里的 \(mxa_u\) 需要和 \(a_u\) 取 \(\max\)。
这时候就可以有 \(30\) 分了。
放一份暴力代码:
#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define vi vector<int>
#define eb emplace_back
#define pii pair<int, ll>
#define fi first
#define se second
#define TIME 1e3 * clock() / CLOCKS_PER_SEC
bool Mbe;
mt19937_64 rng(35);
constexpr int N = 4e5 + 10;
int n, m;
int sz[N];
ll ans, a[N], sza[N];
vi e[N];
void dfs(int u, int ff) {
sz[u] = 1, sza[u] = a[u];
ll mx = a[u];
for(int v : e[u]) {
if(v == ff) continue;
dfs(v, u);
sz[u] += sz[v];
sza[u] += sza[v];
if(sza[v] > mx) mx = sza[v];
}
if(2 * mx > sza[u]) ans += 2 * (sza[u] - mx);
else ans += sza[u] - 1;
}
void work() {
ans = 0;
dfs(1, 0);
cout << ans << "\n";
}
bool Med;
int main() {
fprintf(stderr, "%.3lf MB\n", (&Mbe - &Med) / 1048576.0);
// freopen("history4.in", "r", stdin);
// freopen("history.out", "w", stdout);
ios :: sync_with_stdio(0);
cin.tie(0); cout.tie(0);
cin >> n >> m;
for(int i = 1; i <= n; ++i) cin >> a[i];
for(int i = 1; i < n; ++i) {
int u, v;
cin >> u >> v;
e[u].eb(v);
e[v].eb(u);
}
work();
for(int i = 1; i <= m; ++i) {
int x, w;
cin >> x >> w;
a[x] += w;
work();
}
cerr << TIME << "ms\n";
return 0;
}
然后考虑对这个暴力进行优化。
对于 \(u\to son\) 的边,若 \(sza_{son}>\dfrac{sza_u}{2}\),则称这条边为实边,否则为虚边。
考虑对 \(v\in\{\text{path}(1,u)\}\) 进行区间加后虚实边会有什么变化。发现一个很美妙的地方,就是整条路径上至多有 \(\log\sum a_i\) 条虚边。然后因为对于一条 \(fa\to x\) 的实边,\(sza_x\) 和 \(sza_{fa}\) 同时增加,\(fa\) 的带权的重儿子显然还是 \(x\)。所以可能发生变化的只有虚边。
于是每次操作至多会修改 \(\mathcal{O}(\log\sum a_i)\) 条边,即需要支持单点修改、查询区间中为 \(1\) 的数,线段树即可。
然后可能会略有卡常,可以将维护 \(sza\) 的数组换为 BIT。
还有一个实现的小技巧,就是因为带权重儿子可能是 \(u\) 本身,所以连一个 \(i\to i+n\) 的边即可,就能把 \(i\) 的点权转到 \(i+n\) 上了,这里借鉴了 _Jxsts 的代码。
代码:
#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define vi vector<int>
#define eb emplace_back
#define pii pair<int, ll>
#define fi first
#define se second
#define TIME 1e3 * clock() / CLOCKS_PER_SEC
bool Mbe;
mt19937_64 rng(35);
constexpr int N = 8e5 + 10;
int n, m, q;
int dfc, sz[N], hv[N], dfn[N], top[N], fa[N], dep[N], rdfn[N], leaf[N], hva[N];
ll ans, a[N], sza[N], val[N];
vi e[N];
void dfs1(int u, int ff) {
sz[u] = 1, sza[u] = a[u], dep[u] = dep[ff] + 1, fa[u] = ff;
for(int v : e[u]) {
if(v == ff) continue;
dfs1(v, u);
sz[u] += sz[v];
sza[u] += sza[v];
if(sz[v] > sz[hv[u]]) hv[u] = v;
if(sza[v] > sza[hva[u]]) hva[u] = v;
++leaf[u];
}
if(leaf[u] <= 1) leaf[u] = 1;
else leaf[u] = 0;
if(!leaf[u]) {
if(2 * sza[hva[u]] > sza[u]) val[u] = 2 * (sza[u] - sza[hva[u]]);
else val[u] = sza[u] - 1;
ans += val[u];
// ans += val[u] = min(2 * (sza[u] - sza[hva[u]]), sza[u] - 1);
}
}
void dfs2(int u, int f) {
rdfn[dfn[u] = ++dfc] = u, top[u] = f;
if(!hv[u]) return;
dfs2(hv[u], f);
for(int v : e[u]) {
if(v == hv[u] || v == fa[u]) continue;
dfs2(v, v);
}
}
ll v[N];
void add(int x, ll y) {
for(; x <= m; x += x & -x) v[x] += y;
}
ll ask(int x) {
ll res = 0;
for(; x; x -= x & -x) res += v[x];
return res;
}
ll ask(int l, int r) {
return ask(r) - ask(l - 1);
}
int sum[N << 2];
void build(int x, int L, int R) {
if(L == R) {
sum[x] = 2 * sza[rdfn[L]] <= sza[fa[rdfn[L]]];
return;
}
int m = (L + R) >> 1;
build(x << 1, L, m);
build(x << 1 | 1, m + 1, R);
sum[x] = sum[x << 1] + sum[x << 1 | 1];
}
void modify(int x, int L, int R, int k, int v) {
if(L == R) {
sum[x] = v;
return;
}
int m = (L + R) >> 1;
if(k <= m) modify(x << 1, L, m, k, v);
else modify(x << 1 | 1, m + 1, R, k, v);
sum[x] = sum[x << 1] + sum[x << 1 | 1];
}
int query(int x, int L, int R, int l, int r) {
if(!sum[x]) return -1;
if(L == R) return rdfn[L];
int m = (L + R) >> 1;
if(r > m) {
int v = query(x << 1 | 1, m + 1, R, l, r);
if(~v || (m > r)) return v;
}
return l <= m ? query(x << 1, L, m, l, r) : -1;
}
void upd(int x, ll v) {
int tmp = x;
while(x) {
add(dfn[top[x]], v), add(dfn[x] + 1, -v);
x = fa[top[x]];
}
x = tmp;
int r = dfn[x];
while(x) {
if(r < dfn[top[x]]) {
x = fa[top[x]];
r = dfn[x];
continue;
}
int u = query(1, 1, m, dfn[top[x]], r);
if(u == -1) {
x = fa[top[x]];
r = dfn[x];
continue;
}
ans -= val[fa[u]];
ll saf = ask(dfn[fa[u]]), sau = ask(dfn[u]), sahv = ask(dfn[hva[fa[u]]]);
if(u == hva[fa[u]]) {
if(2 * sahv > saf && 2 * (sahv - v) <= saf - v) modify(1, 1, m, dfn[u], 0);
if(2 * sahv > saf) val[fa[u]] = 2 * (saf - sahv);
else val[fa[u]] = saf - 1;
ans += val[fa[u]];
} else {
if(2 * sahv > saf - v && 2 * sahv <= saf) modify(1, 1, m, dfn[hva[fa[u]]], 1);
if(2 * sau > saf) modify(1, 1, m, dfn[u], 0);
if(sau > sahv) {
hva[fa[u]] = u;
if(sau * 2 > saf) val[fa[u]] = 2 * (saf - sau);
else val[fa[u]] = saf - 1;
} else {
if(sahv * 2 > saf) val[fa[u]] = 2 * (saf - sahv);
else val[fa[u]] = saf - 1;
}
ans += val[fa[u]];
}
r = dfn[fa[u]];
}
}
bool Med;
int main() {
fprintf(stderr, "%.3lf MB\n", (&Mbe - &Med) / 1048576.0);
// freopen("history.in", "r", stdin);
// freopen("history.out", "w", stdout);
ios :: sync_with_stdio(0);
cin.tie(0); cout.tie(0);
cin >> n >> q;
m = 2 * n;
for(int i = 1; i <= n; ++i) {
cin >> a[i + n];
e[i].eb(i + n);
e[i + n].eb(i);
}
for(int i = 1; i < n; ++i) {
int u, v;
cin >> u >> v;
e[u].eb(v);
e[v].eb(u);
}
dfs1(1, 0);
cout << ans << "\n";
dfs2(1, 1);
build(1, 1, m);
for(int i = 1; i <= m; ++i) add(i, sza[rdfn[i]]), add(i + 1, -sza[rdfn[i]]);
// q = 1; // ...65
for(int i = 1; i <= q; ++i) {
int x, w;
cin >> x >> w;
upd(x + n, w);
cout << ans << "\n";
}
cerr << TIME << "ms\n";
return 0;
}
标签:val,int,题解,sza,saf,fa,P4338,ZJOI2018,define
From: https://www.cnblogs.com/Pengzt/p/17929879.html