首页 > 其他分享 >【XSY3528】解密(多项式)

【XSY3528】解密(多项式)

时间:2022-10-30 12:36:56浏览次数:30  
标签:frac overrightarrow int 多项式 sum 解密 XSY3528 prod neq

给出 \(\{a\},\{c\}\),满足矩阵 \(A_{i,j}=a_j^i\) 且 \(A\cdot \overrightarrow{b}=\overrightarrow{c}\),求 \(\{b\}\)。

移项变为 \(\overrightarrow{b}=A^{-1}\cdot \overrightarrow{c}\),考虑先求出 \(A^{-1}\)。

发现 \(\overrightarrow{f}\cdot A=\overrightarrow{g}\) 相当于将 \(a_0,\cdots,a_{n-1}\) 代入系数为 \(f_0,\cdots,f_{n-1}\) 的 \(n\) 次多项式 \(f(x)\) 中得到点值 \(g_0,\cdots,g_{n-1}\)。由于 \(\overrightarrow{f}=\overrightarrow{g}\cdot A^{-1}\) 所以只要我们知道怎么用 \(g\) 线性表示 \(f\) 就可以求出 \(A^{-1}\),这可以使用拉格朗日插值法:

\[f(x)=\sum_{i=0}^{n-1}g_i\prod_{j\neq i}\frac{x-a_j}{a_i-a_j} \]

于是:

\[A^{-1}= \begin{bmatrix} [x^0]\prod_{j\neq 0}\frac{x-a_j}{a_0-a_j}&[x^1]\prod_{j\neq 0}\frac{x-a_j}{a_0-a_j}&\cdots&[x^{n-1}]\prod_{j\neq 0}\frac{x-a_j}{a_0-a_j}\\ [x^0]\prod_{j\neq 1}\frac{x-a_j}{a_1-a_j}&[x^1]\prod_{j\neq 1}\frac{x-a_j}{a_1-a_j}&\cdots&[x^{n-1}]\prod_{j\neq 1}\frac{x-a_j}{a_1-a_j}\\ \vdots&\vdots&&\vdots\\ [x^0]\prod_{j\neq {n-1}}\frac{x-a_j}{a_{n-1}-a_j}&[x^1]\prod_{j\neq {n-1}}\frac{x-a_j}{a_{n-1}-a_j}&\cdots&[x^{n-1}]\prod_{j\neq {n-1}}\frac{x-a_j}{a_{n-1}-a_j}\\ \end{bmatrix} \]

可以暴力求出 \(A^{-1}\),再求 \(A^{-1}\cdot \overrightarrow{c}\) 即可得到 \(\overrightarrow{b}\),时间复杂度 \(O(n^2)\)。

接下来是神奇的优化做法。注意到 \(A^{-1}\) 的不同行之间只有 \(a_i\) 不同,所以考虑将 \(b_i\) 转为有关 \(a_i\) 的式子。

现在要求 \(\prod_{j\neq i}\frac{x-a_j}{a_i-a_j}\),然后将其系数和 \(\{c\}\) 点乘并累加即为 \(b_i\)。

先看分母,使用多项式快速插值中类似的套路,可以将 \(\prod_{j\neq i}(a_i-a_j)\) 转化为有关 \(a_i\) 的函数:(设 \(Q(y)=\prod_{j}(y-a_j)\))

