D - LIS 2
难搞的地方在于取min,考虑比较\((a_i \oplus a_j,b_i \oplus b_j)\)两者的过程:是在它们第一位不一样的地方比较,取该位为0的那个。
而判断两个数在某位是否相等,可以想到异或操作,然后把这两者异或起来后,由异或运算的交换律可得等价于\((a_i \oplus b_i) \oplus (a_j \oplus b_j)\),这样就转成两个位置独立的式子的异或值,然后枚举这个第一个为1的位置,在trie树上记一些东西,再类似地查一下就行。
点击查看代码
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=5e5+5,M=18;
int n,a[N],b[N],ch[N][2],cnt=1;
ll tot[N][2],suma[N][2][M],sumb[N][2][M];
int main() {
cin>>n;
for(int i=1;i<=n;i++) scanf("%d",&a[i]);
for(int i=1;i<=n;i++) scanf("%d",&b[i]);
for(int i=1;i<=n;i++){
int z=(a[i]^b[i]);
int u=1;
for(int p=M-1;p>=0;p--){
int q=(z&(1<<p))>>p,v=(a[i]&(1<<p))>>p;
if(!ch[u][q]) ch[u][q]=++cnt;
u=ch[u][q];
for(int k=0;k<M;k++){
if(a[i]&(1<<k)) suma[u][v][k]++;
if(b[i]&(1<<k)) sumb[u][v][k]++;
}
tot[u][v]++;
}
}
ll ans=0;
for(int i=1;i<=n;i++){
//cout<<"i="<<i<<endl;
int z=(a[i]^b[i]);
int u=1;
for(int p=M-1;p>=0;p--){
int q=(z&(1<<p))>>p;
int cur=ch[u][!q];
int v=(a[i]&(1<<p))>>p;
//v=(!v);
for(int k=0;k<M;k++){
if(a[i]&(1<<k)) ans+=(tot[cur][v]-suma[cur][v][k])*(1<<k);
else ans+=suma[cur][v][k]*(1<<k);
}
//cout<<ans<<endl;
v=(!v);
for(int k=0;k<M;k++){
if(b[i]&(1<<k)) ans+=(tot[cur][v]-sumb[cur][v][k])*(1<<k);
else ans+=sumb[cur][v][k]*(1<<k);
}
u=ch[u][q];
//cout<<p<<" "<<ans<<endl;
}
int cur=u;
for(int v=0;v<2;v++)
for(int k=0;k<M;k++){
if(a[i]&(1<<k)) ans+=(tot[cur][v]-suma[cur][v][k])*(1<<k);
else ans+=suma[cur][v][k]*(1<<k);
}
//cout<<"ans="<<ans<<endl;
}
cout<<ans/2<<endl;
return 0;
}
E - Priority Queue
一般的思路是,考虑如何判断某个最终集合是否合法,但这样会得到非常复杂的转移,不是很可做。
从边界的角度,考虑令留下的集合最小和最大的情况(非严格定义,感性理解):最小显然可以做到删掉最大的那些元素;最大就是从小到大加入元素。然后发现这样做之后,对一个数\(v\),最小的做法使得小于它的元素最多;而最大的做法使得大于它的元素最多。
然后考虑把最终留下的元素升序排序,则对于某个合法的答案,每个元素必须要在,上面得出的对应位置的值作为上下界的区间内。(由上面的结论可以反证出来)
然后可以证明,这个条件是充分的。因为对于留下的最大的集合(删去的最小),一定可以通过修改,将本来要删掉的集合换成更大的(直接一一对应过去即可)。
点击查看代码
#include <bits/stdc++.h>
using namespace std;
const int N=10005,P=998244353;
void inc(int& x,int y){
x+=y;
if(x>=P) x-=P;
if(x<0) x+=P;
}
int sum(int x,int y){
x+=y;
if(x>=P) x-=P;
if(x<0) x+=P;
return x;
}
void mul(int& x,int y){
x=1ll*x*y%P;
}
int prd(int x,int y){
return 1ll*x*y%P;
}
int n,m,a[N],mx[N];
set<int>st;
int pos[N],f[N][N];
int main() {
cin>>n>>m;
for(int i=1;i<=n;i++) mx[i]=1;
int tot=0;
for(int i=1;i<=n+m;i++){
int x;
scanf("%d",&x);
if(x==1){
st.insert(++tot);
}
else{
set<int>:: iterator it=st.end();
it--;
mx[*it]=0;
st.erase(it);
}
}
for(int i=1,j=0;i<=n;i++) if(mx[i]) pos[++j]=i;//cout<<pos[j]<<" "; puts("");
f[0][0]=1;
for(int j=1;j<=n;j++) inc(f[0][j],f[0][j-1]);
for(int i=1;i<=n-m;i++){
for(int j=1;j<=pos[i];j++) f[i][j]=f[i-1][j-1];
for(int j=1;j<=n;j++) inc(f[i][j],f[i][j-1]);
}
cout<<f[n-m][n]<<endl;
return 0;
}