1.算法简介
整体 DP 就是用线段树合并维护 DP。
有一些问题,通常见于二维的DP,有一维记录当前x的信息,但是这一维过大无法开下,O(nm) 也无法通过。
但是如果发现,对于 x,在第二维的一些区间内,取值都是相同的,并且这样的区间是有限个,就可以批量处理。
所以我们就可以用线段树来维护 DP。
对于序列的问题,可以直接扫过去,修改某些位置的点,或者线段树合并。
对于树上的问题,线段树合并。
2.例题
Ⅰ. P4577 [FJOI2018] 领导集团问题
设 \(f_{i,j}\) 表示以 \(i\) 为根的子树 \(min_w≥j\) 的答案,分两种情况讨论:
-
\(f_{i,j} += f{son_i,j}\)
-
\(f_{i,j} = max(f_{i,j},f_{i,w_i} + 1) (j \le w_i)\) 在合并完儿子后进行此操作。
对于转移 1,显然直接线段树合并即可,对于转移 2 ,发现其实是区间 +,且\(f_{i,j}\) 是单调递减的,所以我们只要维护区间最小值 \(mn\),然后找到那段区间即可,需要标记永久化(不用也行)。
#include<bits/stdc++.h>
#define pb push_back
using namespace std;
const int N = 2e5 + 67;
int read(){
int x = 0, f = 1; char ch = getchar();
while(ch < '0' || ch > '9'){if(ch == '-') f = -f; ch = getchar();}
while(ch >= '0' && ch <= '9'){x = (x << 1) + (x << 3) + (ch ^ 48); ch = getchar();}
return x * f;
}
bool _u;
int n, cnt;
int a[N], b[N], rt[N];
int ls[N << 5], rs[N << 5], lz[N << 5], mn[N << 5];
vector<int> e[N];
void pushup(int x){
mn[x] = min(mn[ls[x]], mn[rs[x]]) + lz[x];
}
void modify(int &x, int l, int r, int L, int R){
if(!x) x = ++cnt;
if(L <= l && r <= R) return ++lz[x], ++mn[x], void();
int mid = (l + r) >> 1;
if(L <= mid) modify(ls[x], l, mid, L, R);
if(R > mid) modify(rs[x], mid + 1, r, L, R);
pushup(x);
}
int query1(int x, int l, int r, int p){
if(l == r && !x) return lz[x];
int mid = (l + r) >> 1;
if(p <= mid) return lz[x] + query1(ls[x], l, mid, p);
else return lz[x] + query1(rs[x], mid + 1, r, p);
}
int query2(int x, int l, int r, int val){
if(l == r) return l;
int mid = (l + r) >> 1; val -= lz[x];
if(mn[ls[x]] <= val) return query2(ls[x], l, mid, val);
else return query2(rs[x], mid + 1, r, val);
}
int merge(int x, int y){
if(!x || !y) return x + y;
ls[x] = merge(ls[x], ls[y]), rs[x] = merge(rs[x], rs[y]);
lz[x] += lz[y], pushup(x);
return x;
}
void dfs(int x){
for(auto y : e[x]) dfs(y), rt[x] = merge(rt[x], rt[y]);
int tmp = query1(rt[x], 1, n, a[x]);
modify(rt[x], 1, n, query2(rt[x], 1, n, tmp), a[x]);
}
bool _v;
int main(){
cerr << abs(&_u - &_v) / 1048576.0 << " MB\n";
n = read();
for(int i = 1; i <= n; ++i) b[i] = a[i] = read();
sort(b + 1, b + 1 + n);
for(int i = 1; i <= n; ++i) a[i] = lower_bound(b + 1, b + 1 + n, a[i]) - b;
for(int i = 2, f; i <= n; ++i) f = read(), e[f].pb(i);
dfs(1); printf("%d\n", query1(rt[1], 1, n, 1));
return 0;
}
Ⅱ.CF490F Treeland Tour
\(f/g_{i,j}\) 表示以 \(j\) 为结尾的LIS 和 LDS。线段树维护即可。
#include<bits/stdc++.h>
#define pb push_back
using namespace std;
const int N = 6e3 + 67, LIM = 1e6;
int read(){
int x = 0, f = 1; char ch = getchar();
while(ch < '0' || ch > '9'){if(ch == '-') f = -f; ch = getchar();}
while(ch >= '0' && ch <= '9'){x = (x << 1) + (x << 3) + (ch ^ 48); ch = getchar();}
return x * f;
}
bool _u;
int n, ans, cnt;
int a[N], rt[N];
int ls[N << 5], rs[N << 5], pre[N << 5], suf[N << 5];
vector<int> e[N];
void modify(int &x, int l, int r, int p, int v, int *val){
if(!x) x = ++cnt; val[x] = max(val[x], v);
if(l == r) return ;
int mid = (l + r) >> 1;
if(p <= mid) modify(ls[x], l, mid, p, v, val);
else modify(rs[x], mid + 1, r, p, v, val);
}
int merge(int x, int y){
if(!x || !y) return x + y;
pre[x] = max(pre[x], pre[y]), suf[x] = max(suf[x], suf[y]);
ans = max(ans, max(pre[ls[x]] + suf[rs[y]], suf[rs[x]] + pre[ls[y]]));
ls[x] = merge(ls[x], ls[y]), rs[x] = merge(rs[x], rs[y]);
return x;
}
int query(int x, int l, int r, int L, int R, int *val){
if(!x || L > R) return 0;
if(L <= l && r <= R) return val[x];
int mid = (l + r) >> 1, ans = 0;
if(L <= mid) ans = max(ans, query(ls[x], l, mid, L, R, val));
if(R > mid) ans = max(ans, query(rs[x], mid + 1, r, L, R, val));
return ans;
}
void dfs(int x, int fa){
int ns = 0, np = 0;
for(auto y : e[x]){
if(y == fa) continue;
dfs(y, x);
int tp = query(rt[y], 1, LIM, 1, a[x] - 1, pre);
int ts = query(rt[y], 1, LIM, a[x] + 1, LIM, suf);
ans = max(ans, max(np + ts, ns + tp) + 1);
ns = max(ns, ts), np = max(np, tp);
rt[x] = merge(rt[x], rt[y]);
}
modify(rt[x], 1, LIM, a[x], np + 1, pre);
modify(rt[x], 1, LIM, a[x], ns + 1, suf);
}
bool _v;
int main(){
cerr << abs(&_u - &_v) / 1048576.0 << " MB\n";
n = read();
for(int i = 1; i <= n; ++i) a[i] = read();
for(int i = 1; i < n; ++i){
int u = read(), v = read();
e[u].pb(v), e[v].pb(u);
}
dfs(1, 0);
printf("%d\n", ans);
return 0;
}
Ⅲ. P6773 [NOI2020] 命运
发现对于 若干点对 \((u,v)\),如果对于当前 \(x\) 节点满足 \(v\) 在 \(x\) 的子树内,且 \(u\) 为 \(x\) 的祖先,那么我们只要满足 \(u\) 节点最深的点对即可。
所以我们可以设 \(f[u][i]\) 表示 表示以 \(u\) 为根的子树内,下端点在子树内并且没有被满足的限制中上端点的最深深度为 \(i\) 的方案数。进行分类讨论。
-
\((u,v)\) 的权值 为 \(1\) ,\(f[u][i] = \sum\limits_{j = 0}^{dep[u]} f[u][i] \times f[v][j]\)
-
\((u,v)\) 的权值 为 \(0\) ,\(f[u][i] = \sum\limits_{j = 0}^{i} f[u][i] \times f[v][j] + \sum\limits_{j = 0}^{i - 1} f[u][j] \times f[v][i]\)
设 \(g[u][i] = \sum\limits_{j = 0}^{i} f[u][j]\)。
那么转移式就可以写成
\[f[u][i] = f[u][i] \times (g[v][dep[u]] + g[v][i]) + g[u][i - 1] \times f[v][i] \]线段树合并即可。
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N = 5e5 + 67, mod = 998244353;
int read(){
int x = 0, f = 1; char ch = getchar();
while(ch < '0' || ch > '9'){if(ch == '-') f = -f; ch = getchar();}
while(ch >= '0' && ch <= '9'){x = (x << 1) + (x << 3) + (ch ^ 48); ch = getchar();}
return x * f;
}
bool _u;
int n, m, nod;
int dep[N], ls[N << 5], rs[N << 5], rt[N];
ll laz[N << 5], sum[N << 5];
void add(ll &x, ll y){x += y; if(x >= mod) x -= mod;}
void mul(ll &x, ll y){x = x * y % mod;}
struct Edge{
int tot, to[N << 1], nxt[N << 1], hd[N];
void add(int u, int v){to[++tot] = v, nxt[tot] = hd[u], hd[u] = tot;}
}e, g;
void build(int &u, int l, int r, int p){
u = ++nod; sum[u] = laz[u] = 1;
if(l == r) return ;
int mid = (l + r) >> 1;
if(p <= mid) build(ls[u], l, mid, p);
else build(rs[u], mid + 1, r, p);
}
void pushup(int u){sum[u] = (sum[ls[u]] + sum[rs[u]]) % mod;}
void pushdown(int u){
if(laz[u] != 1){
mul(sum[ls[u]], laz[u]), mul(sum[rs[u]], laz[u]);
mul(laz[ls[u]], laz[u]), mul(laz[rs[u]], laz[u]);
} laz[u] = 1;
}
ll query(int u, int l, int r, int p){
if(!u || r <= p) return sum[u];
int mid = (l + r) >> 1; ll ans = 0; pushdown(u);
if(p > mid) ans = query(rs[u], mid + 1, r, p);
return add(ans, query(ls[u], l, mid, p)), ans;
}
int merge(int x, int y, int l, int r, ll &s1, ll &s2){ //s1 表示 g[v][dep[u]] + g[v][i], s2 表示 g[u][i - 1]
if(!x && !y) return 0;
if(!y) return add(s2, sum[x]), mul(laz[x], s1), mul(sum[x], s1), x;
if(!x) return add(s1, sum[y]), mul(laz[y], s2), mul(sum[y], s2), y;
if(l == r){
ll tmp = sum[x];
add(s1, sum[y]), mul(sum[x], s1);
add(sum[x], sum[y] * s2 % mod), add(s2, tmp);
//注意先后顺序
return x;
}
pushdown(x), pushdown(y);
int mid = (l + r) >> 1;
ls[x] = merge(ls[x], ls[y], l, mid, s1, s2);
rs[x] = merge(rs[x], rs[y], mid + 1, r, s1, s2);
return pushup(x), x;
}
void dfs(int x, int fa){
int mxd = 0; dep[x] = dep[fa] + 1;
for(int i = g.hd[x]; i; i = g.nxt[i]) mxd = max(mxd, dep[g.to[i]]); //找到最深的点
build(rt[x], 0, n, mxd);
for(int i = e.hd[x]; i; i = e.nxt[i]){
int y = e.to[i]; if(y == fa) continue;
dfs(y, x);
ll zx = query(rt[y], 0, n, dep[x]), zxq = 0;
rt[x] = merge(rt[x], rt[y], 0, n, zx, zxq);
}
}
bool _v;
int main(){
cerr << abs(&_u - &_v) / 1048576.0 << " MB\n";
// freopen("destiny4.in", "r", stdin);
n = read();
for(int i = 1; i < n; ++i){
int u = read(), v = read();
e.add(u, v), e.add(v, u);
}
m = read();
for(int i = 1; i <= m; ++i){
int u = read(), v = read();
g.add(v, u);
}
dfs(1, 0);
printf("%lld\n", query(rt[1], 0, n, 0));
return 0;
}
标签:rt,ch,return,int,sum,mid,笔记,27,DP
From: https://www.cnblogs.com/jiangchen4122/p/17713280.html