好题。
容斥后插板,要计算的形如 \(\binom{Sum}{m}\) 的样子。这个 \(Sum\) 可能会很大,不能直接设进状态,但是我们 \(dp\) 需要 \(Sum\) 计算组合数。解决方法是用范德蒙德卷积
\[\sum_{i=0}^{k}{\binom{n}{i}\binom{m}{k-i}} = \binom{n+m}{k} \]设 \(dp_i\) 表示当前所有 \(\binom{Sum}{i}\) 的总和,如果要给 \(Sum\) 加上一个什么数的话,直接用如上的卷积就行了。
但是这样导致了一个严重的问题:我们使用范德蒙德卷积,将组合数拓展到了实数域上(也就是上标为负也可能不是 \(0\) 一类的),会计算出错。我们要保证 \(dp\) 过程中 \(Sum\) 始终不会变成负的,怎么办?
题目告诉我们,给的 \(a_i\) 以几次方几次方的形式出现,而且开始的总和不超过他们的总和,暗示我们数位dp。那我们相当于一些数可以减,一些数不能减,可以看成 \(01\) 序列,然后就是经典数位dp了。 \(f_i\) 表示没解除限制的,\(g_i\) 表示解除了的。相互转移即可。
卷积可以用多项式优化。
加强!!!
#include <bits/stdc++.h>
using namespace std;
namespace ZPoly
{
using LL=long long;
constexpr int MOD=998244353,G=114514,MAXN=1<<21;
inline int qpow(LL a,LL b) { int r=1;for(;b;(b&1)?r=r*a%MOD:0,a=a*a%MOD,b>>=1);return r; }
inline int madd(int x) { return x; }
inline int mmul(int x) { return x; }
inline int msub(int x,int y) { return (x-=y)<0?x+=MOD:x; }
inline int mdiv(int x,int y) { return (LL)x*qpow(y,MOD-2)%MOD; }
template<typename ...Args>inline int madd(int x,Args ...y) { return (x+=madd(y...))>=MOD?x-=MOD:x; }
template<typename ...Args>inline int mmul(int x,Args ...y) { return (LL)x*mmul(y...)%MOD; }
class Polynomial
{
private:
static constexpr int NTT_LIM=180;
static int g[MAXN+5],c1[MAXN+5],c2[MAXN+5];
int deg;
vector<int> c;
public:
static void init()
{
for(int i=2,gn;i<=MAXN;i<<=1)
{
g[i>>1]=1,gn=qpow(G,(MOD-1)/i);
for(int j=(i>>1)+1;j<i;j++) g[j]=mmul(g[j-1],gn);
}
}
static void DIT(int *a,int len)
{
for(int i=len>>1;i;i>>=1)
for(int j=0;j<len;j+=i<<1)
for(int k=0,x,y;k<i;k++)
x=a[j+k],y=a[i+j+k],a[j+k]=madd(x,y),a[i+j+k]=mmul(g[i+k],msub(x,y));
}
static void DIF(int *a,int len)
{
for(int i=1;i<len;i<<=1)
for(int j=0;j<len;j+=i<<1)
for(int k=0,x,y;k<i;k++)
x=a[j+k],y=mmul(g[i+k],a[i+j+k]),a[j+k]=madd(x,y),a[i+j+k]=msub(x,y);
int x=qpow(len,MOD-2);
for(int i=0;i<len;i++) a[i]=mmul(a[i],x);
reverse(a+1,a+len);
}
private:
static void __polyinv(const int *a,int *b,int len)
{
if(len==1) return b[0]=qpow(a[0],MOD-2),void();
__polyinv(a,b,(len+1)>>1);
int nn=1<<(__lg((len<<1)-1)+1);
memcpy(c1,a,len<<2);
memset(b+len,0,(nn-len)<<2);
memset(c1+len,0,(nn-len)<<2);
DIT(b,nn),DIT(c1,nn);
for(int i=0;i<nn;i++) b[i]=mmul(b[i],msub(2,mmul(b[i],c1[i])));
DIF(b,nn),memset(b+len,0,(nn-len)<<2);
}
static void __polyln(const int *a,int *b,int len)
{
__polyinv(a,b,len);
for(int i=1;i<len;i++) c1[i-1]=mmul(i,a[i]);
int nn=1<<(__lg((len<<1)-1)+1);
memset(b+len,0,(nn-len)<<2);
memset(c1+len,0,(nn-len)<<2);
DIT(b,nn),DIT(c1,nn);
for(int i=0;i<nn;i++) b[i]=mmul(b[i],c1[i]);
DIF(b,nn),memset(b+len,0,(nn-len)<<2);
for(int i=len-1;i>0;i--) b[i]=mdiv(b[i-1],i);
b[0]=0;
}
static void __polyexp(const int *a,int *b,int l,int r)
{
if(l==r-1) return b[l]=(l?mdiv(b[l],l):1),void();
int len=r-l,mid=(l+r)>>1;
__polyexp(a,b,l,mid);
for(int i=0;i<len;i++) c1[i]=a[i];
memcpy(c2,b+l,(mid-l)<<2);
memset(c2+mid-l,0,(r-mid)<<2);
if(len<=NTT_LIM) for(int i=len-1;i>=0;i--)
{
c1[i]=mmul(c1[i],c2[0]);
for(int j=0;j<i;j++) c1[i]=madd(c1[i],mmul(c1[j],c2[i-j]));
}
else
{
DIT(c1,len),DIT(c2,len);
for(int i=0;i<len;i++) c1[i]=mmul(c1[i],c2[i]);
DIF(c1,len);
}
for(int i=mid;i<r;i++) b[i]=madd(b[i],c1[i-l]);
__polyexp(a,b,mid,r);
}
public:
Polynomial(): deg(1),c(1){}
Polynomial(const Polynomial &p): deg(p.deg),c(p.c){}
Polynomial(Polynomial &&p): deg(p.deg),c(move(p.c)){}
explicit Polynomial(int d): deg(d),c(d){}
explicit Polynomial(const vector<int> &v): deg(v.size()),c(v){}
explicit Polynomial(const initializer_list<int> &l): deg(l.size()),c(l){}
inline int &operator [](int i) { return c[i]; }
inline int operator [](int i)const { return c[i]; }
inline int degree()const { return deg; }
inline void resize(int d) { c.resize(deg=d); }
inline Polynomial &operator +=(const Polynomial &p)
{
if(deg<p.deg) resize(p.deg);
for(int i=0;i<deg;i++) c[i]=madd(c[i],p[i]);
return *this;
}
inline Polynomial &operator -=(const Polynomial &p)
{
if(deg<p.deg) resize(p.deg);
for(int i=0;i<deg;i++) c[i]=msub(c[i],p[i]);
return *this;
}
inline Polynomial &operator *=(const Polynomial &p)
{
int n=deg,m=p.deg;resize(n+m-1);
if(n+m<NTT_LIM)
{
memcpy(c1,c.data(),n<<2);
memset(c2,0,(n+m-1)<<2);
for(int i=0;i<n;i++)
for(int j=0;j<m;j++)
c2[i+j]=madd(c2[i+j],mmul(c1[i],p[j]));
memcpy(c.data(),c2,(n+m-1)<<2);
}
else
{
int nn=1<<(__lg(n+m-1)+1);
memcpy(c1,c.data(),n<<2),memcpy(c2,p.c.data(),m<<2);
memset(c1+n,0,(nn-n)<<2),memset(c2+m,0,(nn-m)<<2);
DIT(c1,nn),DIT(c2,nn);
for(int i=0;i<nn;i++) c1[i]=mmul(c1[i],c2[i]);
DIF(c1,nn),memcpy(c.data(),c1,deg<<2);
}
return *this;
}
friend inline Polynomial derivative(const Polynomial &p)
{
Polynomial q(p.deg-1);
for(int i=1;i<p.deg;i++) q[i-1]=mmul(p[i],i);
return q;
}
friend inline Polynomial integral(const Polynomial &p)
{
Polynomial q(p.deg+1);
for(int i=1;i<p.deg;i++) q[i+1]=mdiv(p[i],i+1);
return q;
}
inline Polynomial inv()const
{
if(c[0]==0) cerr<<"[x^0]f(x)=0, f(x)^-1 doesn't exist.\n",abort();
int nn=1<<(__lg((deg<<1)-1)+1);
Polynomial q(nn);
__polyinv(c.data(),q.c.data(),deg);
return q.resize(deg),q;
}
friend inline Polynomial ln(const Polynomial &p)
{
if(p[0]!=1) cerr<<"[x^0]f(x)!=1, ln(f(x)) doesn't exist.\n",abort();
int nn=1<<(__lg((p.deg<<1)-1)+1);
Polynomial q(nn);
__polyln(p.c.data(),q.c.data(),p.deg);
return q.resize(p.deg),q;
}
friend inline Polynomial exp(const Polynomial &p)
{
if(p[0]!=0) cerr<<"[x^0]f(x)!=0, exp(f(x)) doesn't exist.\n",abort();
static int c[MAXN];
int nn=1<<(__lg(p.deg-1)+1);
for(int i=0;i<p.deg;i++) c[i]=mmul(i,p[i]);
Polynomial q(nn);
__polyexp(c,q.c.data(),0,nn);
return q.resize(p.deg),q;
}
friend inline pair<Polynomial,Polynomial> div(const Polynomial &f,const Polynomial &g)
{
if(f.deg<g.deg) return make_pair(Polynomial{0},f);
int n=f.deg-1,m=g.deg-1;
Polynomial fr(n+1),gr(m+1);
for(int i=0;i<=n;i++) fr[i]=f[n-i];
for(int i=0;i<=m;i++) gr[i]=g[m-i];
fr.resize(n-m+1),gr.resize(n-m+1),fr*=gr.inv();
fr.resize(n-m+1),reverse(fr.c.begin(),fr.c.end());
gr=f-fr*g,gr.resize(m);
return make_pair(fr,gr);
}
inline Polynomial &operator =(const Polynomial &p)
{ return deg=p.deg,c=p.c,*this; }
inline Polynomial &operator =(Polynomial &&p)
{ return deg=p.deg,c=move(p.c),*this; }
inline Polynomial &operator *=(int k)
{ for(auto &i: c) i=mmul(i,k);return *this; }
inline Polynomial &operator /=(const Polynomial &rhs)
{ return (*this)*=rhs.inv(); }
inline Polynomial &operator %=(const Polynomial &rhs)
{ return (*this)=div(*this,rhs).second; }
inline Polynomial operator +(const Polynomial &rhs)const
{ return Polynomial(*this)+=rhs; }
inline Polynomial operator -(const Polynomial &rhs)const
{ return Polynomial(*this)-=rhs; }
inline Polynomial operator *(const Polynomial &rhs)const
{ return Polynomial(*this)*=rhs; }
inline Polynomial operator /(const Polynomial &rhs)const
{ return Polynomial(*this)/=rhs; }
inline Polynomial operator %(const Polynomial &rhs)const
{ return div(*this,rhs).second; }
friend inline Polynomial operator *(const Polynomial &p,int k)
{ return Polynomial(p)*=k; }
friend inline Polynomial operator *(int k,const Polynomial &p)
{ return Polynomial(p)*=k; }
};
int Polynomial::g[]={},Polynomial::c1[]={},Polynomial::c2[]={};
};
using namespace ZPoly;
int m,b,c,pw[405];
struct bignum{
int a[405],len;
void trim(){
int r=0;
for(int i=0;i<=len;i++)
{a[i]+=r;r=floor(1.0*a[i]/b);a[i]-=r*b;}
while(r!=0)
{++len;a[len]=r;r=floor(1.0*a[len]/b);a[len]-=r*b;}
while(len>=0&&a[len]==0)len--;
}
int getval()
{
int ret=0;
for(int i=0;i<=len;i++)
ret=(ret+1ll*pw[i]*a[i]%MOD)%MOD;
return ret;
}
void sub(bignum &o)
{
for(int i=0;i<=o.len;i++)
a[i]-=o.a[i];
trim();
}
}val[405],N;
bool comp(bignum a,bignum b)
{
if(a.len!=b.len)return a.len<b.len;
for(int i=a.len;i>=0;i--)if(a.a[i]!=b.a[i])return a.a[i]<b.a[i];
return 0;
}
char s[100000];
int len;
int w[100000],tmp[100000],t[405],inv[405];
Polynomial f[2],g[2],h,ini;
int main()
{
Polynomial::init();
scanf("%d%d%d",&m,&b,&c);
pw[0]=1;for(int i=1;i<=400;i++)pw[i]=1ll*pw[i-1]*b%MOD;
inv[1]=1;for(int i=2;i<=400;i++)inv[i]=1ll*inv[MOD%i]*(MOD-MOD/i)%MOD;
for(int i=1;i<=m;i++)
val[i].len=i,val[i].a[i]=1,val[i].a[0]-=c-1,val[i].trim();
scanf("%s",s);
len=strlen(s);reverse(s,s+len);len--;
for(int i=0;i<=len;i++)w[i]=s[i]-'0';
N.len=-1;
while(len>=0)
{
int r=0;
for(int i=len;i>=0;i--)
{tmp[i]=(w[i]+1ll*r*10)/b;r=(1ll*r*10+w[i])%b;}
for(int i=0;i<=len;i++)w[i]=tmp[i],tmp[i]=0;
while(len>=0&&w[len]==0)len--;
N.a[++N.len]=r;
}
if(N.len==-1)
{
puts("0");
return 0;
}
N.a[0]+=m-1;N.trim();
int n=N.getval();
f[0].resize(m+5),f[1].resize(m+5);
g[0].resize(m+5),g[1].resize(m+5);
ini.resize(m+5);
h.resize(m+5);
f[0][0]=1;
for(int val=1,i=1;i<=m;i++)
{
val=1ll*val*(n-i+1)%MOD;
val=1ll*val*inv[i]%MOD;
f[0][i]=val;
}
for(int i=m;i>=1;i--)if(comp(val[i],N))t[i]=1,N.sub(val[i]);
for(int i=m;i>=1;i--)
{
f[0].resize(m+5);f[1].resize(m+5);
int nowa=(MOD-val[i].getval())%MOD;
h[0]=1;
for(int val=1,i=1;i<=m;i++)
{
val=1ll*val*(nowa-i+1)%MOD;
val=1ll*val*inv[i]%MOD;
h[i]=val;
}
if(t[i])
{
g[0]=ini-f[0]*h;
g[1]=f[1]-f[1]*h+f[0];
}
else
{
g[0]=f[0];
g[1]=f[1]-f[1]*h;
}
f[0]=g[0],f[1]=g[1];
}
int ans=(f[0][m]+f[1][m])%MOD;
printf("%d\n",ans);
return 0;
}
标签:UNR,return,UOJ312,题面,--,len,int,inline,resize
From: https://www.cnblogs.com/hikkio/p/17601927.html