- 要统计差值为k的数对(i,j)的数量,这种感觉类似于卷积,我们把和差放到幂次中体现,就可以用NTT做到O(ailogai)
- 其中,对于差值为0的特殊情况,不仅需要减去数自匹配的n种情况,还要除以2
- NTT要预处理step+倍增法优化,否则会TLE
- 将游戏每轮的操作抽象为数学函数
点击查看代码
#include <bits/stdc++.h>
using namespace std;
const int mod=998244353;
int a[1000005];
int v[10000005],prime[10000005],m;
long long cnt[1000005],inv[10000005];
long long f[500005];
int rev[5000005],p[5000005][2];
int read1()
{
char cc=getchar();
while(!(cc>=48&&cc<=57))
{
if(cc=='-')
{
break;
}
cc=getchar();
}
bool f=false;
int s=0;
if(cc=='-')
{
f=true;
}
else
{
s=cc-48;
}
while(1)
{
cc=getchar();
if(cc>=48&&cc<=57)
{
s=s*10+cc-48;
}
else
{
break;
}
}
if(f==true)
{
s=-s;
}
return s;
}
int power(int n,int p)
{
if(p==0)
{
return 1;
}
long long tmp=power(n,p/2);
if(p%2==1)
{
return tmp*tmp%mod*n%mod;
}
return tmp*tmp%mod;
}
void NTT(vector<long long>&f,int opt)
{
int n=f.size();
for(int i=1;i<n;i++)
{
if(i<rev[i])
{
swap(f[i],f[rev[i]]);
}
}
for(int m=2;m<=n;m*=2)
{
int k=m/2;
for(int i=0;i<n;i+=m)
{
long long cur=1,step;
if(opt==1)
{
step=p[m][0];
}
else
{
step=p[m][1];
}
for(int j=0;j<k;j++)
{
long long tmp=cur*f[i+j+k]%mod;
f[i+j+k]=(f[i+j]-tmp)%mod;
f[i+j]=(f[i+j]+tmp)%mod;
cur=cur*step%mod;
}
}
}
}
vector<long long> operator*(vector<long long>a,vector<long long>b)
{
vector<long long>c(a.size()+b.size()-1);
while(c.size()<(1<<22))
{
c.push_back(0);
}
while(a.size()<c.size())
{
a.push_back(0);
}
while(b.size()<c.size())
{
b.push_back(0);
}
NTT(a,1),NTT(b,1);
for(int i=0;i<c.size();i++)
{
c[i]=a[i]*b[i]%mod;
}
NTT(c,-1);
int p=power(c.size(),998244351);
for(int i=0;i<c.size();i++)
{
c[i]=c[i]*p%mod;
}
return c;
}
vector<long long>c,c1(2000001),c2(2000001);
int main()
{
for(int i=1;i<(1<<22);i++)
{
rev[i]=(rev[i>>1]>>1);
if(i&1)
{
rev[i]+=(1<<21);
}
}
for(int i=1;i<=22;i++)
{
p[1<<i][0]=power(3,998244352/(1<<i));
p[1<<i][1]=power(3,998244352-998244352/(1<<i));
}
inv[1]=1;
for(int i=2;i<=10000000;i++)
{
if(v[i]==0)
{
v[i]=i;
prime[++m]=i;
inv[i]=power(i,998244351);
}
for(int j=1;j<=m;j++)
{
if(i*prime[j]>10000000||prime[j]>v[i])
{
break;
}
v[i*prime[j]]=prime[j];
inv[i*prime[j]]=inv[i]*inv[prime[j]]%mod;
}
}
long long n,t;
cin>>n>>t;
for(int i=1;i<=n;i++)
{
a[i]=read1();
c1[a[i]+1000000]++;
c2[-a[i]+1000000]++;
}
c=c1*c2;
for(int i=0;i<1000000;i++)
{
cnt[i]=c[-i+2000000];
}
cnt[0]-=n;
cnt[0]=cnt[0]*power(2,998244351)%mod;
long long p1=(n-2)*power((n*(n-1)/2)%mod,998244351)%mod,p1inv=power(p1,998244351);
vector<long long>a(t+1);
a[0]=power(p1,t);
a[1]=power(p1,t-1)*(1-2*p1)%mod*t%mod;
for(int i=1;i<t;i++)
{
a[i+1]=((1-2*p1)%mod*a[i]%mod*t%mod+2*p1*t%mod*a[i-1]%mod-i*a[i]%mod*(1-2*p1)%mod-a[i-1]*(i-1)%mod*p1%mod)%mod*inv[i+1]%mod*p1inv%mod;
}
long long ans=0;
for(int i=max(0ll,t-1000000);i<=t;i++)
{
ans=(ans+cnt[t-i]*a[i]%mod)%mod;
}
cout<<(ans+mod)%mod<<endl;
return 0;
}