题意
给定一个\(1\sim N\)的排列,在这个排列中选出两段互不重叠的区间,求使选出的元素排序后构成公差为1的等差数列的方案数。选出的两段区间中元素构成的集合相同时视为同一种方案。\(1\le N\le 3\times 10^5\)。
分析
如果考虑怎么优化枚举的两个区间的话,发现不太好搞(反正我只会暴力)。
于是考虑枚举连续的值域区间,再判断一下连续的值域区间是由原排列中几段连续的区间构成,如果 \(\le 2\),就是可行的方案。
对于这种区间问题,一般套路是确定一个点,然后对其他点算贡献。
设\(f[l][r]\) 表示值域 \([l,r]\) 是由几段构成的,\(pos[i]\) 表示 \(i\) 这个值在原序列的位置,我们从 \(1\) 到 \(n\) 依次枚举右端点 \(i\),考虑从 \(i - 1\) 转移到 \(i\),那如何在 \(O(i)\) 的时间内转移呢?可以找到如下规律:
- 如果原序列中 \(i\) 在的位置左右两个数都 \(\le i\) ,那么肯定在之前加入了,而现在加入 \(i\) 会使 \([l, i], l \in [1, \min(a[pos[i]-1], a[pos[i] + 1])]\) 值域的段数 \(-1\),\([l, i], l \in (\min(a[pos[i]-1], a[pos[i] + 1]), \max(a[pos[i]-1], a[pos[i] + 1])]\)值域的段数不变,\([l, i], l \in (\max(a[pos[i]-1], a[pos[i] + 1]), i - 1]\) 值域的段数 \(+1\)。
- 如果只有一个,设那个数的位置为 \(x\),那么对于 \([l, i] ,l\in [1, x]\)值域的段数不变,\([l, i], l \in (x, i - 1]\) 的段数 \(+1\)。
- 如果没有,那么 \([l, i] ,l\in i - 1\)的值域的段数会 \(+1\)。
这样是 \(O(n ^ 2)\),区间加,区间减,很容易想到用线段树优化。
设枚举\(i\),线段树的区间 \([l,r]\),表示 \([x, i], x\in [l,r]\) 的各种信息。
我们需要线段树维护区间的最小段数的值,是这个最小段数的值的区间个数,和是次小段数这个值的区间个数。
最后询问$1∼i−1 $需要分的段数是否小于等于 \(2\) 即可
如何维护详见代码注释。
#include<bits/stdc++.h>
#define N 300005
#define int long long
#define ls u << 1
#define rs u << 1 | 1
using namespace std;
int read(){
int x = 0, f = 1; char ch = getchar();
while(ch < '0' || ch > '9'){if(ch == '-') f = -f; ch = getchar();}
while(ch >= '0' && ch <= '9'){x = (x << 1) + (x << 3) + (ch ^ 48); ch = getchar();}
return x * f;
}
int n, ans;
int a[N], pos[N];
int minn[N << 2], cnt0[N << 2], cnt1[N << 2], lazy[N << 2];
struct Segment{
void pushup(int u){
minn[u] = min(minn[ls], minn[rs]);
cnt0[u] = (minn[ls] == minn[u]) * cnt0[ls] + (minn[rs] == minn[u]) * cnt0[rs];
cnt1[u] = (minn[ls] == minn[u]) * cnt1[ls] + (minn[ls] == minn[u] + 1) * cnt0[ls];
cnt1[u] += (minn[rs] == minn[u]) * cnt1[rs] + (minn[rs] == minn[u] + 1) * cnt0[rs];
//如果左/右区间的最小值等于整个区间的最小值,那么左/右区间次小值就是整个区间的次小值,统计个数
//如果左/右区间的最小值等于整个区间的最小值 + 1,那么最小值就是整个区间的次小值,因为每次枚举 $i$ 时,
//值的变化最多加减1,所以次小值就是最小值 + 1。
}
void pushdown(int u){
minn[ls] += lazy[u], lazy[ls] += lazy[u];
minn[rs] += lazy[u], lazy[rs] += lazy[u];
lazy[u] = 0;
}
void build(int u, int l, int r){
if(l == r) return cnt0[u] = 1, void(); //初始都为 1
int mid = (l + r) >> 1;
build(ls, l, mid), build(rs, mid + 1, r);
pushup(u);
}
void update(int u, int l, int r, int L, int R, int val){
if(L <= l && r <= R) return minn[u] += val, lazy[u] += val, void();
pushdown(u);
int mid = (l + r) >> 1;
if(L <= mid) update(ls, l, mid, L, R, val);
if(R > mid) update(rs, mid + 1, r, L, R, val);
pushup(u);
}
int query(int u, int l, int r, int L, int R){
if(L <= l && r <= R) return cnt0[u] * (minn[u] <= 2) + cnt1[u] * (minn[u] <= 1);
//如果最小值小于等于 2 ,说明最小值是符合的,统计进去。
//如果最小值小于等于 1 , 次小值 = 最小值 + 1 , 次小值也符合。
pushdown(u);
int mid = (l + r) >> 1, res = 0;
if(L <= mid) res += query(ls, l, mid, L, R);
if(R > mid) res += query(rs, mid + 1, r, L, R);
return res;
}
}tr;
signed main(){
n = read();
for(int i = 1; i <= n; ++i) a[i] = read(), pos[a[i]] = i;
tr.build(1, 1, n);
for(int i = 1; i <= n; ++i){
tr.update(1, 1, n, 1, i, 1);
if(a[pos[i] - 1] < i && a[pos[i] - 1]) tr.update(1, 1, n, 1, a[pos[i] - 1], -1);
if(a[pos[i] + 1] < i && a[pos[i] + 1]) tr.update(1, 1, n, 1, a[pos[i] + 1], -1);
if(i >= 2) ans += tr.query(1, 1, n, 1, i - 1);//[i,i]这段是不能算进去的
}
printf("%lld\n", ans);
return 0;
}
标签:mid,int,题解,Segments,段数,Two,pos,值域,区间
From: https://www.cnblogs.com/jiangchen4122/p/17417767.html