题目
点这里看题目。
有一排 \(N\) 个格子,有 \(M\) 个人,初始都在 \(1\) 号格。
每个人可以选择往前跳一格或者跳两格,跳一格的方法数为 \(p\),跳两格的方法数为 \(q\),跳出 \(N\) 个格子则停止,注意在第 \(N\) 个格子仍然能选择跳一或两格。
你需要计算有多少种方法使得每个格子都至少被一个人踩过。
分析
容易得到一个容斥做法。设以下的一堆东西:
- \(f_n\) 表示 \(m\) 个人,从 \(1\) 跳到 \(n\),且所有格子均被至少经过一次的方案数。
- \(h_n\) 表示一个人,从 \(1\) 跳到 \(n\),但不要求所有格子均被至少经过一次的方案数。
- \(\hat f_n,\hat h_n\) 与其去掉尖号的定义基本一致,但是要求从 \(1\) 起跳跳出 \(n\)。
容易得到以下关系:
\[\begin{aligned} h_n&=\begin{cases}ph_{n-1}+qh_{n-2}&n\ge 3\\p&n=2\\1&n=1\end{cases}\\ \hat h_n&=(p+q)h_n+qh_{n-1},\forall n\ge 2 \end{aligned} \]对于 \(f_n,\hat f_n\),列容斥式子:
\[\begin{aligned} f_n&=(h_n)^m-q^m\sum_{k=2}^{n-1}f_{k-1}(h_{n-k})^m\\ \hat f_n&=(\hat h_n)^m-q^m\sum_{k=2}^{n-1}f_{k-1}(\hat h_{n-k})^m-q^mf_{n-1} \end{aligned} \]假设 \(F(x),\hat F(x),G(x),\hat G(x)\) 分别为 \(\{f_n\},\{\hat f_n\},\{(h_n)^m\},\{(\hat h_n)^m\}\) 的生成函数,我们可以得到:
\[\begin{aligned} F(x)&=\frac{G(x)}{1+q^mxG(x)}\\ \hat F(x)&=\frac{\hat G(x)-q^mxG(x)}{1+q^mxG(x)} \end{aligned} \]现在已经可以用多项式求逆做到 \(O(n\log n)\) 了,但是明显远远不够。
我们注意到,虽然 \(H(x),\hat H(x)\)(假设它们分别为 \(\{h_n\},\{\hat h_n\}\) 的生成函数)是容易求得并且计算的,但是 \(G(x),\hat G(x)\) 并非如此。如果要进一步优化,这个思路势必要求我们解出来 \(G(x),\hat G(x)\)。
另一方面,\(H(x)\) 是二阶齐次线性递推数列的生成函数,所以 \(h_n\) 应该形如 \(c_1\lambda^n+c_2\mu^n\);如果再把 \((h_n)^m\) 表示出来,我们看到的是一个可以展开的部分,展开后恰好有 \(m+1\) 项。结合 \(m\le 6\times 10^4\) 的条件,这个思路应该是可以一试的。
具体来说,我们有:
\[H(x)=\frac{x}{1-px-qx^2} \]对其分解,两个特征根分别为 \(\lambda=\frac{p+\sqrt{\Delta}}{2},\mu=\frac{p-\sqrt{\Delta}}{2}(\Delta=p^2+4q)\),结果为:
\[h_n=\frac{1}{\sqrt \Delta}(\lambda^n-\mu^n) \]现在,当我们尝试计算 \((h_n)^m\) 时,我们做一个展开:
\[\begin{aligned} (h_n)^m&=\frac{1}{(\sqrt\Delta)^m}(\lambda^n-\mu^n)^m\\ &=\frac{1}{(\sqrt\Delta)^m}\sum_{k=0}^m\binom{m}{k}(-1)^{m-k}(\lambda^k)^n(\mu^{m-k})^n \end{aligned} \]很神奇的是,我们发现 \(n\) 仅作为指数出现,所以在计算 \(G(x)\) 的过程中,这些项可以被收起来:
\[G(x)=\frac{1}{(\sqrt\Delta)^m}\sum_{k=0}^m\frac{\binom{m}{k}(-1)^{m-k}}{1-\lambda^k\mu^{m-k}x} \]得出的结论是:\(G(x)\) 其实是一个有理分式。我们可以用分治 NTT 通分,而后用 Bostan-Mori 计算远项。
Note.
生成函数的基础思路之一就是根据通项公式直接计算。所以如果需要得到生成函数,并且已知通项,那么应该考虑这个思路。
对于 \(\hat G(x),\hat H(x)\) 也可以类似处理。首先对于 \(\hat H(x)\) 分解,引入参数 \(A=(p+q)\lambda+q,B=(p+q)\mu+q\),结果是:
\[\begin{aligned} \hat H(x)&=\frac{(p+q)x+x^2}{1-px-q^x}\\&=\frac{x}{\sqrt\Delta}\left(\frac{A}{1-\lambda x}-\frac{B}{1-\mu x}\right)\\ \hat h_n&=\frac{1}{\sqrt \Delta}\left(A\lambda^{n-1}-B\mu^{n-1}\right) \end{aligned} \]类似地,可以得到:
\[\begin{aligned} (\hat h_n)^m&=\frac{1}{(\sqrt\Delta)^m}\sum_{k=0}^m\binom{m}{k}(-1)^{m-k}A^kB^{m-k}(\lambda^k)^{n-1}(\mu^{m-k})^{n-1}\\ \hat G(x)&=\frac{x}{(\sqrt\Delta)^m}\sum_{k=0}^m\frac{\binom{m}{k}(-1)^{m-k}A^kB^{m-k}}{1-\lambda^k\mu^{m-k}x} \end{aligned} \]所以 \(\hat G(x)\) 也是一个有理分式。
通分过后,我们可以得到 \(G(x)=\frac{U(x)}{W(x)},\hat G(x)=\frac{V(x)}{W(x)}\),于是最终的结果就是:
\[\hat F(x)=\frac{V(x)-q^mxU(x)}{W(x)+q^mxU(x)} \]还是用 Bostan-Mori 计算即可。复杂度为 \(O(m\log^2m+n\log m\log n)\)。
计算上有一些细节。如果 \(\Delta\not\equiv 0\pmod{998244353}\),那么我们可以直接在 \(F_{998244353}(\sqrt\Delta)\) 中计算。但是,如果 \(\Delta\equiv 0\pmod {998244353}\),那么就有一点点复杂了。此时可能无法直接用这个做法解决问题,需要其它做法介入,限于篇幅不再赘述。
幸运的是,数据里没有出现这种情况,毕竟随机情况下出现这种情况的概率为 \(p^{-1}\)。
代码
完整代码
#include <cstdio>
#include <vector>
#include <iostream>
#include <algorithm>
#define rep( i, a, b ) for( int i = (a) ; i <= (b) ; i ++ )
#define per( i, a, b ) for( int i = (a) ; i >= (b) ; i -- )
const int mod = 998244353, inv2 = ( mod + 1 ) >> 1;
const int MAXN = ( 1 << 18 ) + 5;
template<typename _T>
inline void Read( _T &x ) {
x = 0; char s = getchar(); bool f = false;
while( s < '0' || '9' < s ) { f = s == '-', s = getchar(); }
while( '0' <= s && s <= '9' ) { x = ( x << 3 ) + ( x << 1 ) + ( s - '0' ), s = getchar(); }
if( f ) x = -x;
}
template<typename _T>
inline void Write( _T x ) {
if( x < 0 ) putchar( '-' ), x = -x;
if( 9 < x ) Write( x / 10 );
putchar( x % 10 + '0' );
}
struct Complex {
int r, i;
Complex(): r( 0 ), i( 0 ) {}
Complex( int R, int I = 0 ): r( R % mod ), i( I % mod ) {}
};
typedef std :: vector<Complex> Poly;
struct RetType {
Poly U, V, W;
RetType(): U(), V(), W() {}
RetType( const Poly &u, const Poly &v, const Poly &w ): U( u ), V( v ), W( w ) {}
};
RetType bas[MAXN];
Complex laPw[MAXN], muPw[MAXN], AhPw[MAXN], BhPw[MAXN];
Complex la, mu, Ah, Bh, cns;
int C[MAXN];
int N, M, p, q, Delt;
inline int Qkpow( int, int );
inline int Inv( const int &a ) { return Qkpow( a, mod - 2 ); }
inline int Mul( int x, const int &v ) { return 1ll * x * v % mod; }
inline int Sub( int x, const int &v ) { return ( x -= v ) < 0 ? x + mod : x; }
inline int Add( int x, const int &v ) { return ( x += v ) >= mod ? x - mod : x; }
inline int& MulEq( int &x, const int &v ) { return x = 1ll * x * v % mod; }
inline int& SubEq( int &x, const int &v ) { return ( x -= v ) < 0 ? ( x += mod ) : x; }
inline int& AddEq( int &x, const int &v ) { return ( x += v ) >= mod ? ( x -= mod ) : x; }
inline Complex operator + ( const Complex &a, const Complex &b ) {
return Complex( Add( a.r, b.r ), Add( a.i, b.i ) );
}
inline Complex operator - ( const Complex &a, const Complex &b ) {
return Complex( Sub( a.r, b.r ), Sub( a.i, b.i ) );
}
inline Complex operator * ( const Complex &a, const Complex &b ) {
return Complex( Add( Mul( a.r, b.r ), Mul( Mul( a.i, b.i ), Delt ) ), Add( Mul( a.r, b.i ), Mul( a.i, b.r ) ) );
}
inline Complex operator / ( const Complex &a, const Complex &b ) {
int c = Inv( Sub( Mul( b.r, b.r ), Mul( Delt, Mul( b.i, b.i ) ) ) );
return a * Complex( Mul( b.r, c ), Mul( Mul( mod - 1, c ), b.i ) );
}
inline Complex& operator += ( Complex &a, const Complex &b ) { return a = a + b; }
inline Complex& operator -= ( Complex &a, const Complex &b ) { return a = a - b; }
inline Complex& operator *= ( Complex &a, const Complex &b ) { return a = a * b; }
inline int Qkpow( int base, int indx ) {
int ret = 1;
while( indx ) {
if( indx & 1 ) MulEq( ret, base );
MulEq( base, base ), indx >>= 1;
}
return ret;
}
inline Complex Qkpow( Complex base, int indx ) {
Complex ret( 1 );
while( indx ) {
if( indx & 1 ) ret *= base;
base *= base, indx >>= 1;
}
return ret;
}
namespace Basics {
const int L = 18, g = 3, phi = mod - 1;
int w[MAXN];
inline void NTTInit( const int &n = 1 << L ) {
w[0] = 1, w[1] = Qkpow( g, phi >> L );
rep( i, 2, n - 1 ) w[i] = Mul( w[i - 1], w[1] );
}
inline void DIF( Complex *coe, const int &n ) {
int *wp, p; Complex e, o;
for( int s = n >> 1 ; s ; s >>= 1 )
for( int i = 0 ; i < n ; i += s << 1 ) {
p = ( 1 << L ) / ( s << 1 ), wp = w;
for( int j = 0 ; j < s ; j ++, wp += p ) {
e = coe[i + j], o = coe[i + j + s];
coe[i + j] = e + o;
coe[i + j + s] = ( e - o ) * *wp;
}
}
}
inline void DIT( Complex *coe, const int &n ) {
int *wp, p; Complex k;
for( int s = 1 ; s < n ; s <<= 1 )
for( int i = 0 ; i < n ; i += s << 1 ) {
p = ( 1 << L ) / ( s << 1 ), wp = w;
for( int j = 0 ; j < s ; j ++, wp += p )
k = coe[i + j + s] * *wp,
coe[i + j + s] = coe[i + j] - k,
coe[i + j] = coe[i + j] + k;
}
std :: reverse( coe + 1, coe + n );
int inv = Inv( n ); rep( i, 0, n - 1 ) coe[i] = coe[i] * inv;
}
}
inline Poly operator + ( const Poly &a, const Poly &b ) {
int n = a.size(), m = b.size();
Poly ret( std :: max( n, m ) );
for( int i = 0 ; i < n || i < m ; i ++ ) {
if( i < n ) ret[i] += a[i];
if( i < m ) ret[i] += b[i];
}
return ret;
}
inline Poly operator * ( const Poly &a, const Poly &b ) {
static Complex P[MAXN], Q[MAXN];
int n = a.size(), m = b.size(), L;
for( L = 1 ; L <= n + m - 2 ; L <<= 1 );
rep( i, 0, L - 1 ) P[i] = Q[i] = Complex();
rep( i, 0, n - 1 ) P[i] = a[i];
rep( i, 0, m - 1 ) Q[i] = b[i];
Basics :: DIF( P, L );
Basics :: DIF( Q, L );
rep( i, 0, L - 1 ) P[i] *= Q[i];
Basics :: DIT( P, L );
return Poly( P, P + n + m - 1 );
}
inline void Init() {
Basics :: NTTInit();
Delt = Add( Mul( p, p ), Mul( 4, q ) );
C[0] = 1; rep( i, 1, M ) C[i] = Mul( C[i - 1], Mul( M - i + 1, Inv( i ) ) );
la = Complex( Mul( p, inv2 ), inv2 );
mu = Complex( Mul( p, inv2 ), mod - inv2 );
Ah = la * Add( p, q ) + q;
Bh = mu * Add( p, q ) + q;
cns = Qkpow( 1 / ( la - mu ), M );
laPw[0] = Complex( 1 ); rep( i, 1, M ) laPw[i] = laPw[i - 1] * la;
muPw[0] = Complex( 1 ); rep( i, 1, M ) muPw[i] = muPw[i - 1] * mu;
AhPw[0] = Complex( 1 ); rep( i, 1, M ) AhPw[i] = AhPw[i - 1] * Ah;
BhPw[0] = Complex( 1 ); rep( i, 1, M ) BhPw[i] = BhPw[i - 1] * Bh;
}
RetType Divide( const int &l, const int &r ) {
if( l == r ) return bas[l];
int mid = ( l + r ) >> 1;
RetType tmpL = Divide( l, mid ),
tmpR = Divide( mid + 1, r );
return RetType( tmpL.U * tmpR.W + tmpL.W * tmpR.U,
tmpL.V * tmpR.W + tmpL.W * tmpR.V,
tmpL.W * tmpR.W );
}
inline Complex Evaluation( Poly U, Poly V, int n ) {
static Poly P, Q, R;
for( ; n ; n >>= 1 ) {
int m = V.size(); R = V;
for( int i = 1 ; i < m ; i += 2 ) R[i] *= Complex( mod - 1 );
P = U * R, Q = V * R, U.clear(), V.clear();
m = Q.size(); for( int i = 0 ; i * 2 < m ; i ++ ) V.push_back( Q[i << 1] );
m = P.size(); for( int i = 0 ; i * 2 + n % 2 < m ; i ++ ) U.push_back( P[i * 2 + n % 2] );
}
return U[0] / V[0];
}
int main() {
Read( N ), Read( M ), Read( p ), Read( q );
Init();
rep( i, 0, M ) {
Complex c = ( ( M - i ) & 1 ? mod - C[i] : C[i] ) * cns;
bas[i] = RetType( { c },
{ 0, c * AhPw[i] * BhPw[M - i] },
{ 1, laPw[i] * muPw[M - i] * ( mod - 1 ) } );
}
RetType tmp( Divide( 0, M ) );
Poly A = tmp.V + tmp.U * ( Poly ) { 0, mod - Qkpow( q, M ) },
B = tmp.W + tmp.U * ( Poly ) { 0, Qkpow( q, M ) };
Write( Evaluation( A, B, N ).r ), putchar( '\n' );
// #ifdef _DEBUG
// rep( i, 1, N ) Write( Evaluation( tmp.U, D, i ).r ), putchar( " \n"[i == N] );
// #endif
return 0;
}