首页 > 其他分享 >【LGR125D】【JRKSJ R5】Concvssion(多项式,长链剖分)

【LGR125D】【JRKSJ R5】Concvssion(多项式,长链剖分)

时间:2022-11-13 09:44:57浏览次数:52  
标签:R5 return 剖分 长链 int poly mod size

Sub 1:\(a_i=(i+1)\bmod n\)

即图只有一个环。

设 \(g_u\) 表示原来 \(u\) 上有多少个点,\(f_u=u\) 表示 \(u\) 的点权。

那么对于某个 \(k\in [1,n]\),\(ans_k=\sum_{u}g_uf_{u+k}\)(这里需要将 \(f\) 倍长),这是个减法卷积,可以 \(O(n\log n)\) 求出。

Sub 2:\(a_i\) 是排列

那图是若干个环。

对于每个大小为 \(c\) 的环,它对答案的贡献有大小为 \(c\) 的循环节,循环节具体是什么样可以用上面的方法 \(O(c\log c)\) 求出。

同种环长的所有环可以一起考虑,而总共只有 \(O(\sqrt n)\) 种环长,于是对每种环长都暴力地加到总答案上,可以做到 \(O(n\sqrt n)\) 的复杂度。

Sub 3:\(a_1=1\),\(\forall_{i\in[2,n]},a_i<i\)

即图是一棵内向树。

将树长链剖分。一开始我们先将所有点都当成在最长长链 A 上(所有深度为 \(i\) 的点初始都在长链的第 \(i\) 个点上),然后用减法卷积求出此时的答案。

接着我们考虑长链 A 上挂的一棵子树,以及该子树对应的最长长链 B。显然该子树内的点在时间的前半部分他们初始都不在长链 A 上,但却被我们当做是在 A 上了,这里的贡献我们要剪掉。具体来说,我们先将该子树内的点当成都在长链 B 上,然后让长链 B 上的点权减等于长链 A 上的对应点权,再类似刚刚用卷积计算一遍答案。这样长链 B 上的点的贡献是算对了的,但长链 B 上挂着的点的贡献算错了(前半部分算到了长链 B 上)。容易发现这是一个递归的过程,直接递归即可做到 \(O(n\log n)\)。

可以说是运用了类似差分的思想。

Sub 4:无特殊限制

即图是内向基环森林。

对于一棵内向基环树,可以用类似 Sub 3 的方法考虑:一开始我们先将树中的所有点都当成在环上,然后计算答案。然后对于环上挂的每一棵树继续用 Sub 3 的方法做即可,只不过树的第一条长链的点权要剪掉环的对应点权。

对于内向基环森林,用 Sub2 的方法合并多个环的答案即可。

总时间复杂度 \(O(n\log n+n\sqrt n)\),瓶颈在合并多个环的答案而非多项式。

代码很好写。

