题意简述
有长度为 \(n\) 的 \(s_i=0/1\),求满足下列条件的长度为 \(n\) 的序列 \(a\) 的个数,对 \(998244353\) 取模:
- \(\forall i,0\le a_i\le m\)
- 当 \(s_i=0\) 时,\(a_i\not=\operatorname{mex}(a_1,a_2,\cdots,a_{i-1})\)
- 当 \(s_i=1\) 时,\(a_i=\operatorname{mex}(a_1,a_2,\cdots,a_{i-1})\)
\(n\le 5000,m\le 10^9\)。
分析
首先简单分析下发现 \(ans=m^k\),\(k\) 为 \(s_i=0\) 的 \(i\) 的个数(当 \(s_i=1\) 时只有一种填法,当 \(s_i=0\) 时只有一种填法不能选)。
然后我们发现在 \(m\) 较大时该结论正确(实际上是 \(m\ge n-1,n\ge 2\) 时正确,代码中为 \(m\ge n\),实际上去掉这个特判也不影响正确性)。
当 \(s_i=1\) 时,能填进去的数唯一。问题主要在于如何合理的设计 dp 状态处理掉 \(s_i=0\) 的情况。
发现一个 \(s_i=0\) 的连续段中不能填的数一定。因为在填前面的数时,由于不能填 mex,故 mex 一直没有发生改变。
而且我们发现值和值之间没有太大的本质区别,我们也不需要知道 mex 具体是几,仅仅知道那个数不能填。由此设 \(f_{i,j}\) 表示前 \(i\) 个数中出现了 \(j\) 种不同的数字的方案数。
转移:
- \(s_i=0:f_{i,j}\leftarrow f_{i-1,j-1}\),因为此时只能强制填 mex,而 mex 在之前必定没出现过。
- \(s_i=1:f_{i,j}\leftarrow f_{i-1,j}\times j+f_{i-1,j-1}\times(m-j)\),因为此时可以填 \(j\) 种出现过的数字,也可以填 \(m-j+1\) 种没有出现过的数字,但不能填 mex,而 mex 在之前必定没出现过,所以可以填 \(m-j\) 种数字。
由于 \(n\) 个值里面至多有 \(n\) 个不同值,故时间复杂度 \(O(n^2)\)。
点击查看代码
//#pragma comment(linker, "/stack:200000000")
//#pragma GCC optimize("Ofast")
//#pragma GCC target("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx,tune=native")
#include<iostream>
#include<cstdio>
#include<cstring>
#include<string>
#include<algorithm>
#include<cmath>
#include<map>
#include<unordered_map>
#include<vector>
#include<queue>
#include<bitset>
#include<set>
#include<ctime>
#include<random>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/pb_ds/hash_policy.hpp>
#include <ext/pb_ds/priority_queue.hpp>
#define x1 xx1
#define y1 yy1
#define IOS ios::sync_with_stdio(false)
#define ITIE cin.tie(0);
#define OTIE cout.tie(0);
#define FlushIn fread(Fread::ibuf,1,1<<21,stdin)
#define FlushOut fwrite(Fwrite::obuf,1,Fwrite::S-Fwrite::obuf,stdout)
#define PY puts("Yes")
#define PN puts("No")
#define PW puts("-1")
#define P__ puts("")
#define PU puts("--------------------")
#define popc __builtin_popcount
#define pii pair<int,int>
#define mp make_pair
#define fi first
#define se second
#define gc getchar
#define pc putchar
#define pb emplace_back
#define rep(a,b,c) for(int a=(b);a<=(c);a++)
#define per(a,b,c) for(int a=(b);a>=(c);a--)
#define reprange(a,b,c,d) for(int a=(b);a<=(c);a+=d)
#define perrange(a,b,c,d) for(int a=(b);a>=(c);a-=d)
#define graph(i,j,k,l) for(int i=k[j];i;i=l[i].nxt)
#define lowbit(x) (x&-x)
#define lson(x) (x<<1)
#define rson(x) (x<<1|1)
#define mem(x,y) memset(x,y,sizeof x)
//#define double long double
#define int long long
//#define int __int128
using namespace std;
using namespace __gnu_pbds;
bool greating(int x,int y){return x>y;}
bool greatingll(long long x,long long y){return x>y;}
bool smallingll(long long x,long long y){return x<y;}
namespace Fread {
const int SIZE=1<<21;
char ibuf[SIZE],*S,*T;
inline char getc(){if(S==T){T=(S=ibuf)+fread(ibuf,1,SIZE,stdin);if(S==T)return '\n';}return *S++;}
}
namespace Fwrite{
const int SIZE=1<<21;
char obuf[SIZE],*S=obuf,*T=obuf+SIZE;
inline void flush(){fwrite(obuf,1,S-obuf,stdout);S=obuf;}
inline void putc(char c){*S++=c;if(S==T)flush();}
struct NTR{~NTR(){flush();}}ztr;
}
/*#ifdef ONLINE_JUDGE
#define getchar Fread::getc
#define putchar Fwrite::putc
#endif*/
inline int rd(){
int x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=x*10+ch-48;ch=getchar();}return x*f;
}
inline void write(int x,char ch='\0'){
if(x<0){x=-x;putchar('-');}
int y=0;char z[40];
while(x||!y){z[y++]=x%10+48;x/=10;}
while(y--)putchar(z[y]);if(ch!='\0')putchar(ch);
}
bool Mbg;
const int maxn=5e3+5,maxm=4e5+5,inf=0x3f3f3f3f,mod=998244353;
const long long llinf=0x3f3f3f3f3f3f3f3f;
int n,m,a[maxn];
int ksm(int x,int y){
int res=1;
for(;y;y>>=1,x=x*x%mod)if(y&1)res=res*x%mod;
return res;
}
int f[maxn][maxn];
//填完前i个位置,共出现了j种取值的方案数
void add(int &x,int y){x+=y;if(x>=mod)x-=mod;}
void solve_the_problem(){
n=rd(),m=rd();rep(i,1,n)a[i]=rd();
int sum=0;
rep(i,1,n)if(a[i])sum++;
if(m>=n||!sum){
write(ksm(m,n-sum));
return;
}
f[0][0]=1;
rep(i,1,n)rep(j,1,m+1){
if(a[i])add(f[i][j],f[i-1][j-1]);
else{
add(f[i][j],f[i-1][j-1]*(m-j+1)%mod);
add(f[i][j],f[i-1][j]*j%mod);
}
}
int ans=0;
rep(i,1,m+1)add(ans,f[n][i]);
write(ans);
}
bool Med;
signed main(){
// freopen(".in","r",stdin);freopen(".out","w",stdout);
// fprintf(stderr,"%.3lfMB\n",(&Mbg-&Med)/1048576.0);
int _=1;while(_--)solve_the_problem();
}
/*
*/