原题链接
\(首先注意到用点维护dp值非常地难做\)
\(我们无法通过点直接维护树上的每个节点的染色\)
\(因为这样做的复杂度为 O(2^n)\)
\(我们考虑到通过枚举边来处理\)
\(对于每条边 枚举它两边的黑色和白色节点数\)
\(那么对该条边被经过的数量为两边的黑色节点数和白色节点数的乘积\)
\(该算法理论最坏复杂度为O(n(m^2))\)
\(但是有个非常大的除数所以能够通过\)
\(code:\)
点击查看代码
#include<bits/stdc++.h>
#define int long long
using namespace std;
#define pb push_back
#define pii pair<int,int>
const int N = 2005;
int sz[N];
int dp[N][N];
void solve() {
int n, m;
cin >> n >> m;
if (n - m < m)m = n - m;
vector<vector<pii>> G(n + 1, vector<pii >());
memset(dp, -1, sizeof(dp));
for (int i = 1; i < n; i++) {
int u, v, w;
cin >> u >> v >> w;
G[u].pb({v, w});
G[v].pb({u, w});
}
function<void(int, int)> dfs = [&](int x, int fa) {
sz[x] = 1;
dp[x][0] = dp[x][1] = 0;
for (auto &[son, val]: G[x]) {
if (son == fa)continue;
dfs(son, x);
sz[x] += sz[son];
for (int j = min(m, sz[x]); j >= 0; --j) {
if (dp[x][j] != -1)
dp[x][j] += dp[son][0] + sz[son] * (n - m - sz[son]) * val;
for (int k = min(j, sz[son]); k; k--) {
if (dp[x][j - k] == -1)continue;
int ans = (k * (m - k) + (sz[son] - k) * (n - m - sz[son] + k)) * val;
dp[x][j] = max(dp[x][j], dp[x][j - k] + dp[son][k] + ans);
}
}
}
};
dfs(1, -1);
cout << dp[1][m];
}
signed main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
cout.tie(nullptr);
solve();
}