其实是因为我有多项式板子(逃

#include<bits/stdc++.h>

#define N 300010

using namespace std;

namespace modular
{
	const int mod=998244353,inv2=(mod+1)>>1;
	inline int add(int x,int y){return x+y>=mod?x+y-mod:x+y;}
	inline int dec(int x,int y){return x-y<0?x-y+mod:x-y;}
	inline int mul(int x,int y){return 1ll*x*y%mod;}
	inline void Add(int &x,int y){x=x+y>=mod?x+y-mod:x+y;}
	inline void Dec(int &x,int y){x=x-y<0?x-y+mod:x-y;}
	inline void Mul(int &x,int y){x=1ll*x*y%mod;}
	inline int poww(int a,int b){int ans=1;for(;b;Mul(a,a),b>>=1)if(b&1)Mul(ans,a);return ans;}
}using namespace modular;

inline int read()
{
	int x=0,f=1;
	char ch=getchar();
	while(ch<'0'||ch>'9')
	{
		if(ch=='-') f=-1;
		ch=getchar();
	}
	while(ch>='0'&&ch<='9')
	{
		x=(x<<1)+(x<<3)+(ch^'0');
		ch=getchar();
	}
	return x*f;
}

namespace Poly
{
	void NTT(vector<int> &a,int limit,int opt)
	{
		static vector<int> rev;
		static vector<int> w[25][2];
		static int _bit=0,_limit=1;
		if(limit>_limit)
		{
			rev.resize(limit);
			for(;_limit<limit;_bit++,_limit<<=1)
			{
				int len=_limit<<1;
				int gn=poww(3,(mod-1)/len);
				int ign=poww(gn,mod-2);
				for(int j=0,g=1,ig=1;j<_limit;j++,Mul(g,gn),Mul(ig,ign))
					w[_bit][0].push_back(g),w[_bit][1].push_back(ig);
			}
		}
		opt=(opt<0);
		if(limit>(int)a.size()) a.resize(limit);
		for(int i=0;i<limit;i++) rev[i]=(rev[i>>1]>>1)|((i&1)?(limit>>1):0);
		for(int i=0;i<limit;i++) if(i<rev[i]) swap(a[i],a[rev[i]]);
		for(int bit=0,mid=1;mid<limit;bit++,mid<<=1)
		{
			for(int i=0,len=mid<<1;i<limit;i+=len)
			{
				for(int j=0;j<mid;j++)
				{
					int x=a[i+j],y=mul(w[bit][opt][j],a[i+mid+j]);
					a[i+j]=add(x,y),a[i+mid+j]=dec(x,y);
				}
			}
		}
		if(opt) for(int i=0,x=poww(limit,mod-2);i<limit;i++) Mul(a[i],x);
	}
	struct poly:vector<int>
	{
		using vector<int>::vector;
		void resmax(int s){if(s>(int)size())resize(s);}
		poly mod(int s)const{return poly(begin(),begin()+min(s,(int)size()));}
	};
	poly operator << (poly a,int b)
	{
		a.resize((int)a.size()+b);
		move_backward(a.begin(),a.end()-b,a.end());
		for(int i=0;i<b;i++) a[i]=0;
		return a;
	}
	poly operator >> (poly a,int b)
	{
		if(b>(int)a.size()){return poly();}
		move(a.begin()+b,a.end(),a.begin());
		a.resize((int)a.size()-b);
		return a;
	}
	poly operator * (poly a,int b)
	{
		for(int &x:a) Mul(x,b);
		return a;
	}
	poly operator + (poly a,const poly &b)
	{
		a.resmax(b.size());
		for(int i=0,s=b.size();i<s;i++) Add(a[i],b[i]);
		return a;
	}
	poly operator - (poly a,const poly &b)
	{
		a.resmax(b.size());
		for(int i=0,s=b.size();i<s;i++) Dec(a[i],b[i]);
		return a;
	}
	poly dot(const poly &a,const poly &b)
	{
		const int s=min(a.size(),b.size());
		poly c(s); for(int i=0;i<s;i++) c[i]=mul(a[i],b[i]);
		return c;
	}
	poly operator * (const poly &a,const poly &b)
	{
		const int sa=a.size(),sb=b.size();
		if(!sa&&!sb) return poly();
		if(sa<=50||sb<=50)
		{
			poly c(sa+sb-1);
			for(int i=0;i<sa;i++)
				for(int j=0;j<sb;j++)
					Add(c[i+j],mul(a[i],b[j]));
			return c;
		}
		poly A(a),B(b);
		int limit=1; while(limit<sa+sb-1) limit<<=1;
		NTT(A,limit,1),NTT(B,limit,1);
		for(int i=0;i<limit;i++) Mul(A[i],B[i]);
		NTT(A,limit,-1); A.resize(sa+sb-1);
		return A;
	}
	poly dmul(const poly &a,const poly &b)
	{
		if(b.empty()) return poly();
		poly rb=b; reverse(rb.begin(),rb.end());
		return (a*rb)>>(b.size()-1);
	}
	poly inv(const poly &f,int n)
	{
		poly g(n<<2);
		assert(f[0]); g[0]=poww(f[0],mod-2);
		for(int now=2;now<(n<<1);now<<=1)
		{
			int limit=now<<1;
			poly ff=f.mod(now);
			NTT(ff,limit,1),NTT(g,limit,1);
			for(int i=0;i<limit;i++) Mul(g[i],dec(2,mul(ff[i],g[i])));
			NTT(g,limit,-1);
			for(int i=now;i<limit;i++) g[i]=0;
		}
		g.resize(n); return g;
	}
	poly sqrt(const poly &f,int n)
	{
		assert(f[0]==1); poly g{1};
		for(int now=2;now<(n<<1);now<<=1)
			g=((f.mod(now)*inv(g,now)).mod(now)+g)*inv2;
		g.resize(n); return g;
	}
	poly D(const poly &f)
	{
		const int n=f.size(); poly g(n-1);
		for(int i=1;i<n;i++) g[i-1]=mul(i,f[i]);
		return g;
	}
	poly Int(const poly &f)
	{
		static vector<int> inv{0,1};
		const int n=f.size(); poly g(n+1);
		for(int i=inv.size();i<=n;i++) inv.push_back(mul(mod-mod/i,inv[mod%i]));
		for(int i=1;i<=n;i++) g[i]=mul(inv[i],f[i-1]);
		return g;
	}
	poly ln(const poly &f,int n)
	{
		assert(f[0]==1);
		poly Dg=(D(f)*inv(f,n)).mod(n-1);
		return Int(Dg);
	}
	poly exp(const poly &f,int n)
	{
		assert(!f[0]); poly g{1};
		for(int now=2;now<(n<<1);now<<=1)
			g=(g*(poly{1}-ln(g,now)+f.mod(now))).mod(now);
		g.resize(n); return g;
	}
	//poww: check if z*k>=n before k module
	poly poww(poly f,int k,int k2,int n)//k: k % p, k2: k % phi(p)
	{
		int z=0; while(!f.at(z)) z++;
		if(1ll*z*k>=n) return poly(n);
		int c=f[z],ic=modular::poww(c,mod-2),ck=modular::poww(c,k2);
		f=f>>z; for(int &x:f) Mul(x,ic);
		f=ln(f,n-z*k); for(int &x:f) Mul(x,k);
		f=exp(f,n-z*k); for(int &x:f) Mul(x,ck);
		f=f<<(z*k); return f;
	}
}using namespace Poly;

int n,a[N],f[N],g[N],pre[N],ans[N];
bool vis[N];

vector<int> e[N],s[N];

vector<int> findring(int x)
{
	static bool ins[N];
	int u=x;
	do
	{
		ins[u]=1,u=a[u];
	}while(!ins[u]);
	x=u;
	vector<int> ring;
	do
	{
		ring.push_back(u);
		u=a[u];
	}while(u!=x);
	return ring;
}

void calc1(int l,int r)
{
	poly F,G;
	while(r)
	{
		Add(g[l],g[r]),Dec(f[r],f[l]);
		F.push_back(f[r]),G.push_back(g[r]);
		l=pre[l],r=pre[r];
	}
	reverse(F.begin(),F.end());
	reverse(G.begin(),G.end());
	poly H(dmul(F,G));
	for(unsigned i=0;i<H.size();i++) Add(ans[i],H[i]);
}

int dfs(int u)
{
	vis[u]=1;
	int son=0,sonl=0;
	for(int v:e[u])
	{
		int vl=dfs(v);
		if(vl>sonl) son=v,sonl=vl;
	}
	pre[u]=son;
	for(int v:e[u]) if(v!=son)
		calc1(pre[u],v);
	return sonl+1;
}

void solve(vector<int> ring)
{
	int rs=ring.size();
	for(int i=0;i<rs;i++)
		vis[ring[i]]=1,pre[ring[(i+1)%rs]]=ring[i];
	for(int u:ring)
		for(int v:e[u]) if(v!=pre[u])
			dfs(v),calc1(pre[u],v);
	poly F(rs<<1),G(rs);
	for(int i=0;i<rs;i++)
		G[i]=g[ring[i]],F[i]=F[rs+i]=f[ring[i]];
	poly res(dmul(F,G));
	s[rs].resize(rs);
	for(int i=0;i<rs;i++) Add(s[rs][i],res[i]);
}

int main()
{
	n=read();
	for(int i=1;i<=n;i++) e[a[i]=read()].push_back(i);
	for(int i=1;i<=n;i++) f[i]=i;
	for(int i=1;i<=n;i++) g[read()]++;
	for(int i=1;i<=n;i++)
		if(!vis[i]) solve(findring(i));
	for(int i=1;i<=n;i++)
	{
		if(s[i].empty()) continue;
		for(int j=0;j<=n;j++) Add(ans[j],s[i][j%i]);
	}
	for(int i=1;i<=n;i++) printf("%d\n",ans[i]);
	return 0;
}

标签:R5,return,剖分,长链,int,poly,mod,size
From: https://www.cnblogs.com/ez-lcw/p/16885429.html

相关文章