其实并不是什么很高大上的东西,就是把内层 dp 的结果压到外层 dp 的状态里。
通常解决的是“限制某种值为 \(x\) 的方案数”之类的问题,而限制的值通常是一个经典的 dp 问题。
没有啥好直接介绍的,就写三道做过的题。
BZOJ3864 Hero meet devil
算是一道入门题目。
我们先回忆一下一个经典问题:给定两个串 \(s\) 和 \(t\),求它们的 \(\text{LCS}\)。
考虑 dp:设 \(g_{i,j}\) 表示串 \(t\) 的前 \(i\) 位和串 \(s\) 的前 \(j\) 位的 \(\text{LCS}\),转移是比较简单的:
\[g_{i,j}=\begin{cases}g_{i-1,j-1}+1,&t_i=s_j\\\max(g_{i-1,j},g_{i,j-1}),&t_i\not=s_j\end{cases} \]现在我们就是要对每个 \(0\le i\le|s|\) 统计有多少个串 \(t\) 满足 \(g_{|t|,|s|}=i\)。
要统计 dp 值为某个值的串有几个,不妨直接将 \(g\) 作为内层 dp 记录到状态里:设 \(f_{i,S}\) 表示已经填了 \(i\) 个字符,\(g\) 数组的第 \(i\) 行结果为 \(S\) 的串有多少个。
直接这么记录状态数肯定是爆炸了,毕竟你要用一个数表示整个数组。但是我们冷静分析下发现,固定 \(g\) 的第 \(i\) 行时,相邻的 \(g_{i,j}\) 之间的差值不会超过 \(1\)(这是显然的,多加一个字符 \(\text{LCS}\) 的长度至多增加 \(1\)),换句话说就是 \(g\) 的每行的差分数组只有 \(0/1\) 两种数。所以我们可以直接把 \(S\) 记录成 \(g\) 第 \(i\) 行的差分数组,这样我们就把第二维的状态数压到了 \(2^{|s|}\) 种。
因此我们提前预处理出 \(nxt_{S,c}\) 表示状态为 \(S\) 时接上一个字符 \(c\) 会转移到哪种状态,转移就是直接枚举新加入的字符 \(c\),\(f_{i,S}\to f_{i+1,nxt_{S,c}}\) 即可。
这次直接把代码放上来。
#include<bits/stdc++.h>
#define int long long
#define For(i, a, b) for(int i = (a); i <= (b); i++)
#define Rof(i, a, b) for(int i = (a); i >= (b); i--)
#define deb(x) cerr << #x"=" << x << '\n';
using namespace std;
const int mod = 1e9 + 7;
int n, m, nxt[1 << 15][5];
int f[1005][1 << 15], g[20], h[20], a[20], ans[20];
string s;
int get_nxt(int S, int c){
For(i, 1, n) g[i] = g[i - 1] + ((S >> i - 1) & 1);
For(i, 1, n){
if(a[i] == c) h[i] = g[i - 1] + 1;
else h[i] = max(h[i - 1], g[i]);
}
int T = 0;
For(i, 1, n) T |= (h[i] - h[i - 1]) << i - 1;
return T;
}
void Solve(){
cin >> s;
n = s.size(); s = ' ' + s;
For(i, 1, n){
if(s[i] == 'A') a[i] = 1;
else if(s[i] == 'G') a[i] = 2;
else if(s[i] == 'T') a[i] = 3;
else a[i] = 4;
}
For(i, 0, (1 << n) - 1){
nxt[i][1] = get_nxt(i, 1);
nxt[i][2] = get_nxt(i, 2);
nxt[i][3] = get_nxt(i, 3);
nxt[i][4] = get_nxt(i, 4);
}
cin >> m; f[0][0] = 1;
For(i, 0, m - 1) For(S, 0, (1 << n) - 1){
if(!f[i][S]) continue;
(f[i + 1][nxt[S][1]] += f[i][S]) %= mod;
(f[i + 1][nxt[S][2]] += f[i][S]) %= mod;
(f[i + 1][nxt[S][3]] += f[i][S]) %= mod;
(f[i + 1][nxt[S][4]] += f[i][S]) %= mod;
}
For(S, 0, (1 << n) - 1) (ans[__builtin_popcount(S)] += f[m][S]) %= mod;
For(i, 0, n) cout << ans[i] << '\n', ans[i] = 0;
For(i, 0, m) For(S, 0, (1 << n) - 1) f[i][S] = 0;
}
signed main(){
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
int T = 1; cin >> T;
while(T--) Solve();
return 0;
}
SDOI2022 小 N 的独立集
稍微加了点难度。
给定点权求最大独立集是一个经典的 dp 问题:设 \(f_{u,0/1}\) 表示 \(u\) 的子树内 \(u\) 不选/选的最大独立集。
我们注意到本题拥有着极小的值域(\(k\le 5\)),所以启发我们直接把 dp 结果扔到状态里。设 \(dp_{u,x,y}\) 表示 \(u\) 子树内 \(f_{u,0/1}\) 的值分别为 \(x\) 和 \(y\) 的方案数,转移考虑类似树上背包,将 \(u\) 的儿子 \(v\) 合并过来,具体地:
\[dp_{u,x,y}\times dp_{v,p,q}\to dp'_{u,x+\max(p,q),v+p} \]下标的更新方式就是我们内层 dp 原先的转移方式。
但是我们发现这个做法的状态数达到了 \(O(n^3k^2)\),难以通过。
我们尝试减少内层 dp 的状态数。我们可以发现强制 \(u\) 选的答案比强制 \(u\) 不选的答案不会优太多,稍加分析就可以观察到这么一个性质:\(0\le\max(f_{u,0},f_{u,1})-f_{u,0}\le k\)。这是因为对于强制选了 \(u\) 的方案来说,把 \(u\) 去掉最多只会减少 \(k\) 的权值且会变成一种不选的方案。这启发我们更改下内层 dp 的定义:设 \(f_{u,0/1}\) 表示 \(u\) 子树内不强制/强制 \(u\) 不选的方案数,这样以来就有 \(0\le f_{u,0}-f_{u,1}\le k\)。
那么我们把外层的 dp 状态也相应地更改为:\(dp_{u,x,y}\) 表示 \(u\) 子树内 \(f_{u,1}=x\) 且 \(f_{u,0}=x+y\) 的方案数,转移也不难:
\[dp_{u,x,y}\times dp_{v,p,q}\to dp_{u,x+p+q,\max(x+y+p,x+p+q)-(x+p+q)} \]此时的状态数是 \(O(n^2k^2)\),套用树上背包可以分析出复杂度的上界为 \(O(n^2k^4)\),这个上界极其宽松所以可以通过。
#include<bits/stdc++.h>
#define For(i, a, b) for(int i = (a); i <= (b); i++)
#define Rof(i, a, b) for(int i = (a); i >= (b); i--)
#define deb(x) cerr << #x"=" << x << '\n';
using namespace std;
const int mod = 1e9 + 7;
int n, k, siz[1005], f[1005][5005][6], g[5005][6], ans[5005];
vector<int> e[1005];
void Add(int &x, int y){if((x = x + y) >= mod) x -= mod;}
void dfs(int now, int fa){
siz[now] = 1;
For(i, 1, k) f[now][0][i] = 1;
for(int to : e[now]){
if(to == fa) continue;
dfs(to, now);
memset(g, 0, sizeof g);
For(x, 0, k * siz[now]) For(y, 0, k) if(f[now][x][y])
For(p, 0, k * siz[to]) For(q, 0, k) if(f[to][p][q])
Add(g[x + p + q][max(x + y + p, x + p + q) - (x + p + q)], 1ll * f[now][x][y] * f[to][p][q] % mod);
memcpy(f[now], g, sizeof g);
siz[now] += siz[to];
}
}
void Solve(){
cin >> n >> k;
For(i, 1, n - 1){
int u, v; cin >> u >> v;
e[u].emplace_back(v);
e[v].emplace_back(u);
}
dfs(1, 0);
For(i, 1, n * k){
int ans = 0;
For(j, 0, min(i, k)) Add(ans, f[1][i - j][j]);
cout << ans << '\n';
}
}
signed main(){
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
int T = 1; //cin >> T;
while(T--) Solve();
return 0;
}
CF924F Minimal Subset Difference
比较困难的一题。
考虑一个数的答案如何计算,这个显然等价于子集和,只能做类似背包的东西。我们设 \(f_{i,j}\) 表示用了 \(i\) 位数当前的差是否可以为 \(j\),转移时如果加入了一个 \(c\) 可以转移到 \(f_{i+1,j+c}\) 或 \(f_{i+1,|j-c|}\)。事实上第一维可以扔去,只保留 \(f_i\) 表示差值为 \(i\) 的可行性。
因为单个数计算的复杂度是不可能再低下去了,我们就只能计数有多少个数的 dp 状态是合法的,那就只能是 dp 套 dp。看上去直接把 \(f\) 设在状态里状态数直接升天了,但是先别急,我们慢慢降。
首先一个观察是最后的答案一定不超过 \(9\),因为考虑一个贪心:每次往和小的集合里扔数,这样能保证最后差不超过 \(9\)。而我们又注意到,如果某个时刻两个集合的差大于了 \(72\),那么即使剩下位全都是 \(9\) 扔过去答案也不会小于 \(9\),所以说我们的 \(f\) 其实只需要保留 \(f_0\) 到 \(f_{72}\) 的状态就好了,这样我们可以直接用一个 int128
表示 \(f\)。
现在的状态数是 \(2^{73}\) 左右,还是很爆炸。因为填的数最多只有 \(19\) 位,我们考虑直接爆搜,搜出所有的合法状态,发现只有一万多种!
所以我们可以直接把这一万多种状态拉出来做 dp 套 dp 了。预处理 \(g_{lim,len,S}\) 表示限制差值不超过 \(lim\),还剩下 \(len\) 位数需要填,当前的状态为 \(S\) 时还有多少种填法,转移是容易的。对于询问,显然要先差分掉,然后我们枚举 \(\text{LCP}\) 直接在 \(g\) 这个自动机上走就能统计答案。
跑得相当快,CF 上只跑了 234ms。
#include<bits/stdc++.h>
#define int long long
#define For(i, a, b) for(int i = (a); i <= (b); i++)
#define Rof(i, a, b) for(int i = (a); i >= (b); i--)
#define deb(x) cerr << #x"=" << x << '\n';
using namespace std;
using LL = __int128;
const int D = 10, L = 19, MS = 20005, W = 72;
const LL U = ((LL)1 << W + 1) - 1;
map<LL, int> ID;
int tot, ans[MS], nxt[MS][10], f[D][L][MS];
LL val[MS];
vector<int> vec[D];
LL get_nxt(LL S, int c){
LL T = (S >> c) | (S << c);
For(i, 0, c) if((S >> i) & 1) T |= (1 << c - i);
return T & U;
}
int get_ans(LL S){
For(i, 0, D - 1) if((S >> i) & 1) return i;
assert(0); return 114514;
}
void bfs(){
queue<pair<int, int>> q;
tot++; ans[1] = 0; ID[1] = 1; val[1] = 1; q.push({1, 0});
while(!q.empty()){
auto [cur, len] = q.front(); q.pop();
if(len == L - 1) continue;
For(c, 1, 9){
LL to = get_nxt(val[cur], c);
auto it = ID.find(to);
if(it != ID.end()) {nxt[cur][c] = it -> second; continue;}
ID[to] = ++tot; nxt[cur][c] = tot; ans[tot] = get_ans(to); val[tot] = to;
q.push({tot, len + 1});
}
}
For(i, 1, tot) vec[ans[i]].push_back(i);
For(i, 0, D - 1){
For(j, 0, i) for(int k : vec[j]) f[i][0][k] = 1;
For(j, 1, L - 1) For(k, 1, tot){
f[i][j][k] += f[i][j - 1][k];
For(c, 1, 9) f[i][j][k] += f[i][j - 1][nxt[k][c]];
}
}
}
int lim;
int query(int x){
if(lim >= 10) return x + 1;
int ans = 0, now = 1; x++; vector<int> st;
while(x) st.push_back(x % 10), x /= 10;
reverse(st.begin(), st.end());
int len = st.size(); ans += f[lim][len - 1][1];
For(i, 0, len - 1){
int x = st[i];
For(c, (i == 0), x - 1){
int to = (c == 0) ? now : nxt[now][c];
ans += f[lim][len - i - 1][to];
}
now = (x == 0) ? now : nxt[now][x];
}
return ans;
}
void Solve(){
int l, r; cin >> l >> r >> lim;
cout << query(r) - query(l - 1) << '\n';
}
signed main(){
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
int T; cin >> T; bfs();
while(T--) Solve();
return 0;
}
标签:nxt,int,len,ans,now,dp,小记
From: https://www.cnblogs.com/los114514/p/18349510