点击查看代码
#include <bits/stdc++.h>
using namespace std;
int sa[500005];
int rk[20][500005],w,p[25],r[500005],h[500005];
stack<int>s1;
stack<int>s2;
long long n;
struct t1
{
char c;
int id;
}t[500005];
bool cmp1(t1 a,t1 b)
{
return a.c<b.c;
}
bool cmp2(int a,int b)
{
if(rk[w-1][a]!=rk[w-1][b])
{
return rk[w-1][a]<rk[w-1][b];
}
return rk[w-1][min(a+p[w-1],(int)n+1)]<rk[w-1][min(b+p[w-1],(int)n+1)];
}
int main()
{
string s;
cin>>s;
n=s.size();
for(int i=1;i<=n;i++)
{
t[i].c=s[i-1];
t[i].id=i;
sa[i]=i;
}
p[0]=1;
for(int i=1;i<=20;i++)
{
p[i]=p[i-1]*2;
}
sort(t+1,t+n+1,cmp1);
int cnt=0;
t[0].c=' ';
for(int i=1;i<=n;i++)
{
if(t[i].c!=t[i-1].c)
{
cnt++;
}
rk[0][t[i].id]=cnt;
}
for(int i=1;i<=19;i++)
{
w=i;
sort(sa+1,sa+n+1,cmp2);
cnt=0;
for(int j=1;j<=n;j++)
{
if(rk[w-1][sa[j]]!=rk[w-1][sa[j-1]]||rk[w-1][min(sa[j]+p[w-1],(int)n+1)]!=rk[w-1][min(sa[j-1]+p[w-1],(int)n+1)])
{
cnt++;
}
rk[w][sa[j]]=cnt;
}
}
for(int i=1;i<=n;i++)
{
r[i]=rk[19][i];
}
int k=0;
for(int i=1;i<=n;i++)
{
if(r[i]==1)
{
h[1]=0;
k=0;
continue;
}
while(i+k-1<n&&sa[r[i]-1]+k-1<n&&s[i+k-1]==s[sa[r[i]-1]+k-1])
{
k++;
}
h[r[i]]=k;
if(k)
{
k--;
}
}
long long ans=(1+n)*n/2*(n-1),cur=0;
for(int i=1;i<=n;i++)
{
while(!s1.empty()&&h[i]<h[s1.top()])
{
cur=cur-h[s1.top()]*s2.top();
s1.pop();
s2.pop();
}
if(s2.empty())
{
s2.push(i-1);
}
else
{
s2.push(i-1-(s1.top()-1));
}
s1.push(i);
cur=cur+h[i]*s2.top();
ans=ans-cur*2;
}
cout<<ans<<endl;
return 0;
}