三倍经验:CF1591F. Non-equal Neighbours,ARC115E - LEQ and NEQ。
提供一种力大砖飞的数据结构 \(O(n\log n)\) 做法,非常好写/好调,去掉数据结构部分只有 1k。
定义 \(f_{i,j}\) 表示前 \(i\) 个数,最后一个为 \(j\) 的方案数。显然第 1 维可以压掉,写成 \(f_j\) 的形式。
然后这个东西可以前缀和做到 \(O(\sum a)\)。更具体地说,对于前 \(i-1\) 个数,定义 \(s=\sum\limits_{k=1}^{a_{i-1}}f_k\),加上第 \(i\) 个数之后有 \(f_k=s-f_k\)。这个东西似乎不能优化了。
但是,我们可以发现,对于很多连续的 \(f_k\),他们的值是一样的:对于 \(a_i\ge a_{i-1}\),由于原来从 \(a_{i-1}+1\) 到 \(a_i\) 的这些位置都没有值,所以相当于在最后插入了值为 \(s\) 的一段;对于 \(a_i<a_{i-1}\),相当于舍弃后面一部分 dp 值。
当然,每次剩余的那些段都会把值从 \(v_i\) 变成 \(s-v_i\),但这并不影响。
每次最多加入一段,所以最多 \(n\) 段;每段最多加一次删一次,故时间复杂度 \(O(n\log n)\)。那个 \(\log n\) 是用线段树维护每段具体值的时间。
code:
点击查看代码
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int inf=1e18,mod=998244353;
inline int read(){
int x=0,f=1;char ch=getchar();
while (!isdigit(ch)){if (ch=='-') f=-1;ch=getchar();}
while (isdigit(ch)){x=x*10+ch-48;ch=getchar();}
return x*f;
}
int a[500005];
struct Node{
int l,r;
}q[500005];
struct segtree{
#define ls p<<1
#define rs p<<1|1
#define lson l,mid,ls
#define rson mid+1,r,rs
struct Node{
int s,add,mul;
}c[2000005];
void pushup(int p){
c[p].s=(c[ls].s+c[rs].s+mod)%mod;
}
void pushdown(int l,int r,int p){
if(c[p].mul!=1){
c[ls].s=(c[ls].s*c[p].mul%mod+mod)%mod;
c[rs].s=(c[rs].s*c[p].mul%mod+mod)%mod;
c[ls].add=(c[ls].add*c[p].mul%mod+mod)%mod;
c[rs].add=(c[rs].add*c[p].mul%mod+mod)%mod;
c[ls].mul=(c[ls].mul*c[p].mul%mod+mod)%mod;
c[rs].mul=(c[rs].mul*c[p].mul%mod+mod)%mod;
c[p].mul=1;
}
if(c[p].add!=0){
int siz=r-l+1,ln=siz-(siz>>1),rn=siz>>1;
c[ls].s=(c[ls].s+ln*c[p].add%mod+mod)%mod;
c[rs].s=(c[rs].s+rn*c[p].add%mod+mod)%mod;
c[ls].add=(c[ls].add+c[p].add+mod)%mod;
c[rs].add=(c[rs].add+c[p].add+mod)%mod;
c[p].add=0;
}
}
void build(int l,int r,int p){
c[p].add=0;
c[p].mul=1;
if(l==r){
c[p].s=0;
return;
}
int mid=(l+r)>>1;
build(lson);
build(rson);
pushup(p);
}
void mul(int l,int r,int p,int L,int R,int k){
if(L>R)return;
if(L<=l&&r<=R){
c[p].s=(c[p].s*k%mod+mod)%mod;
c[p].add=(c[p].add*k%mod+mod)%mod;
c[p].mul=(c[p].mul*k%mod+mod)%mod;
return;
}
int mid=(l+r)>>1;pushdown(l,r,p);
if(L<=mid)mul(lson,L,R,k);
if(R>mid)mul(rson,L,R,k);
pushup(p);
}
void add(int l,int r,int p,int L,int R,int k){
if(L>R)return;
if(L<=l&&r<=R){
c[p].s=(c[p].s+(r-l+1)*k%mod+mod)%mod;
c[p].add=(c[p].add+k+mod)%mod;
return;
}
int mid=(l+r)>>1;pushdown(l,r,p);
if(L<=mid)add(lson,L,R,k);
if(R>mid)add(rson,L,R,k);
pushup(p);
}
int query(int l,int r,int p,int L,int R){
if(L>R)return 0;
if(L<=l&&r<=R)return c[p].s;
int mid=(l+r)>>1,res=0;pushdown(l,r,p);
if(L<=mid)res=(res+query(lson,L,R)+mod)%mod;
if(R>mid)res=(res+query(rson,L,R)+mod)%mod;
return res;
}
#undef ls
#undef rs
#undef lson
#undef rson
}Tr;
void solve(){
int n=read(),L=1,R=0,sum=0;
for(int i=1;i<=n;i++)a[i]=read();
Tr.build(1,n,1);
q[++R]=(Node){1,a[1]},sum=(sum+a[1])%mod;
Tr.mul(1,n,1,R,R,0);Tr.add(1,n,1,R,R,1);
for(int i=2;i<=n;i++){
if(a[i]>=a[i-1]){
Tr.mul(1,n,1,L,R,-1);Tr.add(1,n,1,L,R,sum);
q[++R]=(Node){a[i-1]+1,a[i]};
Tr.mul(1,n,1,R,R,0);Tr.add(1,n,1,R,R,sum);
sum=(sum*a[i]%mod-sum+mod)%mod;
}
else{
int nsum=sum;
while(L<=R&&q[R].l>a[i])nsum=(nsum-Tr.query(1,n,1,R,R)*(q[R].r-q[R].l+1)%mod+mod)%mod,R--;
if(L<=R&&q[R].r>a[i])nsum=(nsum-Tr.query(1,n,1,R,R)*(q[R].r-a[i])%mod+mod)%mod,q[R].r=a[i];
Tr.mul(1,n,1,L,R,-1);Tr.add(1,n,1,L,R,sum);
sum=(sum*a[i]%mod-nsum+mod)%mod;
}
}
printf("%lld\n",sum);
}
signed main(){
int T=1;
while(T--)solve();
return 0;
}