题意:有一个小根堆和 \(1\) ~ \(n\) 个数,以及一个操作序列,+ 表示 \(push\), - 表示 \(pop\),\(pop\) 有 \(m\) 次,问你有多少种插入顺序使得最后的 pop 集合与给出的的数字集合 \(Y\) 相同。
首先有个浅显的发现:对于不在 \(Y\) 集合中的数,可选范围形如一个阶梯,换句话说,就是可选范围为 \([l_i,n]\), \(\forall_{i < m},l_i > l_{i+1}\)。
设集合 \(Y\) 从小到大第 \(i\) 个元素为 \(Y_i\)
所以有 \(dp_{i,j,k,t \in0,1}\) 表示现在在操作序列的第 \(i\) 个符号处,现在元素的可选范围为 \([Y_j,n]\),现在堆里有 \(k\) 个元素是在 \(Y\) 集合中的,\(0/1\) 表示 \(Y_j\) 是否被加入进了堆中。
首先我们考虑 + 操作的 dp 转移。
-
若加入进去的数不在 \(Y\) 集合中 :\(val(i,j,k) \times dp_{i,j,k,t} \to dp_{i+1,j,k,t}\) ( \(val(i,j,k)\) 表示的是系数)
-
若加入进去的数在 \(Y\) 集合中
- 加入的数为 \(Y_j\) :\(dp_{i,j,k,0} \to dp_{i+1,j,k+1,1}\)
- 加入的数不为 \(Y_j\) :\(val^{'}(i,j,k) \times dp_{i,j,k,t} \to dp_{i+1,j,k+1,t}\)
对于 \(val(i,j,k)\),我们发现这是好算的,若前面已经 \(push\) 了 \(x\) 次,\(pop\) 了 \(y\) 次(下面的 \(x\) 和 \(y\) 是相同意义),那么 \(val(i,j,k) = (n - V_j) - (m - j) - (x - y - k)\)。
但是我们发现 \(val^{'}(i,j,k)\) 就并不是那么好求了,这时候就有一种很妙的方法:就是在插入的时候不去考虑插入了哪个数,在 \(pop\) 时再去考虑顺序吗,也就是说 \(val^{'}(i,j,k)\) 不需要在 push 时不考虑。
然后是 - 操作的 dp 转移。
-
若 \(k>1\) 或 \(k=1,t=0\),这时候 \(j\) 不会改变 :\(dp_{i,j,k,t} \to dp{i+1,j,k-1,t}\)
-
若 \(k=1,t=1\),这是关键转移,我们枚举在把 \(V_j\) 踢掉之后,下一个限制范围为 \(p\),则有转移:
\(val^{''}(i,j,1) \times dp_{i,j,1,1} \to dp_{i+1,p,0,0}\),\(val^{''}(i,j,1) = \binom{y-(m-j)}{j-p-1} \times (j-p-1)!\)。
解释一下 \(val^{''}(i,j,1)\) 的含义:因为如果要到 \(p\) 这一个位置上的话,就要保证 \([p+1,j-1]\) 的元素都被删除了,所以是这个形式。
点击查看代码
#include<bits/stdc++.h>
#define fir first
#define sec second
#define int long long
#define mkp(a,b) make_pair(a,b)
using namespace std;
typedef pair<int,int> pir;
inline int read(){
int x=0,f=1; char c=getchar();
while(!isdigit(c)){if(c=='-') f=-1; c=getchar();}
while(isdigit(c)){x=x*10+(c^48); c=getchar();}
return x*f;
}
const int mod=998244353,inf=1e18,N=305;
int n,m;
char s[N*2];
int a[N],dp[N][N][2],tmp[N][N][2],C[N][N],jie[N];
inline void init(){
C[0][0]=1;
for(int i=1;i<=m;i++){
C[i][0]=1;
for(int j=1;j<=i;j++)
C[i][j]=(C[i-1][j]+C[i-1][j-1])%mod;
}
jie[0]=1; for(int i=1;i<=m;i++) jie[i]=jie[i-1]*i%mod;
}
signed main(){
freopen("heap.in","r",stdin);
freopen("heap.out","w",stdout);
n=read(),m=read();
scanf("%s",s);
for(int i=1;i<=m;i++) a[i]=read();
init(); sort(a+1,a+m+1);
dp[m][0][0]=1;
int pu=0,po=0;
for(int i=0;i<n+m;i++){
memcpy(tmp,dp,sizeof(dp));
memset(dp,0,sizeof(dp));
if(s[i]=='+'){
for(int j=0;j<=m;j++) for(int k=0;k<=j;k++){
(dp[j][k+1][0]+=tmp[j][k][0])%=mod;
(dp[j][k+1][1]+=tmp[j][k][0])%=mod;
(dp[j][k+1][1]+=tmp[j][k][1])%=mod;
int val=n-a[j]-(m-j)-(pu-po-k);
(dp[j][k][0]+=val*tmp[j][k][0])%=mod;
(dp[j][k][1]+=val*tmp[j][k][1])%=mod;
}
pu++;
}
else{
for(int j=1;j<=m;j++){
for(int k=1;k<=j;k++){
if(k!=1) (dp[j][k-1][1]+=tmp[j][k][1])%=mod;
(dp[j][k-1][0]+=tmp[j][k][0])%=mod;
}
int t=po-(m-j);
for(int p=j-t-1;p<j;p++)
(dp[p][0][0]+=C[t][j-p-1]*jie[j-p-1]%mod*tmp[j][1][1])%=mod;
}
po++;
}
// for(int j=0;j<=m;j++) for(int k=0;k<=pu;k++) cout<<j<<' '<<k<<' '<<dp[j][k][0]<<' '<<dp[j][k][1]<<'\n';
// puts("\n");
}
cout<<dp[0][0][0]<<'\n';
}