给出 \(\{a\},\{c\}\),满足矩阵 \(A_{i,j}=a_j^i\) 且 \(A\cdot \overrightarrow{b}=\overrightarrow{c}\),求 \(\{b\}\)。
移项变为 \(\overrightarrow{b}=A^{-1}\cdot \overrightarrow{c}\),考虑先求出 \(A^{-1}\)。
发现 \(\overrightarrow{f}\cdot A=\overrightarrow{g}\) 相当于将 \(a_0,\cdots,a_{n-1}\) 代入系数为 \(f_0,\cdots,f_{n-1}\) 的 \(n\) 次多项式 \(f(x)\) 中得到点值 \(g_0,\cdots,g_{n-1}\)。由于 \(\overrightarrow{f}=\overrightarrow{g}\cdot A^{-1}\) 所以只要我们知道怎么用 \(g\) 线性表示 \(f\) 就可以求出 \(A^{-1}\),这可以使用拉格朗日插值法:
\[f(x)=\sum_{i=0}^{n-1}g_i\prod_{j\neq i}\frac{x-a_j}{a_i-a_j} \]于是:
\[A^{-1}= \begin{bmatrix} [x^0]\prod_{j\neq 0}\frac{x-a_j}{a_0-a_j}&[x^1]\prod_{j\neq 0}\frac{x-a_j}{a_0-a_j}&\cdots&[x^{n-1}]\prod_{j\neq 0}\frac{x-a_j}{a_0-a_j}\\ [x^0]\prod_{j\neq 1}\frac{x-a_j}{a_1-a_j}&[x^1]\prod_{j\neq 1}\frac{x-a_j}{a_1-a_j}&\cdots&[x^{n-1}]\prod_{j\neq 1}\frac{x-a_j}{a_1-a_j}\\ \vdots&\vdots&&\vdots\\ [x^0]\prod_{j\neq {n-1}}\frac{x-a_j}{a_{n-1}-a_j}&[x^1]\prod_{j\neq {n-1}}\frac{x-a_j}{a_{n-1}-a_j}&\cdots&[x^{n-1}]\prod_{j\neq {n-1}}\frac{x-a_j}{a_{n-1}-a_j}\\ \end{bmatrix} \]可以暴力求出 \(A^{-1}\),再求 \(A^{-1}\cdot \overrightarrow{c}\) 即可得到 \(\overrightarrow{b}\),时间复杂度 \(O(n^2)\)。
接下来是神奇的优化做法。注意到 \(A^{-1}\) 的不同行之间只有 \(a_i\) 不同,所以考虑将 \(b_i\) 转为有关 \(a_i\) 的式子。
现在要求 \(\prod_{j\neq i}\frac{x-a_j}{a_i-a_j}\),然后将其系数和 \(\{c\}\) 点乘并累加即为 \(b_i\)。
先看分母,使用多项式快速插值中类似的套路,可以将 \(\prod_{j\neq i}(a_i-a_j)\) 转化为有关 \(a_i\) 的函数:(设 \(Q(y)=\prod_{j}(y-a_j)\))
\[\prod_{j\neq i}(a_i-a_j)=\lim_{y\to a_i}\frac{Q(y)}{y-a_i}=\lim _{y\to a_i}\frac{Q(y)'}{(y-a_i)'}=\lim_{y\to a_i}\frac{Q(y)'}{1}=\lim_{y\to a_i}Q(y)'=Q(a_i)' \]再看分子,设 \(S(x)=\prod_{j}(x-a_j)=\sum_{k=0}^{n}s_ix^i\),那么 \(\prod_{j\neq i}(x-a_j)=\frac{S(x)}{x-a_i}\),注意它一定是一个 \(n-1\) 次多项式,考虑暴力展开它:
\[P(x,a_i)=\frac{S(x)}{x-a_i}=-\sum_{j=0}^{n-1}x^j\sum_{k=0}^j s_k \frac{1}{a_i^{j-k+1}} \]于是:
\[\begin{aligned} b_i=F(a_i)=&-\sum_{j=0}^{n-1}c_j[x^j]\frac{P(x,a_i)}{Q(a_i)'}\\ =&-\frac{1}{Q(a_i)'}\sum_{j=0}^{n-1}c_j\sum_{k=0}^js_k\frac{1}{a_i^{j-k+1}} \end{aligned} \]换成 \(y\) 好看一点:
\[F(y)=-\frac{1}{Q(y)'}\sum_{j=0}^{n-1}c_j\sum_{k=0}^j s_k\frac{1}{y^{j-k+1}} \]后面那部分是减法卷积,于是我们就求出了 \(F(y)\),但我们无法真正算出来 \(F(y)\),因为 \(\frac{1}{Q(y)'}\) 无法处理。所以不妨把 \(a_0,\cdots,a_{n-1}\) 分别代入 \(Q(y)'\) 和 \(\sum\limits_{j=0}^{n-1}c_j\sum\limits_{k=0}^js_k\frac{1}{y^{j-k+1}}\) 多点求值,然后再得到 \(a_0,\cdots,a_{n-1}\) 代入 \(F(y)\) 得到的点值 \(b_0,\cdots,b_{n-1}\)。
时间复杂度 \(O(n\log ^2n)\)。
#include<bits/stdc++.h>
#define LN 18
#define N 70010
using namespace std;
namespace modular
{
const int mod=998244353;
inline int add(int x,int y){return x+y>=mod?x+y-mod:x+y;}
inline int dec(int x,int y){return x-y<0?x-y+mod:x-y;}
inline int mul(int x,int y){return 1ll*x*y%mod;}
inline void Add(int &x,int y){x=x+y>=mod?x+y-mod:x+y;}
inline void Dec(int &x,int y){x=x-y<0?x-y+mod:x-y;}
inline void Mul(int &x,int y){x=1ll*x*y%mod;}
}using namespace modular;
inline int poww(int a,int b)
{
int ans=1;
while(b)
{
if(b&1) ans=mul(ans,a);
a=mul(a,a);
b>>=1;
}
return ans;
}
inline int read()
{
int x=0,f=1;
char ch=getchar();
while(ch<'0'||ch>'9')
{
if(ch=='-') f=-1;
ch=getchar();
}
while(ch>='0'&&ch<='9')
{
x=(x<<1)+(x<<3)+(ch^'0');
ch=getchar();
}
return x*f;
}
namespace Poly
{
const int NN=N<<1;
typedef vector<int> poly;
void print(const poly &a,string s="")
{
cout<<s;
for(int i=0,s=a.size();i<s;i++)
printf("%d ",a[i]);
puts("");
}
vector<int>w[LN][2];
void init(int limit)
{
for(int bit=0,mid=1;mid<limit;bit++,mid<<=1)
{
int len=mid<<1;
int gn=poww(3,(mod-1)/len);
int ign=poww(gn,mod-2);
int g=1,ig=1;
for(int j=0;j<mid;j++,g=mul(g,gn),ig=mul(ig,ign))
w[bit][0].push_back(g),w[bit][1].push_back(ig);
}
}
void NTT(int *a,int limit,int opt)
{
static int rev[NN];
opt=(opt<0);
for(int i=0;i<limit;i++)
rev[i]=(rev[i>>1]>>1)|((i&1)*(limit>>1));
for(int i=0;i<limit;i++)
if(i<rev[i]) swap(a[i],a[rev[i]]);
for(int bit=0,mid=1;mid<limit;bit++,mid<<=1)
{
for(int i=0,len=mid<<1;i<limit;i+=len)
{
for(int j=0;j<mid;j++)
{
int x=a[i+j],y=mul(w[bit][opt][j],a[i+mid+j]);
a[i+j]=add(x,y),a[i+mid+j]=dec(x,y);
}
}
}
if(opt)
{
int tmp=poww(limit,mod-2);
for(int i=0;i<limit;i++) Mul(a[i],tmp);
}
}
poly pmul(const poly &a,const poly &b)
{
static int A[NN],B[NN];
const int sa=a.size(),sb=b.size();
for(int i=0;i<sa;i++) A[i]=a[i];
for(int i=0;i<sb;i++) B[i]=b[i];
int limit=1;
while(limit<(sa+sb-1)) limit<<=1;
NTT(A,limit,1),NTT(B,limit,1);
for(int i=0;i<limit;i++) Mul(A[i],B[i]);
NTT(A,limit,-1);
poly c(sa+sb-1);
for(int i=0;i<sa+sb-1;i++) c[i]=A[i];
for(int i=0;i<limit;i++) A[i]=B[i]=0;
return c;
}
poly dmul(const poly &a,const poly &b)
{
static int A[NN],B[NN];
const int sa=a.size(),sb=b.size();
for(int i=0;i<sa;i++) A[i]=a[i];
for(int i=0;i<sb;i++) B[i]=b[i];
reverse(B,B+sb);
int limit=1;
while(limit<(sa+sb-1)) limit<<=1;
NTT(A,limit,1),NTT(B,limit,1);
for(int i=0;i<limit;i++) Mul(A[i],B[i]);
NTT(A,limit,-1);
poly c(sa);
for(int i=0;i<sa;i++) c[i]=A[sb+i-1];
for(int i=0;i<limit;i++) A[i]=B[i]=0;
return c;
}
poly getinv(const poly &f,int n)
{
static int ff[NN],g[NN];
assert(f[0]);
g[0]=poww(f[0],mod-2);
int now=2;
for(;now<(n<<1);now<<=1)
{
int limit=now<<1;
for(int i=0;i<now;i++) ff[i]=f[i];
NTT(ff,limit,1),NTT(g,limit,1);
for(int i=0;i<limit;i++)
g[i]=mul(dec(2,mul(ff[i],g[i])),g[i]);
NTT(g,limit,-1);
for(int i=now;i<limit;i++) g[i]=0;
}
poly res(g,g+n);
for(int i=0;i<now;i++) ff[i]=g[i]=0;
return res;
}
poly getder(const poly &f)
{
assert(!f.empty());
if((int)f.size()==1) return poly{0};
poly g((int)f.size()-1);
for(int i=1,s=g.size();i<=s;i++) g[i-1]=mul(i,f[i]);
return g;
}
namespace Eva
{
int n,m;
poly a,res;
poly A[N<<2],f[N<<2];
void solve1(int k,int l,int r)
{
if(l==r)
{
A[k]=poly{1,dec(0,a[l])};
return;
}
int mid=(l+r)>>1,lc=k<<1,rc=k<<1|1;
solve1(lc,l,mid),solve1(rc,mid+1,r);
A[k]=pmul(A[lc],A[rc]);
}
void solve2(int k,int l,int r)
{
if(l>=m) return;
if(l==r)
{
res.push_back(f[k][0]);
return;
}
int mid=(l+r)>>1,lc=k<<1,rc=k<<1|1;
f[lc]=dmul(f[k],A[rc]),f[rc]=dmul(f[k],A[lc]);
f[lc].resize(mid-l+1),f[rc].resize(r-mid);
solve2(lc,l,mid),solve2(rc,mid+1,r);
}
poly geteva(const poly &_f,const poly &_a)
{
f[1]=_f,a=_a;
res.clear();
m=a.size(),n=max(f[1].size(),a.size());
f[1].resize(n),a.resize(n);
solve1(1,0,n-1);
poly q=getinv(A[1],n);
f[1]=dmul(f[1],q);
solve2(1,0,n-1);
return res;
}
}using Eva::geteva;
}using Poly::poly;
int n;
poly a,c,s[N<<2],q;
void solve(int k,int l,int r)
{
if(l==r)
{
s[k]=poly{dec(0,a[l]),1};
return;
}
int mid=(l+r)>>1,lc=k<<1,rc=k<<1|1;
solve(lc,l,mid),solve(rc,mid+1,r);
s[k]=Poly::pmul(s[lc],s[rc]);
}
int main()
{
n=read();
for(int i=0;i<n;i++) a.push_back(read());
for(int i=0;i<n;i++) c.push_back(read());
int limit=1;
while(limit<(n<<1)) limit<<=1;
Poly::init(limit);
solve(1,0,n-1);
q=Poly::getder(s[1]);
s[1].resize(n);
poly f=Poly::dmul(c,s[1]);
f.resize(n+1);
for(int i=n;i>=1;i--) f[i]=f[i-1];
f[0]=0;
poly qv=Poly::geteva(q,a);
for(int &v:a) v=poww(v,mod-2);
poly fv=Poly::geteva(f,a);
for(int i=0;i<n;i++)
printf("%d ",dec(0,mul(poww(qv[i],mod-2),fv[i])));
return 0;
}
标签:frac,overrightarrow,int,多项式,sum,解密,XSY3528,prod,neq
From: https://www.cnblogs.com/ez-lcw/p/16840988.html