manacher
- 用途 : 找字符串中的最长的回文子串。
考虑该问题【模板】manacher 求最长回文串长度。该如何做?
-
暴力 \(O(n^2)\)
就是枚举回文中心,向外拓展。代码太简单了,就不挂了。
其实是懒得打 -
二分+hash \(O(n\log n)\)
将字符串正向hash,反向hash,枚举回文中心,二分答案即可。
足以通过本题了这代码跑的比一些实现较劣的正解快点此查看代码
#include<bits/stdc++.h> #include<bits/extc++.h> // using namespace __gnu_pbds; // using namespace __gnu_cxx; using namespace std; #define infile(x) freopen(x,"r",stdin) #define outfile(x) freopen(x,"w",stdout) #define errfile(x) freopen(x,"w",stderr) using ll=long long;using ull=unsigned long long; #ifdef LOCAL FILE *InFile = infile("in.in"),*OutFile = outfile("out.out"); // FILE *ErrFile=errfile("err.err"); #else FILE *Infile = stdin,*OutFile = stdout; //FILE *ErrFile = stderr; #endif const int N = 1.1e7 + 10,base = 131; char s[N]; ull pw[N],has1[N],has2[N]; int n,ans = 0; inline ull get_hash1(int l,int r){return has1[r] - has1[l-1]*pw[r-l+1];} inline ull get_hash2(int l,int r){return has2[l] - has2[r+1]*pw[r-l+1];} inline bool check1(int mid,int pos){ return get_hash1(pos-mid,pos-1) == get_hash2(pos,pos+mid-1); } inline bool check2(int mid,int pos){ return get_hash1(pos-mid,pos-1) == get_hash2(pos+1,pos+mid); } inline void binary1(int i){ int r = min(i-1,n-i),l = max(ans/2,1),res = 0; while(l <= r){ int mid = (l+r)>>1; if(get_hash1(i-mid,i-1) == get_hash2(i+1,i+mid)) res = mid,l = mid+1; else r = mid-1; } ans = max(ans,res * 2 + 1); } inline void binary2(int i){ int r = min(i-1,n-i) + 1,l = max(ans/2,1),res = 0; while(l <= r){ int mid = (l+r)>>1; if(get_hash1(i-mid,i-1) == get_hash2(i,i+mid-1)) res = mid,l = mid+1; else r = mid-1; } ans = max(ans,res*2); } signed main(){ cin.tie(nullptr)->sync_with_stdio(false); cout.tie(nullptr)->sync_with_stdio(false); cin >> (s+1); n = strlen(s+1); pw[0] = 1; ans = 1; for(int i = 1;i <= n; ++i) pw[i] = pw[i-1]*base; for(int i = 1;i <= n; ++i) has1[i] = has1[i-1]*base+s[i]; for(int i = n;i >= 1; --i) has2[i] = has2[i+1]*base+s[i]; for(int i = 1;i <= n; ++i){ if(s[i-1] == s[i+1]) binary1(i); if(s[i] == s[i-1]) binary2(i); } cout<<ans; }
-
hash \(O(n)\)
我不会。挂个oi-wiki的讲解
我们发现,以上这些方法,要么时间复杂度超标,要么代码难度和细节上超标。
那么有没有一种方法可以既不会T,也非常好写的算法呢?
-
manacher \(O(n)\)
优点 : 常数小,代码简洁。
-
实现方法 : 考虑优化最上面的内个暴力。
枚举回文中心有两种情况:
- 形如\(aba\)的长度为奇数的回文串,回文中心是它最中间的字符。
- 形如\(abba\)的长度为偶数的回文串,回文中心是两个字符最中间的字符中间的空隙
这样就要分类讨论,非常不符合manacher简洁的特点。
所以manacher对其的优化就是在相邻两个字符中间和字符串首尾(为了方便处理)插入一个不会被用到的字符(通常是$,#,@等,这里我选择使用#)
那么形如\(aba\)的回文串就会变成 #a#b#a#,没有影响。
形如\(abba\)的回文串就会变成 #a#b#b#a#,我们枚举回文中心至第二个#时就会将\(abba\)计入
注意我们要在字符串最前面和字符串最后面插入两个不同的字符,使其变成形如 $#a#b#n#j#k#@的形式,防止越界
我们发现,暴力较慢的原因是因为它重复遍历了子串。
如何将这一部分优化掉
我们记录一个变量r为已经触及到的最右边的字符的位置。
再记录一个变量mid表示回文串中包含r的回文串的回文中心的位置。
用一个数组\(len_i\)表示第i个点能够扩展出的回文串长度。
当\(mid<i<r\)时,设\(i\)关于\(mid\)的对称点为\(j\),我们可以证明一定有\([j-len_j,j+len_j] = [i-len_i,i+len_i]\)
注意当\(i+len_i>r\)时,无法保证上述情况相等。
对于该情况,我们直接令\(len_i=\min(len_j,r-i+1)\),暴力扩展。
反之,若\(i>r\),此时只能暴力扩展。
点此查看代码
#include<bits/stdc++.h> #include<bits/extc++.h> // using namespace __gnu_pbds; // using namespace __gnu_cxx; using namespace std; using ll=long long;using ull=unsigned long long; using db = double;using ldb = long double; #define infile(x) freopen(x,"r",stdin) #define outfile(x) freopen(x,"w",stdout) #define errfile(x) freopen(x,"w",stderr) #ifdef LOCAL FILE *InFile = infile("in.in"),*OutFile = outfile("out.out"); // FILE *ErrFile=errfile("err.err"); #else FILE *Infile = stdin,*OutFile = stdout; //FILE *ErrFile = stderr; #endif const int N = 1.1e7 + 100; int n,tot,len[N<<1],mid,mxr,ans = 0;//注意len数组要开二倍空间 char s1[N],s[N<<1]; signed main(){ cin.tie(nullptr)->sync_with_stdio(false); cout.tie(nullptr)->sync_with_stdio(false); cin>>(s1+1); n = strlen(s1+1); s[++tot] = '$';s[++tot] = '#'; for(int i = 1;i <= n; ++i) s[++tot] = s1[i],s[++tot] = '#'; s[++tot] = '@';s[++tot] = '\0'; n = strlen(s+1); for(int i = 2;i < n; ++i){ if(i < mxr) len[i] = min(len[mid*2-i],mxr-i+1); else len[i] = 1; while(s[i - len[i]] == s[i + len[i]]) ++len[i]; if(i + len[i] - 1 > mxr) mxr = i + len[i] - 1,mid = i; ans = max(ans,len[i]); } cout<<ans-1;//千万注意这个 }
复杂度的话,感性理解一下吧,就是mid和r都单调不降,所以复杂度为线性。
-
例题 :
-
用桶存一下即可,其它板子。
点此查看代码
#include<bits/stdc++.h> #include<bits/extc++.h> // using namespace __gnu_pbds; // using namespace __gnu_cxx; using namespace std; using ll=long long;using ull=unsigned long long; using db = double;using ldb = long double; #define infile(x) freopen(x,"r",stdin) #define outfile(x) freopen(x,"w",stdout) #define errfile(x) freopen(x,"w",stderr) #ifdef LOCAL FILE *InFile = infile("in.in"),*OutFile = outfile("out.out"); FILE *ErrFile=errfile("err.err"); #else FILE *Infile = stdin,*OutFile = stdout; //FILE *ErrFile = stderr; #endif const int N = 2e6 + 10,mod = 19930726; char s1[N],s[N<<1]; int len[N<<1],n,mid,mxr,tot,num[N<<1]; ll k; inline int power(int a,int b,int mod){ int res = 1; for(; b;b >>= 1,a = 1ll * a * a % mod) if(b&1)res = 1ll * res * a % mod; return res; } signed main(){ cin.tie(nullptr)->sync_with_stdio(false); cout.tie(nullptr)->sync_with_stdio(false); cin>>n>>k>>(s1+1); s[++tot] = '$',s[++tot] = '#'; for(int i = 1;i <= n; ++i) s[++tot] = s1[i],s[++tot] = '#'; s[++tot] = '@',s[++tot] = '\0'; n = strlen(s+1); for(int i = 2;i < n; ++i){ if(i < mxr) len[i] = min(len[mid*2-i],mxr-i+1); else len[i] = 1; while(s[i - len[i]] == s[i + len[i]]){ ++len[i]; } if(i + len[i] - 1 > mxr) mxr = i + len[i] - 1,mid = i; if((len[i] - 1) & 1) num[len[i]-1]++; } ll res = 1,sum = 0; for(int i = n;i >= 1; --i){ if(!(i&1)) continue; sum += num[i]; if(sum > k) return cout<<(res = 1ll * res * power(i,k,mod) % mod)<<'\n',0; else res = 1ll * res * power(i,sum,mod) % mod,k -= sum; } cout<<(k?-1:res); }
-
模板题,就是练一下手感。
点此查看代码
#include<bits/stdc++.h> #include<bits/extc++.h> // using namespace __gnu_pbds; // using namespace __gnu_cxx; using namespace std; #define infile(x) freopen(x,"r",stdin) #define outfile(x) freopen(x,"w",stdout) #define errfile(x) freopen(x,"w",stderr) #ifdef LOCAL FILE *InFile = infile("in.in"),*OutFile = outfile("out.out"); // FILE *ErrFile=errfile("err.err"); #else FILE *Infile = stdin,*OutFile = stdout; //FILE *ErrFile = stderr; #endif bool StdIn = cin.tie(nullptr)->sync_with_stdio(false); bool StdOut = cout.tie(nullptr)->sync_with_stdio(false); using ll=long long;using ull=unsigned long long; using db = double;using ldb = long double; const int N = 1e7 + 10; int n,tot,len[N<<1],mid,mxr,ans = 0; char s1[N],s[N<<1]; inline void manacher(){ cin>>(s1+1);n = strlen(s1+1); tot = mid = mxr = ans = 0; s[++tot] = '$';s[++tot] = '#'; for(int i = 1;i <= n; ++i) s[++tot] = s1[i],s[++tot] = '#'; s[++tot] = '@';s[++tot] = '\0'; n = strlen(s+1); for(int i = 2;i < n; ++i){ if(i < mxr) len[i] = min(len[mid*2-i],mxr-i+1); else len[i] = 1; while(s[i - len[i]] == s[i + len[i]]) ++len[i]; if(i + len[i] - 1 > mxr) mxr = i + len[i] - 1,mid = i; ans = max(ans,len[i] - 1); } cout<<ans<<'\n'; } signed main(){ int T;cin>>T; while(T--) manacher(); }
-
枚举分割点应该是非常容易想到的。
我们记录一下以该点为左端点的最长回文串长度,以该点为右端点的最长回文串长度。
但我们发现manacher无法求出每个点这个值。
0|1|2|3|4|5|6|7|8|9|10|11|12|13|14|15|16|17
$|#|a|#|b|#|a|#|b|#|a |# |c |# |c |# |d |@
比如这个,我们发现\(l_3\)求出来是0,但很明显,\(l_3\)应该是3,因为bab为一个回文串。
因为我们在\(i=6\)时,\(len_i=6\),所以\(l_3\)没有更新,只更新到了\(l_1\),所以我们要将\(l_3\)通过\(l_1\)更新出来
所以我们再一次递推,将\(l_i\)与\(l_{i-2}-2\)取max
同理,\(right_i\)与\(right_{i+2}-2\)取max
点此查看代码
#include<bits/stdc++.h> #include<bits/extc++.h> // using namespace __gnu_pbds; // using namespace __gnu_cxx; using namespace std; #define infile(x) freopen(x,"r",stdin) #define outfile(x) freopen(x,"w",stdout) #define errfile(x) freopen(x,"w",stderr) #ifdef LOCAL FILE *InFile = infile("in.in"),*OutFile = outfile("out.out"); // FILE *ErrFile=errfile("err.err"); #else FILE *Infile = stdin,*OutFile = stdout; //FILE *ErrFile = stderr; #endif bool StdIn = cin.tie(nullptr)->sync_with_stdio(false); bool StdOut = cout.tie(nullptr)->sync_with_stdio(false); using ll=long long;using ull=unsigned long long; using db = double;using ldb = long double; const int N = 2e5 + 10; int lft[N],rht[N],len[N],n,tot,mid,mxr; char s1[N],s[N]; signed main(){ cin>>(s1+1); n = strlen(s1+1); s[++tot] = '$',s[++tot] = '#'; for(int i = 1;i <= n; ++i) s[++tot] = s1[i],s[++tot] = '#'; s[++tot] = '@';s[++tot] = '\0'; n = strlen(s+1); for(int i = 2;i < n; ++i){ if(i < mxr) len[i] = min(len[mid*2-i],mxr-i+1); else len[i] = 1; while(s[i - len[i]] == s[i + len[i]]) len[i]++; if(i + len[i] - 1 > mxr) mxr = i + len[i] - 1,mid = i; lft[i + len[i] - 1] = max(len[i] - 1,lft[i + len[i] - 1]), rht[i - len[i] + 1] = max(len[i] - 1,rht[i - len[i] + 1]); } for(int i = 2;i <= n - 1;i += 2) rht[i] = max(rht[i],rht[i + 2] - 2); for(int i = n - 1;i >= 2;i -= 2) lft[i] = max(lft[i],lft[i - 2] - 2); int ans = 0; for(int i = 2;i < n;i += 2) if(lft[i] && rht[i]) ans = max(ans,lft[i] + rht[i]); cout<<ans; }