在还没有理解矩阵做法之前,看着一些讲解搓出来的 \(O(k\log k\log n)\) 做法,估计已经有过了罢。
题意:已知递推式 \(a_n=\sum\limits_{i=1}^kf_ia_{n-i}\) ,求 \(a_n\) 。
假设我们知道 \(a_n=\sum\limits_{i=0}^{k-1}p_ia_i\) ,不难发现 \(\forall j\geqslant 0,a_{n+j}=\sum\limits_{i=0}^{k-1}p_ia_{i+j}\) 。
这样,我们在知道 \(a_n=\sum\limits_{i=0}^{k-1}p_ia_i,a_m=\sum\limits_{i=0}^{k-1}q_ia_i\) 的情况下,发现:
\(a_{n+m}=\sum\limits_{i=0}^{k-1}p_ia_{m+i}=\sum\limits_{i=0}^{k-1}p_i\sum\limits_{j=0}^{k-1}q_j a_{i+j}=\sum\limits_{i=0}^{k-1}\sum\limits_{j=0}^{k-1}p_iq_ja_{i+j}\) 。
我们就可以通过对 \(p*q\) 来通过 \(a_0\dots a_{2(k-1)}\) 来表示 \(a_{n+m}\) 了。
现在就出现了一个问题:我们出现了 \(a_k\dots a_{2(k-1)}\) 这不该出现的 \(k-1\) 项,如果我们不考虑如何将他们处理成 \(a_0\dots a_{k-1}\) ,这个方法的复杂度就是不正确的。
所以我们一个将 \(a_k\dots a_{2(k-1)}\) 变成 \(a_0\dots a_{k-1}\) 的方法。
我们考虑先将 \(a_{2(k-1)}\) 变成 \(a_{2k-3}\dots a_{k-2}\) ,然后一直递归这个过程。
不妨将 \(a_k\dots a_{2(k-1)}\) 顺序颠倒,记 \(A_i=a_{2(k-1)-i}\) 。
我们记 \(B_i\) 为 \(a_{2(k-1)-i}\) 在我们上面的过程中由于前几位的变换的增加量, \(C_{2(k-2)+i}\) 表示最终 \(a_k\dots a_{2(k-1)}\) 的分解后 \(a_i\) 的增加量。
则 \((A+B)*f=B+C\) ,由于 \(C\) 的 \(0\sim k-2\) 位均为 \(0\) ,也就是在只考虑 \(0\sim k-2\) 位时,我们知道 \(A*f+B*f=B\) 。
所以 \(B=A*\dfrac{f}{1-f}\) ,我们又知道 \(B\) 的 \(k-1\) 位及以后都是 \(0\) ,所以我们可以直接把 \(C\) 求出来了。
我们就可以不断地重复这个合并的过程,单次是 \(O(k\log k)\) 的,再套上一个快速幂就是 \(O(k\log k\log n)\) ,而且常数很大(洛谷上模板题每个点都是2.0s)
代码:
#include <bits/stdc++.h>
#define Mod 998244353
#define LIM (1<<17|5)
using namespace std;
int Qread()
{
int x=0;bool f=false;char ch=getchar();
while(ch<'0'||ch>'9') f|=(ch=='-'),ch=getchar();
while(ch>='0'&&ch<='9') x=x*10+(ch^48),ch=getchar();
return f?-x:x;
}
long long qpow(long long a,long long p)
{
long long ret=1;
for(;p;p>>=1,a=a*a%Mod)
if(p&1) ret=ret*a%Mod;
return ret;
}
int rev[LIM];
void ntt(long long *num,int len,bool typ)
{
long long w,g,x,y;
for(int i=1;i<len;i++){rev[i]=rev[i>>1]>>1;if(i&1) rev[i]|=(len>>1);}
for(int i=1;i<len;i++) if(rev[i]<i) swap(num[rev[i]],num[i]);
for(int i=1;i<len;i<<=1)
{
if(typ) w=qpow(332748118,(Mod-1)/(i<<1));
else w=qpow(3,(Mod-1)/(i<<1));
for(int j=0;j<len;j+=(i<<1))
{
g=1;
for(int k=0;k<i;k++)
{
x=num[j+k],y=g*num[j+i+k]%Mod;
num[j+k]=(x+y)%Mod;
num[j+i+k]=(x+Mod-y)%Mod;
g=g*w%Mod;
}
}
}
if(typ)
{
long long ny=qpow(len,Mod-2);
for(int i=0;i<len;i++) num[i]=num[i]*ny%Mod;
}
return;
}
namespace P{
long long f[LIM],g[LIM];
void polyinv(int n,long long *F,long long *G)
{
G[0]=qpow(F[0],Mod-2);
for(int len=2;len<2*n;len<<=1)
{
for(int i=0;i<len;i++) f[i]=F[i],g[i]=G[i];
for(int i=len;i<(len<<1);i++) f[i]=g[i]=0;
ntt(f,len<<1,false),ntt(g,len<<1,false);
for(int i=0;i<(len<<1);i++) f[i]=f[i]*g[i]%Mod;
ntt(f,len<<1,true);
for(int i=0;i<len;i++) f[i]=f[i]?Mod-f[i]:0;
for(int i=len;i<(len<<1);i++) f[i]=0;
f[0]=(f[0]+2)%Mod;
ntt(f,len<<1,false);
for(int i=0;i<(len<<1);i++) f[i]=f[i]*g[i]%Mod;
ntt(f,len<<1,true);
for(int i=0;i<len;i++) G[i]=f[i];
if((len<<1)>2*n) for(int i=n;i<len;i++) G[i]=0;
}
}
}
long long tmp[LIM],F[LIM],G[LIM],A[LIM],B[LIM];
long long ans[LIM],pw[LIM],pr;
int n,k,len;
long long f[32010];
long long a[32010];
int main()
{
n=Qread(),k=Qread();
for(int i=1;i<=k;i++) f[i]=(Qread()%Mod+Mod)%Mod;
for(int i=0;i<k;i++) a[i]=(Qread()%Mod+Mod)%Mod;
for(len=1;len<4*k;len<<=1);
for(int i=1;i<=k;i++) F[i]=f[i]?Mod-f[i]:0;F[0]=1;
P::polyinv(2*k+1,F,G);
memset(F,0,sizeof(F));
for(int i=1;i<=k;i++) F[i]=f[i];
G[0]=0;
ntt(G,len,false),ntt(F,len,false);
ans[0]=1,pw[1]=1;
for(;n;n>>=1)
{
if(n&1)
{
ntt(ans,len,false),ntt(pw,len,false);
for(int i=0;i<len;i++) tmp[i]=ans[i]*pw[i]%Mod;
ntt(tmp,len,true);
memset(A,0,sizeof(A));
for(int i=k;i<2*k-1;i++) A[2*k-2-i]=tmp[i];
ntt(A,len,false);
for(int i=0;i<len;i++) B[i]=A[i]*G[i]%Mod;
ntt(B,len,true),ntt(A,len,true);
for(int i=0;i<k-1;i++) (A[i]+=B[i])%=Mod;
ntt(A,len,false);
for(int i=0;i<len;i++) A[i]=A[i]*F[i]%Mod;
ntt(A,len,true);
for(int i=k-1;i<2*k-1;i++) (tmp[2*k-2-i]+=A[i])%=Mod;
memset(ans,0,sizeof(ans));
for(int i=0;i<k;i++) ans[i]=tmp[i];
ntt(pw,len,true);
}
ntt(pw,len,false);
for(int i=0;i<len;i++) tmp[i]=pw[i]*pw[i]%Mod;
ntt(tmp,len,true);
memset(A,0,sizeof(A));
for(int i=k;i<2*k-1;i++) A[2*k-2-i]=tmp[i];
ntt(A,len,false);
for(int i=0;i<len;i++) B[i]=A[i]*G[i]%Mod;
ntt(B,len,true),ntt(A,len,true);
for(int i=0;i<k-1;i++) (A[i]+=B[i])%=Mod;
ntt(A,len,false);
for(int i=0;i<len;i++) A[i]=A[i]*F[i]%Mod;
ntt(A,len,true);
for(int i=k-1;i<2*k-1;i++) (tmp[2*k-2-i]+=A[i])%=Mod;
memset(pw,0,sizeof(pw));
for(int i=0;i<k;i++) pw[i]=tmp[i];
}
for(int i=0;i<k;i++)
pr=(pr+ans[i]*a[i])%Mod;
printf("%lld\n",pr);
return 0;
}
标签:dots,log,limits,int,sum,ia,齐次,线性,递推
From: https://www.cnblogs.com/Xun-Xiaoyao/p/17149549.html