容斥神仙题 qwq。
题意
给定一个长度为 \(n\) 的序列 \(a_1,a_2,\cdots ,a_n\),输出满足如下条件的序列 \(x\) 的方案数:
- \(1\leq x_i\leq a_i\)
- \(x_i\neq x_{i+1}(1\leq i\leq n-1)\)
- \(2\le n\le 5\times 10^5,1\le a_i\le 10^9\)
Solution
首先我们规定形如 \(x_{i} = x_{i + 1}\) 的点为坏点。
令 \(f_{i}\) 表示至少有 \(i\) 个坏点的方案数,那么套上容斥的式子:
\[ans = \sum_{i=0}^{n-1} (-1) ^ i\times f_i \]那么我们令所有的坏点向左合并,即 $x_i \gets x_i,x_{i + 1} $ 其中 \(x_i = x_{i + 1}\),假设有 \(k\) 个坏点,那么恰好减少 \(k\) 个数,而这等价于将序列划分成 \(n - k\) 段,每段内都是相同的数。
注意这里只是强制限定了每一段内相同,不代表相邻段之间必须不同,因为我们是至少。
所以我们考虑用 dp 求出这样划分的方案数,令 \(dp[i][j]\) 表示填了 \(i\) 个数,划分出 \(j\) 段的方案数。
那么对答案进行改写:
\[ans = \sum_{i = 0}^{n - 1} (-1)^i \times dp[n][n - i] \]先考虑朴素转移:
\[dp[i][j] = \sum_{k = 0}^{i - 1} dp[k][j - 1] \times \min_{o = k + 1}^{i}a_o \]代表添加一段 \((k + 1,i)\) 的线段,有 \(\min_{o = k + 1}^{i}a_o\) 种填法。
复杂度还不够优秀。
我们发现最后只关心划分出的段数的奇偶,所以可以直接改写成 \(dp[i][0/1]\) 代表划分成 偶数/奇数
段的方案数。
方程又变成了:
\[dp[i][op] = \sum_{k = 0}^{i - 1} dp[k][op\text{ xor }1] \times \min_{o = k + 1}^{i}a_o \]那么前面那个东西可以前缀和优化。
后面那一坨,我们维护单调栈,找到左边第一个小于 \(a[i]\) 的。
- 如果不存在这样的,那么 \(a[i]\) 就是最小值,直接前缀和乘上即可。
- 否则考虑令前面的位置为 \(p\) ,则 \([p + 1...i,i]\) 这些区间都是以 \(a[i]\) 为最小值的。也就是说你可以从 \(dp[p...i - 1]\) 转移过来,因为我们的转移相当于在这个位置后面插入一段。在这之前的最小值都和 \(a[i]\) 没关系!!所以直接继承 \(dp[p][op]\) 就是贡献,因为这个玩意就等价于前面的 $\sum dp[k][op \text{ xor } 1] \times \min $。
最后根据 \(n\) 的奇偶性,我们可以得出 \(dp[n][0/1]\) 对应的奇偶性,也就是 \(n - 2\times k\) 和 \(n - (2\times k + 1)\) 的奇偶性,用偶的减去奇的即可。
#include<bits/stdc++.h>
#define rep(i,a,b) for(int i=(a);i<=(b);++i)
#define Rep(i,a,b) for(int i=(a);i<(b);++i)
#define rrep(i,a,b) for(int i=(a);i>=(b);--i)
using namespace std;
template <typename T>
inline void read(T &x){
x=0;char ch=getchar();bool f=0;
while(ch<'0'||ch>'9'){if(ch=='-')f=1;ch=getchar();}
while(ch>='0'&&ch<='9')x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
if(f)x=-x;
}
template <typename T,typename ...Args>
inline void read(T &tmp,Args &...tmps){read(tmp);read(tmps...);}
const int N = 5e5 + 5;
const int mod = 998244353;
inline void add(int &x,int y){
if((x += y) >= mod)x -= mod;
}
vector<int>S;
int dp[N][2],s[N][2];
int a[N],n;
signed main(){
read(n);
rep(i,1,n)read(a[i]);
dp[0][0] = s[0][0] = 1;
rep(i,1,n){
while(!S.empty() && a[S.back()] >= a[i])S.pop_back();
if(S.empty()){
for(int x : {0,1})dp[i][x] = 1ll * s[i - 1][x ^ 1] * a[i] % mod;
}
else{
for(int x : {0,1})dp[i][x] = (dp[S.back()][x] + 1ll * (s[i - 1][x ^ 1] - s[S.back() - 1][x ^ 1] + mod) * a[i] % mod) % mod;
}
S.push_back(i);
s[i][0] = (s[i - 1][0] + dp[i][0]) % mod;
s[i][1] = (s[i - 1][1] + dp[i][1]) % mod;
}
int ans = (dp[n][0] - dp[n][1] + mod) % mod;
if(n & 1)ans = 1ll * ans * (mod - 1) % mod;
printf("%d",ans);
}
都看到这了留个赞呗qwq。