目录
以洛谷模板题为例介绍动态 dp 的一般方法。
P4719 【模板】"动态 DP"&动态树分治 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn)
P4751 【模板】"动态DP"&动态树分治(加强版) - 洛谷 | 计算机科学教育新生态 (luogu.com.cn)
题目描述
给定一个 \(n\) 个点的带点权树,进行 \(m\) 次修改点权的操作。
你需要在每次修改之后输出树上最大带权独立集的权值之和。
\(n,m\leq 10^6\)。
朴素算法
设 \(f_{u, 0/1}\) 表示考虑 \(u\) 子树内时,\(u\) 选或不选。可以转移:
\[f_{u, 0}=\sum_{v}\max(f_{v, 0}, f_{v, 1}) \]\[f_{u, 1}=a_u+\sum_{v}f_{v, 0} \]矩阵刻画
\[\begin{bmatrix} g_{u, 0}& g_{u, 0}\\ g_{u, 1}& -\infty \end{bmatrix}\times\begin{bmatrix} f_{v, 0}\\ f_{v, 1} \end{bmatrix}=\begin{bmatrix} f_{u, 0}\\ f_{u, 1} \end{bmatrix} \]这样可以轻易写为树剖形式:令 \(v=son(u)\),\(g_{u, 0}=\sum_{v\neq son[u]}\max(f_{v, 0}, f_{v, 1})\),\(g_{u, 1}=a_u+\sum_{v\neq son[u]}f_{v, 0}\)。基本的思路就是在重链上考虑,如何用一个点的重儿子的 dp 值和它的轻子树的 dp 值推出这个点的 dp 值。这个点自身的贡献可以在轻儿子上考虑。
实现
在每个点上维护 \(g_u\) 表示其轻儿子的信息(定义在上文给出),然后尝试维护每一条链的 \(g\) 的乘积(按照某种顺序,如从浅至深)。修改点权会影响 \(O(\log n)\) 个 \(g\) 以及 \(O(\log n)\) 条重链的乘积,从深往浅,每次修改一个 \(g\) 后重新计算所在重链的 \(g\) 的乘积,从而计算对链头父亲的 \(g\) 的影响,修改后继续往上跳重链。
使用
【模板】轻重链剖分 treecut - caijianhong - 博客园 (cnblogs.com)
题解 LGP4211【LNOI2014]LCA】/【模板】全局平衡二叉树 - caijianhong - 博客园 (cnblogs.com)
实现即可。
code
注意:本题的答案矩阵 \(A\) 可以直接归纳证明 \(A_{0, 0}\geq A_{0, 1}, A_{1, 0}\geq A_{1, 1}\)。因此输出答案为 \(\max(A_{0, 0}, A_{1, 0})\)。
代码中 val
是 \(g\),sum
是重链上的 \(g\) 的乘积。
重链剖分 2log 实现
#include <algorithm>
#include <cstdio>
#include <cstring>
using namespace std;
typedef long long LL;
template <int N, int M, class T = int>
struct graph {
int head[N + 10], nxt[M * 2 + 10], cnt;
struct edge {
int u, v;
T w;
edge(int u = 0, int v = 0, T w = 0) : u(u), v(v), w(w) {}
} e[M * 2 + 10];
graph() { memset(head, cnt = 0, sizeof head); }
edge operator[](int i) { return e[i]; }
void add(int u, int v, T w = 0) {
e[++cnt] = edge(u, v, w), nxt[cnt] = head[u], head[u] = cnt;
}
void link(int u, int v, T w = 0) { add(u, v, w), add(v, u, w); }
};
template <int N, int M, class T = int>
struct matrix {
T a[N + 1][M + 1];
matrix(bool f = 0) {
memset(a, ~0x3f, sizeof a);
for (int i = 1; f && i <= N; i++) a[i][i] = 0;
}
T *operator[](int i) { return a[i]; }
};
template <int N, int M, int R, class T = int>
matrix<N, R> operator*(matrix<N, M, T> a, matrix<M, R, T> b) {
matrix<N, R, T> c = 0;
for (int i = 1; i <= N; i++) {
for (int j = 1; j <= R; j++) {
for (int k = 1; k <= M; k++) {
c[i][j] = max(c[i][j], a[i][k] + b[k][j]);
}
}
}
return c;
}
int fa[100010], siz[100010], son[100010], dfn[100010], rnk[100010], top[100010],
cnt;
template <int N, class T = matrix<2, 2>>
struct segtree {
// 鍔ㄦ€佸紑鐐癸紝閾鹃《涓烘牴锛屼贡鎼烇紝璇㈤棶 O(1)
T ans[N + 10];
int ch[N + 10][2], tot;
segtree() : tot(-1) { ans[0] = 1, newnode(); }
int newnode() { return ++tot, ch[tot][0] = ch[tot][1] = 0, tot; }
void maintain(int p) { ans[p] = ans[ch[p][0]] * ans[ch[p][1]]; }
void build(T a[], int &p, int l, int r) {
if (!p) p = newnode();
if (l == r) return ans[p] = a[rnk[l]], void();
int mid = (l + r) >> 1;
build(a, ch[p][0], l, mid);
build(a, ch[p][1], mid + 1, r);
maintain(p);
}
void modify(T &v, int x, int p, int l, int r) {
if (l == r) return ans[p] = v, void();
int mid = (l + r) >> 1;
if (x <= mid)
modify(v, x, ch[p][0], l, mid);
else
modify(v, x, ch[p][1], mid + 1, r);
maintain(p);
}
T query(int L, int R, int p, int l, int r) {
if (L <= l && r <= R) return ans[p];
int mid = (l + r) >> 1;
if (R <= mid)
return query(L, R, ch[p][0], l, mid);
else if (mid + 1 <= L)
return query(L, R, ch[p][1], mid + 1, r);
else
return query(L, R, ch[p][0], l, mid) * query(L, R, ch[p][1], mid + 1, r);
}
};
int n, m, f[100010][2], a[100010], root, eds[100010];
matrix<2, 2> val[100010];
graph<100010, 100010, bool> g;
segtree<200010> t;
int dfs(int u, int f = 0) {
fa[u] = f, siz[u] = 1, son[u] = 0;
for (int i = g.head[u]; i; i = g.nxt[i]) {
int v = g[i].v;
if (v == fa[u]) continue;
siz[u] += dfs(v, u);
if (siz[v] > siz[son[u]]) son[u] = v;
}
return siz[u];
}
void add(int u, int v) { f[u][0] += max(f[v][0], f[v][1]), f[u][1] += f[v][0]; }
void make(int u, int v) {
val[u][1][1] += max(f[v][0], f[v][1]), val[u][1][2] = val[u][1][1],
val[u][2][1] += f[v][0];
}
void cut(int u, int topf) {
rnk[dfn[u] = ++cnt] = u, top[u] = topf, eds[topf] = max(eds[topf], dfn[u]);
f[u][0] = 0, f[u][1] = a[u], val[u][1][1] = val[u][1][2] = 0,
val[u][2][1] = a[u];
if (son[u]) cut(son[u], topf), add(u, son[u]);
for (int i = g.head[u]; i; i = g.nxt[i]) {
int v = g[i].v;
if (v == fa[u] || v == son[u]) continue;
cut(v, v);
add(u, v), make(u, v);
}
}
void update(int u, int k) {
val[u][2][1] += k - a[u], a[u] = k;
matrix<2, 2> bef, aft;
while (u) {
bef = t.query(dfn[top[u]], eds[top[u]], root, 1, n);
t.modify(val[u], dfn[u], root, 1, n);
aft = t.query(dfn[top[u]], eds[top[u]], root, 1, n);
u = fa[top[u]];
val[u][1][1] += max(aft[1][1], aft[2][1]) - max(bef[1][1], bef[2][1]);
val[u][1][2] = val[u][1][1];
val[u][2][1] += aft[1][1] - bef[1][1];
}
}
int main() {
// #ifdef LOCAL
// freopen("input.in","r",stdin);
// #endif
scanf("%d%d", &n, &m);
for (int u = 1; u <= n; u++) scanf("%d", &a[u]);
for (int i = 1, u, v; i < n; i++) scanf("%d%d", &u, &v), g.link(u, v);
dfs(1), cut(1, 1), t.build(val, root, 1, n);
for (int i = 1, u, k; i <= m; i++) {
scanf("%d%d", &u, &k);
update(u, k);
matrix<2, 2> res = t.query(dfn[top[1]], eds[top[1]], root, 1, n);
printf("%d\n", max(res[1][1], res[2][1]));
}
return 0;
}
全局平衡二叉树 1log 实现
#include <bits/stdc++.h>
using namespace std;
#ifdef LOCAL
#define debug(...) fprintf(stderr, __VA_ARGS__)
#else
#define endl "\n"
#define debug(...) void(0)
#endif
using LL = long long;
template <int n, int m>
struct matrix {
int mat[n][m];
auto operator[](int i) { return mat[i]; }
auto operator[](int i) const { return mat[i]; }
void fill(int v) { fill_n(&mat[0][0], n * m, v); }
};
template <int n>
matrix<n, n> I() { /*{{{*/
matrix<n, n> res;
res.fill(-1e9);
for (int i = 0; i < n; i++) res[i][i] = 0;
return res;
} /*}}}*/
template <int n, int m, int r>
matrix<n, r> operator*(const matrix<n, m>& lhs,
const matrix<m, r>& rhs) { /*{{{*/
matrix<n, r> res;
res.fill(-1e9);
for (int i = 0; i < n; i++)
for (int j = 0; j < m; j++)
for (int k = 0; k < r; k++)
res[i][k] = max(res[i][k], lhs[i][j] + rhs[j][k]);
return res;
} /*}}}*/
template <>
matrix<2, 2> operator*(const matrix<2, 2>& lhs, const matrix<2, 2>& rhs) {
return {max(lhs[0][0] + rhs[0][0], lhs[0][1] + rhs[1][0]),
max(lhs[0][0] + rhs[0][1], lhs[0][1] + rhs[1][1]),
max(lhs[1][0] + rhs[0][0], lhs[1][1] + rhs[1][0]),
max(lhs[1][0] + rhs[0][1], lhs[1][1] + rhs[1][1])};
}
constexpr int N = 1e6 + 10;
int f[N][2], n, m, a[N];
matrix<2, 2> val[N], sum[N];
basic_string<int> g[N];
int fa[N], siz[N], son[N], dep[N], cnt, dfn[N], rnk[N], top[N], tf[N], ch[N][2];
void maintain(int p) {
// sum[p] = sum[ch[p][0]] * val[p] * sum[ch[p][1]];
sum[p] = val[p];
if (ch[p][0]) sum[p] = sum[ch[p][0]] * sum[p];
if (ch[p][1]) sum[p] = sum[p] * sum[ch[p][1]];
}
int build(int l, int r) {
if (l > r) return 0;
int pos = l + 1, T = siz[rnk[l]] - siz[son[rnk[r]]];
while (pos <= r && siz[rnk[l]] - siz[son[rnk[pos]]] <= T / 2) ++pos;
int p = rnk[--pos];
if ((ch[p][0] = build(l, pos - 1))) tf[ch[p][0]] = p;
if ((ch[p][1] = build(pos + 1, r))) tf[ch[p][1]] = p;
maintain(p);
return p;
}
void update(int p, int k) {
debug("update(%d, %d)\n", p, k);
val[p][1][0] += k - a[p], a[p] = k;
while (p) {
if (ch[tf[p]][0] != p && ch[tf[p]][1] != p) {
val[tf[p]][0][0] -= max(sum[p][0][0], sum[p][1][0]);
val[tf[p]][1][0] -= sum[p][0][0];
}
maintain(p);
if (ch[tf[p]][0] != p && ch[tf[p]][1] != p) {
val[tf[p]][0][0] += max(sum[p][0][0], sum[p][1][0]);
val[tf[p]][0][1] = val[tf[p]][0][0];
val[tf[p]][1][0] += sum[p][0][0];
}
p = tf[p];
}
}
void dfs(int u, int _fa) {
fa[u] = _fa, dep[u] = dep[_fa] + 1, siz[u] = 1;
for (int v : g[u])
if (v != _fa) {
dfs(v, u), siz[u] += siz[v];
if (siz[v] > siz[son[u]]) son[u] = v;
}
}
void trans(int u, int v) {
f[u][0] += max(f[v][0], f[v][1]);
f[u][1] += f[v][0];
}
void cut(int u, int topf) {
dfn[u] = ++cnt, rnk[cnt] = u, top[u] = topf;
f[u][0] = 0, f[u][1] = a[u];
val[u] = {0, 0, a[u], (int)-1e9};
if (son[u]) cut(son[u], topf), trans(u, son[u]);
// else tf[build(dfn[topf], dfn[u])] = fa[topf];
for (int v : g[u])
if (v != fa[u] && v != son[u]) {
cut(v, v), trans(u, v);
val[u][0][0] += max(f[v][0], f[v][1]);
val[u][1][0] += f[v][0];
}
val[u][0][1] = val[u][0][0];
debug("son[%d] = %d\n", u, son[u]);
debug("f[%d] = {%d, %d}\n", u, f[u][0], f[u][1]);
}
int main() {
#ifndef LOCAL
cin.tie(nullptr)->sync_with_stdio(false);
#endif
cin >> n >> m;
for (int i = 1; i <= n; i++) cin >> a[i];
for (int i = 1, u, v; i < n; i++) cin >> u >> v, g[u] += v, g[v] += u;
sum[0] = I<2>();
dfs(1, 0), cut(1, 1);
for (int i = 1; i <= n; i++)
if (!son[i]) tf[build(dfn[top[i]], dfn[i])] = fa[top[i]];
auto rt = find(tf + 1, tf + n + 1, 0) - tf;
debug("ans0 = %d\n", max(f[1][0], f[1][1]));
debug("ans0 = %d\n", max(sum[rt][0][0], sum[rt][1][0]));
#ifdef LOCAL
for (int i = 1; i <= n; i++) debug("tf[%d] = %d\n", i, tf[i]);
#endif
int lst = 0;
while (m--) {
int u, k;
cin >> u >> k;
update(u ^ lst, k);
auto&& res = sum[rt];
assert(res[0][1] <= res[0][0]);
cout << (lst = max(res[0][0], res[1][0])) << endl;
}
return 0;
}