题意
有 \(n\) 个物品,和一个背包容量上限 \(m\)。每个物品有价值 \(v_i\) 和体积 \(a_i\)。
你需要选择一段区间 \([l,r]\),将这个区间内的体积变为 \(b_i\),剩下的不变。然后你对这 \(n\) 个物品做背包,设背包容量结果为 \(f(i)\),需要求出有多少段区间使得 \(\dfrac{\sum_{i=1}^m f(i)}{m}\le E\)。
\(n,k\le 2\times10^5,nk\le 10^7\)。
分析
令 \(p_i\) 为最小的满足 \([i,p_i]\) 合法的数。那么答案就是 \(\sum_i n-p_i+1\)。
首先,需要注意到 \(p_i\) 单调不降。暴力的话直接双指针背包即可,\(O(n^2k)\),飞了。
由于 \(p_i\) 满足决策单调性那样的性质,考虑套路性地分治,考虑设 \(\operatorname{solve}(l,r,L,R)\) 表示计算 \([l,r]\) 这段区间中的 \(p\),\(p\) 的取值范围落在 \([L,R]\),且不在 \([l,r]\cup[L,R]\) 的物品已经被加入背包。令 \(M=\lfloor\frac{L+R}{2}\rfloor\),考虑二分答案找到最大的满足 \(p_i\le M\) 的下标,记作 \(m\),然后我们就把问题划分成了 \(\operatorname{solve}(l,m,L,M),\operatorname{solve}(m+1,r,M+1,R)\) 两个子问题,分别递归求解即可。
分析时间复杂度:若每次二分都暴力将 \([l,r]\cup[L,R]\) 内的物品加入,每一层中物品都要加入 \(O(n\log n)\) 次,分治一共 \(O(\log n)\) 层,每次加入物品的复杂度显然 \(O(k)\),故复杂度 \(O(nk\log^2n)\),飞了。
考虑优化,发现实际上很多情况下物品都被重复加入了。考虑在二分前 \([L,R]\) 的取值(取 \(a_i\) 或 \(b_i\))就已经确定了,提前将这些物品不在 \([l,r]\) 中的部分加入。考虑二分 \(mid\) 时实际上就是把 \([mid,r]\) 中的物品归为 \(b_i\),\([l,mid)\) 归为 \(a_i\),所以考虑在二分指针右移时(即 \(l=mid+1\))时 \([l,mid]\) 中的物品就永远是 \(a_i\) 类的了,直接把这些物品加入背包即可。二分指针左移同理。这样物品加入次数就降为了 \(O(n)\),复杂度就是 \(O(nk\log n)\),看上去还是飞了但就是能过。
小细节:举个例子,比如往左递归时需要将在 \([l,r]\cup[L,R]\) 但不在 \([l,m]\cup[L,M]\) 的物品加入,根据推导我们应该将 \([M+1,R]\) 划给 \(a_i\),将 \([m+1,r]\) 划给 \(b_i\),但这两段区间可能会有交集,需要分类讨论取哪一个。自行画图不难理解。
小细节 2:注意特殊处理一下 \(p_i>n\) 的情况,即 \([i,n]\) 不合法。
点击查看代码
#include<iostream>
#include<cstdio>
#include<cstring>
#include<string>
#include<algorithm>
#include<cmath>
#include<map>
#include<vector>
#include<queue>
#include<stack>
#include<bitset>
#include<set>
#include<ctime>
#include<random>
#include<cassert>
#define IOS ios::sync_with_stdio(false)
#define PY puts("Yes")
#define PN puts("No")
#define PW puts("-1")
#define P0 puts("0")
#define P__ puts("")
#define PU puts("--------------------")
#define mp make_pair
#define fi first
#define se second
#define pc putchar
#define pb emplace_back
#define un using namespace
#define popc __builtin_popcountll
#define all(x) x.begin(),x.end()
#define rep(a,b,c) for(int a=(b);a<=(c);++a)
#define per(a,b,c) for(int a=(b);a>=(c);--a)
#define reprange(a,b,c,d) for(int a=(b);a<=(c);a+=(d))
#define perrange(a,b,c,d) for(int a=(b);a>=(c);a-=(d))
#define graph(i,j,k,l) for(int i=k[j];i;i=l[i].nxt)
#define lowbit(x) (x&-x)
#define lson(x) (x<<1)
#define rson(x) (x<<1|1)
#define mem(x,y) memset(x,y,sizeof x)
//#define double long double
//#define int long long
//#define int __int128
using namespace std;
using i64=long long;
using u64=unsigned long long;
using pii=pair<int,int>;
inline int rd(){
int x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+ch-48;ch=getchar();}return x*f;
}
template<typename T>
inline void write(T x,char ch='\0'){
if(x<0){x=-x;putchar('-');}
int y=0;char z[40];
while(x||!y){z[y++]=x%10+48;x/=10;}
while(y--)putchar(z[y]);if(ch!='\0')putchar(ch);
}
bool Mbg;
const int maxn=2e5+5,maxm=4e5+5,inf=0x3f3f3f3f;
const long long llinf=0x3f3f3f3f3f3f3f3f;
int n;
i64 m,E;
int val[maxn],wa[maxn],wb[maxn];
i64 ans;
vector<i64>f;
stack<vector<i64> >sta;
inline void add(int x,int typ){
int w=typ?wb[x]:wa[x];
per(i,m,w)f[i]=max(f[i],f[i-w]+val[x]);
}
inline bool ck(int l,int r,int mid){
sta.emplace(f);
rep(i,l,mid-1)add(i,0);
rep(i,mid,r)add(i,1);
i64 sum=0;
rep(i,1,m)sum+=f[i];
bool ok=sum<=m*E;
// write(mid,32),write(r,32),write(sum,10);
f=sta.top();sta.pop();
if(ok){
rep(i,l,mid)add(i,0);
}else{
rep(i,mid,r)add(i,1);
}
return ok;
}
inline void solve(int l,int r,int ll,int rr,bool valid=false){
// write(l,32),write(r,32),write(ll,32),write(rr,10);
if(l>r||ll>rr)return;
if(ll==rr&&valid){
ans+=1ll*(r-l+1)*(n-ll+1);
return;
}
const int mm=(ll+rr)>>1;
sta.emplace(f);
int L=l,R=min(mm,r),res=L-1;
rep(i,ll,mm)if(!(L<=i&&i<=R))add(i,1);
rep(i,mm+1,rr)if(!(L<=i&&i<=R))add(i,0);
while(L<=R){
int mid=(L+R)>>1;
if(ck(L,R,mid))res=mid,L=mid+1;
else R=mid-1;
}
f=sta.top();
rep(i,mm+1,rr)if(!(l<=i&&i<=res))add(i,0);
rep(i,res+1,r)if(!(ll<=i&&i<=rr))add(i,1);
solve(l,res,ll,mm,1);
f=sta.top();sta.pop();
rep(i,l,res)if(!(mm<i&&i<=rr))add(i,0);
rep(i,ll,mm)if(!(l<=l&&i<=r))add(i,1);
solve(res+1,r,mm+1,rr,0);
}
inline void solve_the_problem(){
n=rd(),m=rd(),E=rd();
rep(i,1,n)val[i]=rd(),wa[i]=rd(),wb[i]=rd();
f.resize(m+1,0);
solve(1,n,1,n);
write(ans);
}
bool Med;
signed main(){
// freopen(".in","r",stdin);freopen(".out","w",stdout);
// fprintf(stderr,"%.3lfMB\n",(&Mbg-&Med)/1048576.0);
int _=1;
while(_--)solve_the_problem();
}
/*
*/