来源
来自 EI 的 2020 年的论文《浅谈函数最值的动态维护》。
适用范围
给出一些形如 \(k_ix_i+b_i\) 的一次函数且 \(x_i\) 为已知值,支持动态对一次函数的 \(x_i\) 或 \(b_i\) 区间加,并快速查询一次函数的结果最值。
思想与实现
使用线段树,记录一个阈值 \(\Delta x\) 表示 “当前区间的 \(x\) 在整体加 \(\Delta x\) 后,存在某个线段树结点,其对应区间最大值位置发生改变”。
为了方便计算,我们线段树上的点存的最大值改为用直线存而不是一个数字。一条直线 \((k,b)\) 表示 “最大值为 \(b\),且他的 \(x\) 每 \(+1\) 这个值会增加 \(k\)”,注意我们维护的 \(b\) 会随着修改而变化。
考虑如何求 \(\Delta x\),取左右儿子的 \(\Delta x\) 的较小值,然后考虑当前点的最大值取的是左儿子还是右儿子,考虑两条直线,并计算出 “\(b\) 比较小的直线超过 \(b\) 比较大的直线需要增加的阈值”,和前两个一起取最小值。
定义 “转移” 为将两条直线按上述 “合并” 的过程。不难发现,\(\Delta x\) 需要涉及到线段树的当前结点子树内所有结点的 “转移”。
对于修改,打标记的时候将 \(\Delta x\) 减去打的标记,并修改直线的截距 \(b\)。当维护的 \(\Delta x\) 为负数时,暴力下传标记,然后递归下去继续更新,合并两个儿子时重新刷新 \(\Delta x\)。
根据 EI 的论文,可以证明,时间复杂度为 \(O((n+m)\log ^3n)\)。在一般情况下卡不到上界,差不多是 \(2\log\) 的。
点击查看代码
#define pil pair<line,ll>
struct line{
ll k,b;
line(){k=b=0;}
line(ll K,ll B){k=K, b=B;}
void add(ll v){b+=k*v;}
};
pil merg(line x,line y){
if(x.k<y.k||(x.k==y.k&&x.b<y.b)) swap(x,y);
if(x.b>=y.b) return mkp(x,inf);
return mkp(y,(y.b-x.b)/(x.k-y.k));
}
const pil operator+ (const pil &x,const pil &y){
pil tmp=merg(x.fi,y.fi);
return mkp(tmp.fi,min(min(x.se,y.se),tmp.se));
}
struct KTT{
ll tag[maxn<<2]; pil a[maxn<<2];
void addtag(ll p,ll v){
tag[p]+=v, a[p].se-=v;
a[p].fi.add(v);
}
void pushdown(ll p){
if(!tag[p]) return;
addtag(p<<1,tag[p]);
addtag(p<<1|1,tag[p]);
tag[p]=0;
}
void upd(ll p){
if(a[p].se>=0) return;
pushdown(p);
upd(p<<1), upd(p<<1|1);
a[p]=a[p<<1]+a[p<<1|1];
}
void build(ll p,ll l,ll r){
if(l==r){
a[p]=mkp(line(0,-inf),inf);
return;
} ll mid=l+r>>1;
build(p<<1,l,mid), build(p<<1|1,mid+1,r);
a[p]=a[p<<1]+a[p<<1|1];
}
void modify(ll p,ll l,ll r,ll ql,ll qr,ll v){
if(ql<=l&&r<=qr){
addtag(p,v);
upd(p); return;
}
ll mid=l+r>>1; pushdown(p);
if(ql<=mid) modify(p<<1,l,mid,ql,qr,v);
if(mid<qr) modify(p<<1|1,mid+1,r,ql,qr,v);
a[p]=a[p<<1]+a[p<<1|1];
}
void rep(ll p,ll l,ll r,ll x,line w){
if(l==r){
a[p].fi=w, a[p].se=inf;
return;
} ll mid=l+r>>1;
pushdown(p);
if(x<=mid) rep(p<<1,l,mid,x,w);
else rep(p<<1|1,mid+1,r,x,w);
a[p]=a[p<<1]+a[p<<1|1];
}
}T;
入门题
设 \(f[i]\) 表示激活第 \(i\) 个点,且考虑了前 \(i\) 个点和所包含的区间的答案。
枚举上一个点 \(j\),那么 \(j\) 贡献的区间满足左端点 \(\le j\),右端点 \(\ge j\)。我们根据右端点不断加入新的区间,记 \(c_i\) 表示点 \(i\) 被区间覆盖的次数。
\[f[i]=\max_{j=0}^{i-1}\{f[j]+c_j\cdot a_j\} \]\(j=0\) 表示 \(1\sim i-1\) 没有被激活的点,答案为 \(f[n+1]\),注意是先转移再加入区间。
每次加入区间会把 \(c_{l...r}\) 区间 \(+1\),\(f[j]\) 和 \(a_j\) 都是定值,这需要使用 KTT 维护。
具体操作看代码
点击查看代码
#include<bits/stdc++.h>
#define ll long long
#define ull unsigned ll
#define pir pair<ll,ll>
#define pil pair<line,ll>
#define fi first
#define se second
#define mkp make_pair
#define max(a,b) ((a)>(b)? (a):(b))
#define min(a,b) ((a)<(b)? (a):(b))
#define pb push_back
#define ls(p) a[p].lc
#define rs(p) a[p].rc
#define ad(a,b) ((a)+(b)>=mod? (a)+(b)-mod:(a)+(b))
using namespace std;
const ll maxn=1e6+10, inf=1e17;
struct line{
ll k,b;
line(){k=b=0;}
line(ll K,ll B){k=K, b=B;}
void add(ll v){b+=k*v;}
};
pil merg(line x,line y){
if(x.k<y.k||(x.k==y.k&&x.b<y.b)) swap(x,y);
if(x.b>=y.b) return mkp(x,inf);
return mkp(y,(y.b-x.b)/(x.k-y.k));
}
const pil operator+ (const pil &x,const pil &y){
pil tmp=merg(x.fi,y.fi);
return mkp(tmp.fi,min(min(x.se,y.se),tmp.se));
}
struct KTT{
ll tag[maxn<<2]; pil a[maxn<<2];
void addtag(ll p,ll v){
tag[p]+=v, a[p].se-=v;
a[p].fi.add(v);
}
void pushdown(ll p){
if(!tag[p]) return;
addtag(p<<1,tag[p]);
addtag(p<<1|1,tag[p]);
tag[p]=0;
}
void upd(ll p){
if(a[p].se>=0) return;
pushdown(p);
upd(p<<1), upd(p<<1|1);
a[p]=a[p<<1]+a[p<<1|1];
}
void build(ll p,ll l,ll r){
if(l==r){
a[p]=mkp(line(0,-inf),inf);
return;
} ll mid=l+r>>1;
build(p<<1,l,mid), build(p<<1|1,mid+1,r);
a[p]=a[p<<1]+a[p<<1|1];
}
void modify(ll p,ll l,ll r,ll ql,ll qr,ll v){
if(ql<=l&&r<=qr){
addtag(p,v);
upd(p); return;
}
ll mid=l+r>>1; pushdown(p);
if(ql<=mid) modify(p<<1,l,mid,ql,qr,v);
if(mid<qr) modify(p<<1|1,mid+1,r,ql,qr,v);
a[p]=a[p<<1]+a[p<<1|1];
}
void rep(ll p,ll l,ll r,ll x,line w){
if(l==r){
a[p].fi=w, a[p].se=inf;
return;
} ll mid=l+r>>1;
pushdown(p);
if(x<=mid) rep(p<<1,l,mid,x,w);
else rep(p<<1|1,mid+1,r,x,w);
a[p]=a[p<<1]+a[p<<1|1];
}
}T;
ll t,n,m,l,r,f[maxn],a[maxn];
vector<ll>vec[maxn];
int main()
{
scanf("%lld",&t);
while(t--){
scanf("%lld%lld",&m,&n);
T.build(1,1,n);
for(ll i=1;i<=m;i++){
scanf("%lld%lld",&l,&r);
vec[r].pb(l);
}
for(ll i=1;i<=n+1;i++){
f[i]=max(0ll,T.a[1].fi.b);
if(i>n) break;
scanf("%lld",a+i);
T.rep(1,1,n,i,line(a[i],f[i]));
for(ll j:vec[i])
T.modify(1,1,n,j,i,1);
vec[i].clear();
}
printf("%lld\n",f[n+1]);
}
return 0;
}
维护转移
这个应该算是比较正规的板子。
考虑平常我们是如何维护最大子段和的?维护四个数 \(sum,lmx,rmx,mx\) 表示区间和、最大前缀和、最大后缀和、最大子段和。
其实和之前说的维护区间最大值是一样的,只不过变成了维护四个数。两个东西取 \(\max\) 也变成了直线合并。
然后直线的斜率就是所维护的段的长度,显然区间加 \(v\) 那么对应的值需要加 长度\(\times v\)。
点击查看代码
#include<bits/stdc++.h>
#define ll long long
#define ull unsigned ll
#define pir pair<ll,ll>
#define pil pair<line,ll>
#define fi first
#define se second
#define mkp make_pair
#define max(a,b) ((a)>(b)? (a):(b))
#define min(a,b) ((a)<(b)? (a):(b))
#define pb push_back
#define ls(p) a[p].lc
#define rs(p) a[p].rc
#define ad(a,b) ((a)+(b)>=mod? (a)+(b)-mod:(a)+(b))
using namespace std;
const ll maxn=4e5+10, inf=1e17;
ll n,m,op,l,r,k;
struct line{
ll k,b;
line(){k=b=0;}
line(ll K,ll B){k=K,b=B;}
const line operator+ (const line t) const{
return line(k+t.k,b+t.b);
}
void add(ll v){b+=k*v;}
};
pil merg(line x,line y){
if(x.k<y.k||(x.k==y.k&&x.b<y.b)) swap(x,y);
if(x.b>=y.b) return mkp(x,inf);
return mkp(y,(y.b-x.b)/(x.k-y.k));
}
struct node{
line sum,lmx,rmx,mx; ll x;
node(){x=inf;}
node(line a,line b,line c,line d,ll e){sum=a, lmx=b, rmx=c, mx=d, x=e;}
const node operator+ (const node t) const{
node res; pil tmp;
res.x=min(x,t.x);
res.sum=sum+t.sum;
tmp=merg(lmx,sum+t.lmx), res.lmx=tmp.fi, res.x=min(res.x,tmp.se);
tmp=merg(t.rmx,rmx+t.sum), res.rmx=tmp.fi, res.x=min(res.x,tmp.se);
tmp=merg(mx,t.mx), res.mx=tmp.fi, res.x=min(res.x,tmp.se);
tmp=merg(res.mx,rmx+t.lmx), res.mx=tmp.fi, res.x=min(res.x,tmp.se);
return res;
}
}a[maxn<<2];
struct KTT{
ll tag[maxn<<2];
void addtag(ll p,ll v){
tag[p]+=v; a[p].x-=v;
a[p].sum.add(v), a[p].lmx.add(v), a[p].rmx.add(v), a[p].mx.add(v);
}
void pushdown(ll p){
if(!tag[p]) return;
addtag(p<<1,tag[p]);
addtag(p<<1|1,tag[p]);
tag[p]=0;
}
void upd(ll p){
if(a[p].x>=0) return;
pushdown(p);
upd(p<<1), upd(p<<1|1);
a[p]=a[p<<1]+a[p<<1|1];
}
void build(ll p,ll l,ll r){
if(l==r){
ll x;
scanf("%lld",&x);
line d(1,x);
a[p]=node(d,d,d,d,inf);
return;
}
ll mid=l+r>>1;
build(p<<1,l,mid), build(p<<1|1,mid+1,r);
a[p]=a[p<<1]+a[p<<1|1];
}
void modify(ll p,ll l,ll r,ll ql,ll qr,ll v){
if(ql<=l&&r<=qr){
addtag(p,v);
upd(p); return;
} pushdown(p);
ll mid=l+r>>1;
if(ql<=mid) modify(p<<1,l,mid,ql,qr,v);
if(mid<qr) modify(p<<1|1,mid+1,r,ql,qr,v);
a[p]=a[p<<1]+a[p<<1|1];
}
node query(ll p,ll l,ll r,ll ql,ll qr){
if(ql<=l&&r<=qr){
upd(p);
return a[p];
}
pushdown(p);
ll mid=l+r>>1;
if(qr<=mid) return query(p<<1,l,mid,ql,qr);
if(mid<ql) return query(p<<1|1,mid+1,r,ql,qr);
return query(p<<1,l,mid,ql,qr)+query(p<<1|1,mid+1,r,ql,qr);
}
}T;
int main()
{
scanf("%lld%lld",&n,&m);
T.build(1,1,n);
while(m--){
scanf("%lld%lld%lld",&op,&l,&r);
if(op==1){
scanf("%lld",&k);
T.modify(1,1,n,l,r,k);
} else{
node ret=T.query(1,1,n,l,r);
printf("%lld\n",max(0ll,ret.mx.b));
}
}
return 0;
}