简要题意
给定 \(n\),\(k\) 和值域 \([1,k]\) 的 \(n\) 个整数 \(h_i\),求有多少个长为 \(n\) 的整数序列 \(a\) 满足值域 \([1,k]\),且 \(\sum\limits_{i=1}^n[a_i=h_i]<\sum\limits_{i=1}^n[a_i=h_{(i\bmod{n})+1}]\)。答案对 \(998244353\) 取模。
\(1 \leq n \leq 2\times 10^5,1 \leq k \leq 10^9\)
思路
如果你不会这道题的 Easy Version,可以先看看 这篇博客。这道题与简单版的最大区别是这道题无法使用 \(O(n^2)\) 的 DP。
简单版的解法是:令 \(f_{i,j}\) 为截止到第 \(i\) 位,比原来的少错 \(j\) 个数的方案数。然后可以 \(O(1)\) 转移。总时间复杂度 \(O(n^2)\)。
- 结论 1: \(f_{n,i}=f_{n,-i}\)。因为对 \(j\) 道题与错 \(j\) 道题的方案本质上是对称的。
- 结论 2:\(\sum_{i=1}^{n}f_{n,i}=\dfrac{k^n-f_0}{2}\)。因为 \(f_{n,i}=f_{n,-i}\),所以 \(k^n-f_0=\sum_{i=-n}^{-1}f_i+\sum_{i=1}^{n}f_i=2\sum_{i=1}^{n}f_{n,i}\)。
所以我们其实要求的是 \(f_{n,0}\)。也就是不改变正确个数的方案总数。
不改变正确的个数的方案总数一定是变换时正确了 \(k\) 个,错了 \(k\) 个。且 \(0\leq k\leq\lfloor\dfrac{M}{2}\rfloor\),其中 \(M\) 是 \(h\) 中变换后不同的个数。
所以:
\[f_{n,0}=\sum_{i=1}^{\frac{M}{2}}{F(i)} \]\(F(i)\) 是给出正确 / 错误的个数,求方案数。我们考虑如何求 \(F(i)\)。
考虑 \(F(i)\) 的来自三个部分的贡献:
- 正确的 \(i\) 个,为 \(\binom{M}{i}\)。
- 错误的 \(i\) 个,为 \(\binom{M-i}{i}(k-2)^{M-2i}\)。注意原来的和错误的都要错,所以是 \(M-2i\)。要排除正确的和本身的,所以是 \(k-2\)。
- 没有贡献的,为 \(k^{M}\)。
所以 \(f_{n,0}\) 就求出来了:
\[f_{n,0}=\sum_{i=1}^{\frac{M}{2}}\binom{M}{i}\binom{M-i}{i}(k-2)^{M-2i}k^{M} \]预处理阶乘、阶乘逆元、\(k\) 的幂和 \(k-2\) 的幂,可以做到严格线性。
时间复杂度 \(O(n)\)。
代码
#include <bits/stdc++.h>
#pragma GCC optimize("Ofast", "inline", "-ffast-math")
#pragma GCC target("avx,sse2,sse3,sse4,mmx")
#define int long long
using namespace std;
const int mod = 998244353;
const int N = 2e5+5;
int n,k,m,h[N],fact[N]={1}, inv[N]={1,1}, ret, k2p[N]={1}, kp[N]={1};
int M(const int x){return (x%mod+mod)%mod;}
int c(int m, int n){assert(n-m>=0);return M(fact[n]*M(inv[n-m]*inv[m]));}
signed main(){
ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
cin>>n>>k;
if(k==1){cout<<0;return 0;}
for(int i=1;i<=n;i++) cin>>h[i];
for(int i=1;i<=n;i++) m+=(h[i]!=h[i%n+1]);
for(int i=1;i<=n;i++) fact[i]=M(fact[i-1]*i);
for(int i=2;i<=n;i++) inv[i]=M(M(mod-mod/i)*inv[mod%i]);
for(int i=2;i<=n;i++) inv[i]=M(inv[i-1]*inv[i]);
for(int i=1;i<=n;i++) kp[i]=M(kp[i-1]*k);
for(int i=1;i<=n;i++) k2p[i]=M(k2p[i-1]*(k-2));
for(int i=0;i<=(m>>1);i++){
ret=M(ret+M(M(c(i,m)*c(i,m-i))*M(k2p[m-(i<<1)]*kp[n-m])));
}
cout<<M((kp[n]-ret)*inv[2])<<'\n';;
return 0;
}
标签:int,sum,Hard,个数,leq,Version,test,binom,mod
From: https://www.cnblogs.com/zheyuanxie/p/cf1227f2.html