\[\prod_{j\neq i}(a_i-a_j)=\lim_{y\to a_i}\frac{Q(y)}{y-a_i}=\lim _{y\to a_i}\frac{Q(y)'}{(y-a_i)'}=\lim_{y\to a_i}\frac{Q(y)'}{1}=\lim_{y\to a_i}Q(y)'=Q(a_i)' \]

再看分子,设 \(S(x)=\prod_{j}(x-a_j)=\sum_{k=0}^{n}s_ix^i\),那么 \(\prod_{j\neq i}(x-a_j)=\frac{S(x)}{x-a_i}\),注意它一定是一个 \(n-1\) 次多项式,考虑暴力展开它:

\[P(x,a_i)=\frac{S(x)}{x-a_i}=-\sum_{j=0}^{n-1}x^j\sum_{k=0}^j s_k \frac{1}{a_i^{j-k+1}} \]

于是:

\[\begin{aligned} b_i=F(a_i)=&-\sum_{j=0}^{n-1}c_j[x^j]\frac{P(x,a_i)}{Q(a_i)'}\\ =&-\frac{1}{Q(a_i)'}\sum_{j=0}^{n-1}c_j\sum_{k=0}^js_k\frac{1}{a_i^{j-k+1}} \end{aligned} \]

换成 \(y\) 好看一点:

\[F(y)=-\frac{1}{Q(y)'}\sum_{j=0}^{n-1}c_j\sum_{k=0}^j s_k\frac{1}{y^{j-k+1}} \]

后面那部分是减法卷积,于是我们就求出了 \(F(y)\),但我们无法真正算出来 \(F(y)\),因为 \(\frac{1}{Q(y)'}\) 无法处理。所以不妨把 \(a_0,\cdots,a_{n-1}\) 分别代入 \(Q(y)'\) 和 \(\sum\limits_{j=0}^{n-1}c_j\sum\limits_{k=0}^js_k\frac{1}{y^{j-k+1}}\) 多点求值,然后再得到 \(a_0,\cdots,a_{n-1}\) 代入 \(F(y)\) 得到的点值 \(b_0,\cdots,b_{n-1}\)。

时间复杂度 \(O(n\log ^2n)\)。

#include<bits/stdc++.h>

#define LN 18
#define N 70010

using namespace std;

namespace modular
{
	const int mod=998244353;
	inline int add(int x,int y){return x+y>=mod?x+y-mod:x+y;}
	inline int dec(int x,int y){return x-y<0?x-y+mod:x-y;}
	inline int mul(int x,int y){return 1ll*x*y%mod;}
	inline void Add(int &x,int y){x=x+y>=mod?x+y-mod:x+y;}
	inline void Dec(int &x,int y){x=x-y<0?x-y+mod:x-y;}
	inline void Mul(int &x,int y){x=1ll*x*y%mod;}
}using namespace modular;

inline int poww(int a,int b)
{
	int ans=1;
	while(b)
	{
		if(b&1) ans=mul(ans,a);
		a=mul(a,a);
		b>>=1;
	}
	return ans;
}

inline int read()
{
	int x=0,f=1;
	char ch=getchar();
	while(ch<'0'||ch>'9')
	{
		if(ch=='-') f=-1;
		ch=getchar();
	}
	while(ch>='0'&&ch<='9')
	{
		x=(x<<1)+(x<<3)+(ch^'0');
		ch=getchar();
	}
	return x*f;
}

namespace Poly
{
	const int NN=N<<1;
	typedef vector<int> poly;
	void print(const poly &a,string s="")
	{
		cout<<s;
		for(int i=0,s=a.size();i<s;i++)
			printf("%d ",a[i]);
		puts("");
	}
	vector<int>w[LN][2];
	void init(int limit)
	{
		for(int bit=0,mid=1;mid<limit;bit++,mid<<=1)
		{
			int len=mid<<1;
			int gn=poww(3,(mod-1)/len);
			int ign=poww(gn,mod-2);
			int g=1,ig=1;
			for(int j=0;j<mid;j++,g=mul(g,gn),ig=mul(ig,ign))
				w[bit][0].push_back(g),w[bit][1].push_back(ig);
		}
	}
	void NTT(int *a,int limit,int opt)
	{
		static int rev[NN];
		opt=(opt<0);
		for(int i=0;i<limit;i++)
			rev[i]=(rev[i>>1]>>1)|((i&1)*(limit>>1));
		for(int i=0;i<limit;i++)
			if(i<rev[i]) swap(a[i],a[rev[i]]);
		for(int bit=0,mid=1;mid<limit;bit++,mid<<=1)
		{
			for(int i=0,len=mid<<1;i<limit;i+=len)
			{
				for(int j=0;j<mid;j++)
				{
					int x=a[i+j],y=mul(w[bit][opt][j],a[i+mid+j]);
					a[i+j]=add(x,y),a[i+mid+j]=dec(x,y);
				}
			}
		}
		if(opt)
		{
			int tmp=poww(limit,mod-2);
			for(int i=0;i<limit;i++) Mul(a[i],tmp);
		}
	}
	poly pmul(const poly &a,const poly &b)
	{
		static int A[NN],B[NN];
		const int sa=a.size(),sb=b.size();
		for(int i=0;i<sa;i++) A[i]=a[i];
		for(int i=0;i<sb;i++) B[i]=b[i];
		int limit=1;
		while(limit<(sa+sb-1)) limit<<=1;
		NTT(A,limit,1),NTT(B,limit,1);
		for(int i=0;i<limit;i++) Mul(A[i],B[i]);
		NTT(A,limit,-1);
		poly c(sa+sb-1);
		for(int i=0;i<sa+sb-1;i++) c[i]=A[i];
		for(int i=0;i<limit;i++) A[i]=B[i]=0;
		return c;
	}
	poly dmul(const poly &a,const poly &b)
	{
		static int A[NN],B[NN];
		const int sa=a.size(),sb=b.size();
		for(int i=0;i<sa;i++) A[i]=a[i];
		for(int i=0;i<sb;i++) B[i]=b[i];
		reverse(B,B+sb);
		int limit=1;
		while(limit<(sa+sb-1)) limit<<=1;
		NTT(A,limit,1),NTT(B,limit,1);
		for(int i=0;i<limit;i++) Mul(A[i],B[i]);
		NTT(A,limit,-1);
		poly c(sa);
		for(int i=0;i<sa;i++) c[i]=A[sb+i-1];
		for(int i=0;i<limit;i++) A[i]=B[i]=0;
		return c;
	}
	poly getinv(const poly &f,int n)
	{
		static int ff[NN],g[NN];
		assert(f[0]);
		g[0]=poww(f[0],mod-2);
		int now=2;
		for(;now<(n<<1);now<<=1)
		{
			int limit=now<<1;
			for(int i=0;i<now;i++) ff[i]=f[i];
			NTT(ff,limit,1),NTT(g,limit,1);
			for(int i=0;i<limit;i++)
				g[i]=mul(dec(2,mul(ff[i],g[i])),g[i]);
			NTT(g,limit,-1);
			for(int i=now;i<limit;i++) g[i]=0;
		}
		poly res(g,g+n);
		for(int i=0;i<now;i++) ff[i]=g[i]=0;
		return res;
	}
	poly getder(const poly &f)
	{
		assert(!f.empty());
		if((int)f.size()==1) return poly{0};
		poly g((int)f.size()-1);
		for(int i=1,s=g.size();i<=s;i++) g[i-1]=mul(i,f[i]);
		return g;
	}
	namespace Eva
	{
		int n,m;
		poly a,res;
		poly A[N<<2],f[N<<2];
		void solve1(int k,int l,int r)
		{
			if(l==r)
			{
				A[k]=poly{1,dec(0,a[l])};
				return;
			}
			int mid=(l+r)>>1,lc=k<<1,rc=k<<1|1;
			solve1(lc,l,mid),solve1(rc,mid+1,r);
			A[k]=pmul(A[lc],A[rc]);
		}
		void solve2(int k,int l,int r)
		{
			if(l>=m) return;
			if(l==r)
			{
				res.push_back(f[k][0]);
				return;
			}
			int mid=(l+r)>>1,lc=k<<1,rc=k<<1|1;
			f[lc]=dmul(f[k],A[rc]),f[rc]=dmul(f[k],A[lc]);
			f[lc].resize(mid-l+1),f[rc].resize(r-mid);
			solve2(lc,l,mid),solve2(rc,mid+1,r);
		}
		poly geteva(const poly &_f,const poly &_a)
		{
			f[1]=_f,a=_a;
			res.clear();
			m=a.size(),n=max(f[1].size(),a.size());
			f[1].resize(n),a.resize(n);
			solve1(1,0,n-1);
			poly q=getinv(A[1],n);
			f[1]=dmul(f[1],q);
			solve2(1,0,n-1);
			return res;
		}
	}using Eva::geteva;
}using Poly::poly;

int n;
poly a,c,s[N<<2],q;

void solve(int k,int l,int r)
{
	if(l==r)
	{
		s[k]=poly{dec(0,a[l]),1};
		return;
	}
	int mid=(l+r)>>1,lc=k<<1,rc=k<<1|1;
	solve(lc,l,mid),solve(rc,mid+1,r);
	s[k]=Poly::pmul(s[lc],s[rc]);
}

int main()
{
	n=read();
	for(int i=0;i<n;i++) a.push_back(read());
	for(int i=0;i<n;i++) c.push_back(read());
	int limit=1;
	while(limit<(n<<1)) limit<<=1;
	Poly::init(limit);
	solve(1,0,n-1);
	q=Poly::getder(s[1]);
	s[1].resize(n);
	poly f=Poly::dmul(c,s[1]);
	f.resize(n+1);
	for(int i=n;i>=1;i--) f[i]=f[i-1];
	f[0]=0;
	poly qv=Poly::geteva(q,a);
	for(int &v:a) v=poww(v,mod-2);
	poly fv=Poly::geteva(f,a);
	for(int i=0;i<n;i++)
		printf("%d ",dec(0,mul(poww(qv[i],mod-2),fv[i])));
	return 0;
}

标签:frac,overrightarrow,int,多项式,sum,解密,XSY3528,prod,neq
From: https://www.cnblogs.com/ez-lcw/p/16840988.html

相关文章