比较喜欢的题目。
考虑合法的条件,从点亮的灯的角度难以维护。反过来看,从未点亮的灯角度考虑,条件相当于这些灯形成了一个包含 \(1\) 号灯的连通块。
如何判定这些灯形成一个连通块?点减边!设 \(c_i\) 为操作前 \(i\) 次后,未点亮的灯的 \(|V| - |E|\) 的值,那么 \(c_i = 1\) 即合法。
对于固定的 \(i\),\(|V|\) 是固定的,等于 \(n - i\),可以先初始化 \(c_i = n - i\)。对于 \(|E|\),我们考虑每一条边的贡献。设 \(b_u\) 表示 \(u\) 号灯被点亮的时间,那么对于边 \((u,v)\),在时刻 \(1\sim \min(b_u, b_v) - 1\) 时有贡献,令 \(c_{1\sim \min(b_u, b_v) - 1}\) 减一。
对于 \(c_i = 1\),我们考虑如何计算答案。仍然考虑每条边的贡献,一条边有贡献的条件是两端的灯一个点亮,一个未点亮。设 \(w_i\) 为 \(i\) 时刻这样的边的数量,对于边 \((u,v)\),令 \(w_{\min(b_u, b_v) \sim \max(b_u, b_v) - 1}\) 加一。我们相当于求 \(\sum_{i = 1} ^ {n - 1} [c_i = 1] w_i\)。
多次询问时,由于我们并不关心树的形态,而只关心每条边的贡献,所以可以轻松使用线段树维护。具体的,维护 \(\min c,\ \sum_i [c_i = \min c], \ \sum_i [c_i = \min c] w_i\) 三个信息。
点击查看代码
#include<bits/stdc++.h>
#define ll long long
#define ull unsigned ll
#define pir pair <ll, ll>
#define mkp make_pair
#define fi first
#define se second
#define pb push_back
using namespace std;
const ll maxn = 5e5 + 10, inf = 1e18;
ll n, m, x[maxn], y[maxn], a[maxn], b[maxn];
struct SGT {
ll mn[maxn << 2], cnt[maxn << 2], sum[maxn << 2];
ll tagmn[maxn << 2], tagw[maxn << 2];
void addtag(ll p, ll v1, ll v2) {
mn[p] += v1, tagmn[p] += v1;
sum[p] += v2 * cnt[p], tagw[p] += v2;
}
void pushdown(ll p) {
addtag(p << 1, tagmn[p], tagw[p]);
addtag(p << 1|1, tagmn[p], tagw[p]);
tagmn[p] = tagw[p] = 0;
}
void modify(ll p, ll l, ll r, ll ql, ll qr, ll v1, ll v2) {
if(ql <= l && r <= qr) { addtag(p, v1, v2); return; }
pushdown(p); ll mid = l + r >> 1;
if(ql <= mid) modify(p << 1, l, mid, ql, qr, v1, v2);
if(mid < qr) modify(p << 1|1, mid + 1, r, ql, qr, v1, v2);
mn[p] = min(mn[p << 1], mn[p << 1|1]), cnt[p] = sum[p] = 0;
if(mn[p << 1] == mn[p])
cnt[p] += cnt[p << 1], sum[p] += sum[p << 1];
if(mn[p << 1|1] == mn[p])
cnt[p] += cnt[p << 1|1], sum[p] += sum[p << 1|1];
}
void build(ll p, ll l, ll r) {
mn[p] = n - r, cnt[p] = 1;
if(l == r) return;
ll mid = l + r >> 1;
build(p << 1, l, mid), build(p << 1|1, mid + 1, r);
}
} tr;
void add(ll x, ll y, ll v) {
if(min(b[x], b[y]) > 1)
tr.modify(1, 1, n - 1, 1, min(b[x], b[y]) - 1, -v, 0);
tr.modify(1, 1, n - 1, min(b[x], b[y]), max(b[x], b[y]) - 1, 0, v);
}
int main() {
scanf("%lld%lld", &n, &m);
for(ll i = 1; i < n; i++)
scanf("%lld%lld", x + i, y + i);
for(ll i = 1; i < n; i++) {
scanf("%lld", a + i);
b[a[i]] = i;
} b[a[n] = 1] = n; tr.build(1, 1, n - 1);
for(ll i = 1; i < n; i++)
add(x[i], y[i], 1);
printf("%lld\n", tr.sum[1]);
while(m--) {
ll u, v, p, q; scanf("%lld%lld%lld%lld", &u, &v, &p, &q);
add(u, v, -1), add(p, q, 1);
printf("%lld\n", tr.sum[1]);
}
return 0;
}