【多项式】[LGP4173] 残缺的字符串
题意
给定两个有通配符的字符串,跑字符串匹配。
思路
肯定不能用 kmp
(不要问为什么)。
设 \(A_{1:m}\) 为模式串,\(B_{1:n}\) 为文本串。
定义一个函数 \(d(x,y)\geq 0\),且 \(d(x,y)=0\) 当且仅当 \(A_x=B_y\),即 \(A\) 的第 \(x\) 位和 \(B\) 的第 \(y\) 位匹配。
定义关于 \(x\) 的函数 \(P(x)\geq 0\) ,且 \(P(x)=0\) 当且仅当 \(A_{1:m}=B_{x:x+m-1}\),即 \(A\) 和 \(B\) 在 \(B\) 的第 \(x\) 位匹配。
那么应该 \(\forall i \in[1,m]\cap\Z, d(i,x+i-1)=0\) ,这样看上去没有优化前途。
既然他的值域都是正数,那么上面条件可以等价于 \(\displaystyle \sum_{i=1}^m\bigg[d(i,x+i-1)\bigg]=0\)。
可以定义 \(d(x,y)=A_x\times B_y\times(A_x-B_y)^2\) ,当字符为通配符时其值为 \(0\) ,否则为其 ASCII
码的值。
那么 函数
\[\begin{aligned} P(x)&=\sum_{i=1}^mA_iB_{x+i-1}(A_i-B_{x+i-1})^2\\ &=\sum_{i=1}^m\left(A_i^3B_{x+i-1}+A_iB_{x+i-1}^3-2A_i^2B_{x+i-1}^2\right)\\ &=\sum_{i=1}^mA_i^3B_{x+i-1}+\sum_{i=1}^mA_iB_{x+i-1}^3-2\left(\sum_{i=1}^mA_i^2B_{x+i-1}^2\right) \end{aligned} \]还是不好弄,考虑翻转字符串 \(A\) 为 \(S\),那么 \(A_i=S_{m-i+1}\)。
代入 函数
\[P(x)=\sum_{i=1}^mS_{m-i+1}^3B_{x+i-1}+\sum_{i=1}^mS_{m-i+1}B_{x+i-1}^3-2\left(\sum_{i=1}^mS_{m-i+1}^2B_{x+i-1}^2\right) \]可以发现每一个加项都是关于 \(S_{m-i+1}\) 和 \(B_{x+i-1}\) 的单项式,而他们下标加起来是一个常数 \(m+x\)。
因此,关于 \(x\) 的 函数 \(P(x)\) 就可以化成一个多项式计算。\(P(x)=\displaystyle\sum_{i+j=m+x}\bigg(S_i^3B_j+S_iB_j^3-2S_i^2B_j^2\bigg)\)
求和可以拆开,三个式子加起来。三个式子处理方式相同。比如第一个式子 \(\displaystyle\sum_{i+j=m+x}S_i^3B_j\) ,其实就是取多项式 \(S^{(3)}B\) 的第 \(m+x\) 项系数,其中 \(S^{(3)}\) 表示 \(S\) 中各项系数都变成 \(3\) 次方后的多项式。
所以暴力做一些多项式乘法就可以得出关于 \(x\) 的 函数 \(P(x)\) 的值。
看上去要做 \(3\) 次多项式乘法,常数爆炸。
专业来说,这道题的思路就是构造 \(P(i)\) 的生成函数 \(F(x)=\displaystyle\sum_{i=1}^nP(i)x^i=S^{(3)}B+SB^{(3)}-2S^{(2)}B^{(2)}\)。算出后面多项式的系数就可以得到所有的 \(P(i)\)。
点击查看代码
#include <cstdio>
#include <cstdlib>
#include <vector>
#include <algorithm>
#include <cstring>
const int N=1e5+1;
int n,m;
char s[N],t[N];
int rplcment[N];
int bty[N<<1],mx; // bty[i]: the max beauty in the front i characters.
int ans[N<<1],mn; // ans[i]:the minimum of replacement when catch the beauty bty[i].
const int mod=998244353,PHI=998244352,G=3,invG=332748118;
inline int fastpow(long long a,int k) {
int res=1;
while(k) {
if(k&1) res=a*res%mod;
a=a*a%mod; k>>=1;
}
return res;
}
inline int plus(int x,int y) { x+=y-mod; return x+((x>>31)&mod); }
inline int minus(int x, int y) { x-=y; return x+((x>>31)&mod); }
typedef std::vector<int> poly;
inline void NTT(poly& a,const int limit,const int B[],const int I=1) {
for(int i=1;i<limit;++i) if(i<B[i]) a[i]^=a[B[i]]^=a[i]^=a[B[i]];
for(int slen=1;slen<limit;slen<<=1) {
const int g = fastpow(I==1? G:invG,PHI/(slen<<1));
for(int j=0;j<limit;j+=slen<<1) {
long long rt=1;
for(int opt=0;opt<slen;++opt) {
const int x=a[j+opt],y=a[j+opt+slen]*rt%mod;
a[j+opt]=plus(x,y); a[j+opt+slen]=minus(x,y);
rt=rt*g%mod;
}
}
}
}
inline poly operator*(poly a,poly b) {
const int deg=a.size()+b.size()-2;
int n,k=0;
while((1<<++k)<=deg);
static int *B=(int*)malloc(sizeof(int));
if((n=1<<k)!=B[0]) {
B=(int*)realloc(B,sizeof(int)*n);
B[0]=0;
for(int i=1;i<n;++i) B[i]=(B[i>>1]>>1)|((i&1)<<(k-1));
B[0]=n;
}
a.resize(n+1); b.resize(n+1);
NTT(a,n,B); NTT(b,n,B);
for(int i=0;i<=n;++i) a[i]=1ll*a[i]*b[i]%mod;
NTT(a,n,B,-1);
a.resize(deg+1);
const long long inv=fastpow(n,mod-2);
for(int i=0;i<=deg;++i) a[i]=inv*a[i]%mod;
return a;
}
inline int hsh(char s) {
if(s=='?') return 0;
return s;
}
inline void work() {
poly a,b,c,d;
a.resize(n); b.resize(m);
for(int i=0;i<n;++i) a[i]=hsh(s[i]);
for(int j=0;j<m;++j) b[j]=hsh(t[j]);
std::reverse(b.begin(),b.end());
c=a; d=b;
for(int i=0;i<n;++i) c[i]=c[i]*c[i]*c[i];
for(int j=0;j<m;++j) d[j]=d[j]*d[j]*d[j];
c=b*c; d=a*d;
for(int i=0;i<n;++i) a[i]=a[i]*a[i];
for(int j=0;j<m;++j) b[j]=b[j]*b[j];
a=a*b;
memset(rplcment,0x80,sizeof(rplcment));
for(int i=0;i<=n-m;++i) if(c[i+m-1]+d[i+m-1]-(a[i+m-1]<<1)==0) rplcment[i+m-1]=0;
int tmp=0; for(int i=0;i<m;++i) tmp+=s[i]=='?';
for(int i=m-1;i<n;++i) {
if(rplcment[i]==0) rplcment[i]=tmp;
tmp+=(s[i+1]=='?')-(s[i-m+1]=='?');
}
return;
}
int main() {
scanf("%d%s%d",&n,s,&m);
for(int i=0;i<m;++i) t[i]=(i&1? 'b':'a');
work();
int i=m-1;
while(i<n+m) {
if(rplcment[i]>=0) bty[i]=mx+1, ans[i]=mn+rplcment[i];
else bty[i]=mx, ans[i]=mn;
if(bty[++i-m]>mx) {
mx=bty[i-m];
mn=ans[i-m];
}else if(bty[i-m]==mx) mn=std::min(mn,ans[i-m]);
}
printf("%d\n",mn);
return 0;
}