题目描述
现在有三个长度为 \(n\) 的序列 \(a,b,c\) 。你需要提取一个子序列 \(p_1,p_2,\dots,p_m\) ,满足如下条件:
- \(\forall i \in (1,m] ,p_i>p_{i-1}\) 。
- \(\forall i \in (1,m] ,a_i\geq a_{i-1}\) 。
- \(b_{p_1},b_{p_2},\dots,b_{p_m}\) 是互不相同的。
在此基础上最大化 \(\sum_{i=1}^m c_{p_1}\) ,或者报告无解。
\(n \leq 3\times 10^3,m \leqslant 5\) 。
思路点拨
看了一眼这个题目的官解,十分牛逼。但是这里要给一个正确性 \(100 \%\) 的 \(O(n^2 \log n)\) 做法。
考虑对于每一个 \(m\) ,都分类讨论一下:
\(m=1\)
直接枚举 \(p_1\) 就可以。时间复杂度 \(O(n)\) 。
\(m=2\)
直接枚举二元组就可以。时间复杂度 \(O(n^2)\) 。
\(m=3\)
考虑枚举 \(p_1,p_2\) ,接下来选择一个最优的 \(p_3\) 。一种不对的想法就是在 \((p_2,n]\) 下标范围内选择 \(c\) 最大的那个,但是这样可能导致它的 \(b\) 是 \(b_{p_1},b_{p_2}\) 中的一个。
我们记录 \(suf_{i,j}\) 表示考虑 \((i,n]\) 下标内满足 \(a \geq a_i\) 的数中,第 \(j\) 大的元素。这里的第 \(j\) 大要求 \(b\) 不可以与第 \(k(k<j)\) 大的元素相同。因为我们的目标是避免和 \(b_{p_1},b_{p_2}\) 相同,所以只需要记录前三大。对于每一个 \(i\) 都 \(O(n)\) 做一遍。
接下来在枚举 \(p_1,p_2\) 的过程中,找到 \((p_2,n]\) 下标范围内 \(c\) 最大的,且 \(b\) 与 \(b_{p_1},b_{p_2}\) 不同那一个元素,一定出现在 \(suf_{p_2,j}(1 \leq j \leq 3)\) 中。
时间复杂度 \(O(n^2)\) 。
\(m=4\)
考虑枚举 \(p_2,p_3\) ,接下来就是找到一个 \(p_1,p_4\) ,他们的 \(b\) 不和 \(b_{p_2},b_{p_3}\) 。还是考虑 \(m=3\) 时的错误想法,直接找 \(c\) 最大的,然后处理 \(b\) 相同的情况。
我们沿用 \(suf_{i,j}\) 的定义,但是 \(j\) 要取到 \(4\) ,因为有 \(4\) 个数。
再定义 \(pre_{i,j}\) 表示 \([1,i)\) 下标范围内,满足 \(a \leq a_i\) 的数中,\(c\) 第 \(j\) 大的元素,要求 \(b\) 不可以第 \(k(k<j)\) 大的相同。
接下来在枚举 \(p_2,p_3\) 之后,一定存在一组解满足 \(p_1\in pre_i,p_4\in suf_j\) 。时间复杂度 \(O(n^2)\) 。
\(m=5\)
首先沿用 \(pre,suf\) 数组,但是维护到第 \(5\) 大,因为会选出 \(5\) 个数。
这个时候我们枚举 \(p_2,p_4\) ,那么 \(p_1,p_5\) 就可以利用 \(pre_{p_2},suf_{p_4}\) 选出来,问题在于 \(p_3\) 。
实际上,想要求出 \(p_3\) 可能得取值,就需要维护一个数组 \(f_{i,j,k_1}\) 表示下标在 \((i,j)\) 范围内满足 \(a_i \leq a \leq a_j\) 的情况下,\(c\) 第 \(k_1\) 大的元素,但是要求 \(b\) 不与 \(k_2<k_1\) 的相同。如果对于每一个 \(i,j\) 都维护出这个数组就是时间复杂度 \(O(n^3)\) 的了,不可以接受。我们考虑固定一个 \(i\) ,然后 \(j\) 从小到大枚举,这个时候下标的限制就可以满足,还剩下 \(a\) 的限制。可以利用权值线段树,将 \(f\) 看做一个结构体就可以。
这样子可以通过 \(f_{p_1,p_2}\) 选出 \(p_3\) 。
表面上,这个时间复杂度是光鲜的 \(O(n^2 \log n)\) ,但是我不会告诉你实际时间是 \(O(m^3n^2+m^2n^2 \log n)\) 。
#include<bits/stdc++.h>
//#define int long long
using namespace std;
namespace fastIO{
inline int read(){
int x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-') f=-f;ch=getchar();}
while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
int buf[20],TOT;
inline void print(int x,char ch=' '){
if(x<0) putchar('-'),x=-x;
else if(x==0) buf[++TOT]=0;
for(int i=x;i;i/=10) buf[++TOT]=i%10;
do{putchar(buf[TOT]+'0');}while(--TOT);
putchar(ch);
}
}
using namespace fastIO;
const int MAXN=3e3+5,inf=1e9;
int n,m,a[MAXN],b[MAXN],c[MAXN];
int ans=-1;
struct node{
int pos[6],val[6];
node(){
for(int i=1;i<=5;i++)
pos[i]=0,val[i]=-inf;
}
void insert(int p,int v){
for(int i=1;i<=5;i++){
if(v>val[i]){
int flag=5;
for(int j=i;j<=5;j++)
if(pos[j]==p) flag=j;
for(int j=flag;j>i;j--)
pos[j]=pos[j-1],val[j]=val[j-1];
pos[i]=p,val[i]=v;
break;
}
else if(p==pos[i])
break;
}
}
node friend operator+(const node &a,const node &b){
node c;
for(int i=1;i<=5;i++)
c.insert(a.pos[i],a.val[i]);
for(int i=1;i<=5;i++)
c.insert(b.pos[i],b.val[i]);
return c;
}
};
node pre[MAXN],suf[MAXN],t[MAXN];
#define lowbit(x) (x&(-x))
void update(int x,pair<int,int> w){
for(int i=x;i<=n;i+=lowbit(i))
t[i].insert(w.first,w.second);
}
node query(int x){
node ans;
for(int i=x;i;i-=lowbit(i))
ans=ans+t[i];
return ans;
}
signed main(){
n=read(),m=read();
for(int i=1;i<=n;i++) a[i]=read();
for(int i=1;i<=n;i++) b[i]=read();
for(int i=1;i<=n;i++) c[i]=read();
if(m==1){
for(int i=1;i<=n;i++)
ans=max(ans,c[i]);
cout<<ans;
}
else if(m==2){
for(int i=1;i<=n;i++)
for(int j=i+1;j<=n;j++)
if(a[i]<=a[j]&&b[i]!=b[j])
ans=max(ans,c[i]+c[j]);
cout<<ans;
}
else if(m==3){
for(int i=1;i<=n;i++)
for(int j=i+1;j<=n;j++)
if(a[i]<=a[j]) suf[i].insert(b[j],c[j]);
for(int i=1;i<=n;i++){
for(int j=i+1;j<=n;j++){
if(a[i]<=a[j]&&b[i]!=b[j]){
for(int k=1;k<=5;k++){
if(b[i]==suf[j].pos[k]||b[j]==suf[j].pos[k])
continue;
ans=max(ans,c[i]+c[j]+suf[j].val[k]);
}
}
}
}
cout<<ans;
}
else if(m==4){
for(int i=1;i<=n;i++)
for(int j=i+1;j<=n;j++)
if(a[i]<=a[j]) suf[i].insert(b[j],c[j]);
for(int i=1;i<=n;i++)
for(int j=1;j<i;j++)
if(a[j]<=a[i]) pre[i].insert(b[j],c[j]);
for(int i=1;i<=n;i++){
for(int j=i+1;j<=n;j++){
if(a[i]<=a[j]&&b[i]!=b[j]){
for(int p=1;p<=5;p++){
if(pre[i].pos[p]==b[i]||pre[i].pos[p]==b[j]) continue;
for(int q=1;q<=5;q++){
if(pre[i].pos[p]!=suf[j].pos[q]&&suf[j].pos[q]!=b[i]&&suf[j].pos[q]!=b[j])
ans=max(ans,c[i]+c[j]+pre[i].val[p]+suf[j].val[q]);
}
}
}
}
}
cout<<ans;
}
else if(m==5){
for(int i=1;i<=n;i++)
for(int j=i+1;j<=n;j++)
if(a[i]<=a[j]) suf[i].insert(b[j],c[j]);
for(int i=1;i<=n;i++)
for(int j=1;j<i;j++)
if(a[j]<=a[i]) pre[i].insert(b[j],c[j]);
for(int i=1;i<=n;i++){
for(int j=1;j<=n;j++) t[j]=node();
for(int j=i+2;j<=n;j++){
if(a[i]<=a[j-1])
update(a[j-1],make_pair(b[j-1],c[j-1]));
if(a[i]<=a[j]&&b[i]!=b[j]){
node tmp=query(a[j]);
for(int x=1;x<=5;x++){
int p1=pre[i].pos[x];
if(!p1) continue;
if(p1==b[i]||p1==b[j]) continue;
for(int y=1;y<=5;y++){
int p2=tmp.pos[y];
if(!p2) continue;
if(p2==p1||p2==b[i]||p2==b[j]) continue;
for(int z=1;z<=5;z++){
int p3=suf[j].pos[z];
if(!p3) continue;
if(p3==p1||p3==p2||p3==b[i]||p3==b[j]) continue;
ans=max(ans,c[i]+c[j]+pre[i].val[x]+tmp.val[y]+suf[j].val[z]);
break;
}
}
}
}
}
}
cout<<ans;
}
return 0;
}
标签:suf,CF2003F,报告,int,复杂度,ch,leq,枚举,解题
From: https://www.cnblogs.com/-Aurore-/p/18379966