题目描述
简要描述:给定一个长度为 \(N\) 的数组,求数组的子数组满足最大值为 \(X\) 且最小值为 \(Y\) 的子区间的个数。
做法
1. ST表 + 二分
时间复杂度: \(O(n \log n)\)
对于每个位置,二分出以它为左端点最大值为 \(X\) 的最远和最近的位置,以及以它为左端点最小值为 \(Y\) 的最远和最近的位置,然后对两个区间求区间交即可统计答案
代码
点击查看代码
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
template <typename T> void chkmax(T &x, T y) { x = max(x, y); }
template <typename T> void chkmin(T &x, T y) { x = min(x, y); }
//使用的时候注意判断如果 多个最值 取左边下标还是右边下标
//分别对应 >= > <= <
template<typename T>
class ST{
public:
ST(vector<T> a, int _n) : a(a), n(_n) { // cope with in [0,n-1]
lg.resize(n + 1); lg[1] = 0;
for (int i = 2; i <= n; i ++ ) lg[i] = lg[i >> 1] + 1;
int m = lg[n] + 1;
maxv.resize(m); minv.resize(m);
for (int i = 0; i < m; i ++ ) maxv[i].resize(n), minv[i].resize(n);
for (int i = 0; i < n; i ++ ) maxv[0][i] = minv[0][i] = a[i];
for (int i = 1; i < m; i ++ ) {
for (int j = 0; j <= n - (1 << i); j ++ ) {
maxv[i][j] = max(maxv[i - 1][j], maxv[i - 1][j + (1 << (i - 1))]);
minv[i][j] = min(minv[i - 1][j], minv[i - 1][j + (1 << (i - 1))]);
}
}
}
T getmax(int l,int r){
int k = lg[r - l + 1];
return max(maxv[k][l], maxv[k][r - (1 << k) + 1]);
}
T getmin(int l,int r){
int k = lg[r - l + 1];
return min(minv[k][l], minv[k][r - (1 << k) + 1]);
}
private:
int n;
vector<T> a;
vector<int> lg;
vector<vector<T>> maxv, minv;
};
void solve() {
int n, l, r; cin >> n >> r >> l;
vector<int> a(n); for (int &x: a) cin >> x;
ST st(a, n);
ll ans = 0;
for (int i = 0; i < n; i ++ ) if (a[i] >= l && a[i] <= r) {
int L = i, R = n - 1;
while (L < R) {
int MID = (L + R + 1) / 2;
if (st.getmax(i, MID) > r) R = MID - 1;
else L = MID;
}
//此时的L就是最大值的右边界
if (st.getmax(i, L) != r) continue;
int maxr = L;
L = i, R = R;
while (L < R) {
int MID = (L + R) / 2;
if (st.getmax(i, MID) >= r) R = MID;
else L = MID + 1;
}
int maxl = L;
if (st.getmax(i, L) != r) continue;
if (maxl > maxr) continue;
L = i, R = n - 1;
while (L < R) {
int MID = (L + R + 1) / 2;
if (st.getmin(i, MID) < l) R = MID - 1;
else L = MID;
}
int minr = L;
if (st.getmin(i, L) != l) continue;
L = i, R = R;
while (L < R) {
int MID = (L + R) / 2;
if (st.getmin(i, MID) <= l) R = MID;
else L = MID + 1;
}
int minl = L;
if (st.getmin(i, L) != l) continue;
if (minl > minr) continue;
if (minl > maxr || maxl > minr) continue;
ans += min(maxr, minr) - max(maxl, minl) + 1;
}
cout << ans << "\n";
}
int main() {
cin.tie(nullptr)->sync_with_stdio(false);
solve();
return 0;
}
2. 三指针
时间复杂度: \(O(n)\)
对于每个位置,统计以它为右端点的区间个数。只需要知道当前位置左边,距离它最近的且值为 \(X\) 的位置 \(i_1\),距离它最近的且值为 \(Y\) 的位置 \(i_2\) ,距离它最近的且不在区间 \([Y, X]\) 的位置 \(i_0\) ,然后个数就是 \(\max(0, min(i_1, i_2) - i_0)\)
代码
点击查看代码
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
class Solution {
public:
long long countSubarrays(vector<int> &nums, int min_k, int max_k) {
long long ans = 0L;
int n = nums.size(), min_i = -1, max_i = -1, i0 = -1;
for (int i = 0; i < n; ++i) {
int x = nums[i];
if (x == min_k) min_i = i;
if (x == max_k) max_i = i;
if (x < min_k || x > max_k) i0 = i; // 子数组不能包含 nums[i0]
ans += max(min(min_i, max_i) - i0, 0);
}
return ans;
}
};
void solve() {
int n; cin >> n;
int x, y; cin >> x >> y;
vector<int> a(n); for (int &x: a) cin >> x;
Solution s;
cout << s.countSubarrays(a, y, x) << "\n";
}
int main() {
cin.tie(nullptr)->sync_with_stdio(false);
solve();
return 0;
}
3. \(DP\)
时间复杂度: \(O(n)\)
定义 dp[i][j][k]
表示以 \(i\) 为右端点且状态为 \(j\) 和 \(k\) 的方案数。
第二维和第三维的 \(0 / 1\) 表示是否含最大值 或者 最小值
代码
点击查看代码
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
constexpr int N = 2E5 + 10;
int n, x, y;
int a[N], dp[N][2][2];
void solve() {
cin >> n >> x >> y;
for (int i = 1; i <= n; i ++ ) {
cin >> a[i];
}
ll ans = 0;
for (int i = 1; i <= n; i ++ ) {
if (a[i] > x || a[i] < y) continue;
bool fx = a[i] == x;
bool fy = a[i] == y;
dp[i][fx][fy] ++;
for (int p = 0; p < 2; p ++ ) {
for (int q = 0; q < 2; q ++ ) {
dp[i][p | fx][q | fy] += dp[i - 1][p][q];
}
}
ans += dp[i][1][1];
}
cout << ans << "\n";
}
int main() {
cin.tie(nullptr)->sync_with_stdio(false);
solve();
return 0;
}