P9847 [ICPC2021 Nanjing R] Crystalfly
你说得对,但是刻晴更可爱捏
翻译
给定一个 \(n(1\le n\le10^5)\) 个节点的树,每个节点上有 \(a_i\) 只晶蝶。派蒙最初在 \(1\) 号节点,并获得 \(1\) 号节点的所有晶蝶,接下来每一秒她可以移动到相邻的节点上并获得节点上的所有晶蝶,但是当她每到达一个节点 \(u\) 后,对于每个与 \(u\) 相邻的节点 \(v\),节点 \(v\) 上的的晶蝶会在 \(t_v(1\le t_v\le3)\) 秒内消失,在 \(t_v\) 秒后再到达节点 \(v\) 将无法获得节点上的晶蝶。现在需要你求出最多可以获得的晶蝶数。
分析
观察到特殊数据 \(t_i\le3\),不难想到只有两种情况:
- 走进节点 \(u\) 的一棵子树后放弃其他子树,转化成和原问题相同的子问题。
- 若 \(t_i=3\) 则可以进入一个节点后再退出并进入另一棵子树,转化成子问题。
因此考虑树上 DP。
定义 \(f_{u,0/1}\) 为遍历以 \(u\) 为根的整棵子树且 \(u\) 点的子节点的晶蝶是否消失的情况下所能获得的最大晶蝶数量。记与 \(u\) 相邻的非父亲节点中 \(t_i=3\) 的节点晶蝶数量的最大值和次大值(若存在,不存在特判即可)分别为 \(max1,max2\)。
如果当前节点不存在 \(t_i=3\) 的节点,那么 \(f_{u,0}=\sum\limits_{v\in son_u}f_{v,1},f_{u,1}=(\sum\limits_{v\in son_u}f_{v,1})+\max(a_v)\)。
如果当前节点存在 \(t_i=3\) 的节点,那么通过画图观察不难发现,记所有子节点的 \(f_{v,1}\) 的和 \(\sum\limits_{v\in son_u}f_{v,1}\) 为 \(sum\),\(f_{u,0}\) 结果不变,\(f_{u,1}=\max(f_{v,0}+a_v+sum-f_{v,1}+max1,f_{v,0}+a_v+sum-f_{v,1}+max2)\)。
这样最后的结果就变成了 \(f_{1,1}+a_1\)。
Code
#include <bits/stdc++.h>
#define int long long
using namespace std;
const int maxn = 1e5 + 5;
int f[maxn][2], a[maxn], TestCase, t[maxn], n;
vector<int> G[maxn], g[maxn];
inline void dfs(int u, int fa) {
int maxx, res;
maxx = res = 0;
f[u][0] = f[u][1] = 0;
for (auto v : G[u]) {
if (v == fa) continue;
dfs(v, u);
res += f[v][1];
maxx = max(maxx, a[v]);
}
f[u][0] = res, f[u][1] = res + maxx;
if (g[u].size()) {
int maxx1, maxx2, maxid1, maxid2;
maxx1 = maxx2 = -2e18;
maxid1 = maxid2 = 0;
for (auto v : g[u]) {
if (v == fa) continue;
if (maxx1 < a[v]) {
maxx2 = maxx1;
maxid2 = maxid1;
maxx1 = a[v];
maxid1 = v;
}
else if (maxx2 < a[v]) {
maxx2 = a[v];
maxid2 = v;
}
}
maxx = -2e18;
for (auto v : G[u]) {
if (v == fa) continue;
if (v == maxid1) {
if (maxid2 == 0) maxx = max(maxx, a[v] + f[v][0] + res - f[v][1]);
else maxx = max(maxx, a[v] + f[v][0] + res - f[v][1] + maxx2);
}
else maxx = max(maxx, a[v] + f[v][0] + res - f[v][1] + maxx1);
}
f[u][1] = max(maxx, f[u][1]);
}
}
signed main() {
ios::sync_with_stdio(false);
cin.tie(nullptr), cout.tie(nullptr);
cin >> TestCase;
while (TestCase--) {
cin >> n;
for (int i = 1; i <= n; i++) cin >> a[i];
for (int i = 1; i <= n; i++) cin >> t[i];
for (int i = 1; i <= n; i++) G[i].clear(), g[i].clear();
for (int i = 1; i < n; i++) {
int u, v;
cin >> u >> v;
G[u].emplace_back(v), G[v].emplace_back(u);
if (t[v] == 3) g[u].emplace_back(v);
if (t[u] == 3) g[v].emplace_back(u);
}
dfs(1, 1);
cout << a[1] + f[1][1] << "\n";
}
return 0;
}
Feedback
不知道是不是造数据的时候数据生成的问题,在判断相邻两边是否是 \(t=3\) 的点时:
if (t[v] == 3) g[u].emplace_back(v);
if (t[u] == 3) g[v].emplace_back(u);
与
if (t[v] == 3) g[u].emplace_back(v);
都可以通过本题。
显然后一种写法只有在 \(u\) 始终为 \(v\) 的根节点的时候才有正确性。
严重怀疑是不是 generator 采用了向前随机挂点的方法。