(矩乘代码在最后)
看一眼题,没什么思路,那大概就是个 dp 了。
先求已知序列的方案数。
考虑怎么设计状态。
因为 本质不同 ,如果设 \(f[i]\) 表示,前 \(i\) 个元素构成的序列的本质不同子序列个数似乎不太好转移。
但是 \(f[i]\) 表示以 \(i\) 结尾的子序列方案数似乎要好一些。让每个子序列只在第一次出现时被统计。
转移:
- 当 \(a[i]\) 与前面的字符都不同时,那么前面的所有方案后面再加上这个字符一定是不同的,并且这个字符本身也算一个新的子序列。于是有了
- 当 \(a[i]\) 在前面出现过时,记前一个位置为 \(lst[i]\),那么 \(lst[i]\) 以前的方案末尾加上当前字符已经在 \(lst[i]\) 被统计过了,所以只能在以 \(lst[i]\) 到 \(i\) 中间这一段字符结尾的方案后面加当前字符。于是有
显而易见,上面的求和可以用前缀和,再记录一个 \(lst\) 数组就行 能求出已知序列的方案数了。
for(int i=1;i<=n;++i){
if(!lst[a[i]]) f[i]=sum[i-1]-sum[0]+1;
else f[i]=sum[i-1]-sum[lst[a[i]]-1];
f[i]=(f[i]%p+p)%p;
lst[a[i]]=i,sum[i]=sum[i-1]+f[i];
sum[i]%=p;
}
从上面的转移得知,每次用前缀和更新时,右端点一定是 \(i-1\),为了让值尽量大,那就一定要让左端点尽量靠前,也就是 \(lst[i]\) 尽量小。
这样就有了一个贪心策略,每次填字符的时候填当前 \(lst\) 最小的字符。稍稍模拟一下就会发现这其实是循环填进去的。
所以只需要把上面的 \(lst\) 数组拿来排一下序,按照排序结果不断地填字母,并且按照上面同样的方法计算答案,直到填的个数到达 \(n\)。
因为有空序列的存在,最后答案加一就好。
排序复杂度很小忽略不计,所以总复杂度应该是 \(O(n+m)\) 的。
Code
#include<bits/stdc++.h>
using namespace std;
#define int long long
int read(){
int sum=0,f=1;char a=getchar();
while(a<'0' || a>'9'){if(a=='-') f=-1;a=getchar();}
while(a>='0' && a<='9') sum=sum*10+a-'0',a=getchar();
return sum*f;
}
const int N=2e6+5,p=1e9+7;
int lst[105],f[N],sum[N],a[N],q[N];
int n,m,k;
bool cmp(int _,int __){return lst[_]<lst[__];}
signed main(){
m=read(),k=read();
string s;cin>>s;n=s.size();
for(int i=0;i<n;++i) a[i+1]=s[i]-'a'+1;
for(int i=1;i<=n;++i){
if(!lst[a[i]]) f[i]=sum[i-1]-sum[0]+1;
else f[i]=sum[i-1]-sum[lst[a[i]]-1];
f[i]=(f[i]%p+p)%p;
lst[a[i]]=i,sum[i]=sum[i-1]+f[i];
sum[i]%=p;
}
for(int i=1;i<=k;++i) q[i]=i;
sort(q+1,q+k+1,cmp);
for(int i=n+1,j=1;i<=n+m;++i){
a[i]=q[j];
if(!lst[a[i]]) f[i]=sum[i-1]-sum[0]+1;
else f[i]=sum[i-1]-sum[lst[a[i]]-1];
f[i]=(f[i]%p+p)%p;
lst[a[i]]=i,sum[i]=sum[i-1]+f[i];
sum[i]%=p;
++j;if(j==k+1) j=1;
}
cout<<sum[n+m]+1;
return 0;
}
最后提一嘴,见过一个加强版,\(n\) 直接到了 \(10^{18}\),这个时候就不能直接枚举了,观察发现在填了几个字符后 \(sum[i]\) 的计算方式就比较固定了,可以用矩阵乘法优化。
放一个矩阵乘法的代码:
点击查看代码
#include<bits/stdc++.h>
using namespace std;
#define int long long
int read(){
int sum=0,f=1;char a=getchar();
while(a<'0' || a>'9'){if(a=='-') f=-1;a=getchar();}
while(a>='0' && a<='9') sum=sum*10+a-'0',a=getchar();
return sum*f;
}
const int N=2e6+5,p=1e9+7;
int lst[105],f[N],sum[N],a[N],q[N];
int n,m,k;
bool cmp(int _,int __){
return lst[_]<lst[__];
}
struct Mat{
int n,m,a[105][105];
Mat operator*(const Mat &x){
Mat y;y.n=n,y.m=x.m;
for(int i=1;i<=n;++i)
for(int j=1;j<=y.m;++j){
y.a[i][j]=0;
for(int k=1;k<=m;++k) y.a[i][j]+=a[i][k]*x.a[k][j]%p,y.a[i][j]=(y.a[i][j]%p+p)%p;
}
return y;
}
void print(){
cout<<n<<" "<<m<<"\n";
for(int i=1;i<=n;++i,cout<<"\n")
for(int j=1;j<=m;++j) cout<<a[i][j]<<" ";
cout<<"\n";
}
};
Mat ksm(Mat x,int y){
Mat res=x;--y;
while(y>0){
if(y&1) res=res*x;
x=x*x,y>>=1;
}
return res;
}
signed main(){
m=read(),k=read();
string s;cin>>s;n=s.size();
for(int i=0;i<n;++i) a[i+1]=s[i]-'a'+1;
for(int i=1;i<=n;++i){
if(!lst[a[i]]) f[i]=sum[i-1]-sum[0]+1;
else f[i]=sum[i-1]-sum[lst[a[i]]-1];
f[i]=(f[i]%p+p)%p;
lst[a[i]]=i,sum[i]=sum[i-1]+f[i];
sum[i]%=p;
}
for(int i=1;i<=k;++i) q[i]=i;
sort(q+1,q+k+1,cmp);
int l=1,pos=n;
while(l<=k && (pos+1-lst[q[l]]!=k || !lst[q[l]]) && m){
++pos,--m;
a[pos]=q[l];++l;
if(!lst[a[pos]]) f[pos]=sum[pos-1]-sum[0]+1;
else f[pos]=sum[pos-1]-sum[lst[a[pos]]-1];
lst[a[pos]]=pos;
f[pos]=(f[pos]%p+p)%p;
sum[pos]=sum[pos-1]+f[pos],sum[pos]%=p;
}
if(!m){
cout<<sum[pos]+1;
return 0;
}
for(int i=1;i<=k;++i) q[i]=i;
sort(q+1,a+k+1,cmp);
Mat ans,A;
ans.n=1,ans.m=k+1;
for(int i=1;i<=k+1;++i) ans.a[1][i]=sum[pos-k+i-1];
A.n=k+1,A.m=k+1;
memset(A.a,0,sizeof A.a);
for(int i=2;i<=k+1;++i) A.a[i][i-1]=1;
A.a[1][k+1]=-1,A.a[k+1][k+1]=2;
ans=ans*ksm(A,m);
cout<<ans.a[1][k+1]+1;
return 0;
}