这题难度是middle,但是确实有点强思维的味道,赛时思考了许久,没想到好方向,最后想了个线段树的解法。。当然最后超时了 861 / 884,二十多个用例过不去;
简单说下一开始用线段树来解的思路,因为是需要统计所有的子数组,数组 n 是 4e4,然后其实我当时想法是,可以递推的方式,每次向后挪一位,这样每次n + 1,规模就扩大一次,这样一次操作其实就两类,也就是要么是0,要么是 1,那么我们维护一棵线段树,树的子节点一开始是 [i, i],维护每棵子树的 max 和 min,以及每个节点的0和1个数,那么每次规模扩大,也就是新增一个1或者0,就是子节点变成 [i, i + x] 这个范围中 0 的个数和 1的个数,换句话说新增1 就是范围内加1,新增0就是范围内 减去原来0个数的平方,增加目前0个数的平方
这里可以使用 lazy 下放的优化方式;如果当前新增的是 0,且当前范围子树内 max 小于 0 ,那么可以认定全部都不符合,直接在父节点加上一个 0,后续lazy 下放,同理如果当前新增的是 1,且当前范围子树内 min 大于等于 -1,那么直接在父节点上加上一个1,后续 lazy 下方
具体看代码实现:
class Solution {
public:
const static int maxn = 4e4 + 1;
struct node {
int max;
int min;
int zero;
int first;
};
node seg[maxn * 4];
node value[maxn];
long long ans = 0;
void initTree(int l, int r, int pos) {
if (l == r) {
seg[pos].max = value[l].max;
seg[pos].min = value[l].min;
seg[pos].zero = value[l].zero;
seg[pos].first = value[l].first;
return ;
}
int mid = (l + r) >> 1;
initTree(l, mid, pos << 1);
initTree(mid + 1, r, pos << 1 | 1);
seg[pos].max = max(seg[pos << 1].max, seg[pos << 1 | 1].max);
seg[pos].min = min(seg[pos << 1].min, seg[pos << 1 | 1].min);
seg[pos].zero = 0;
seg[pos].first = 0;
}
void addZero(int l, int r, int left, int right, int pos) {
if (l == r) {
seg[pos].zero ++;
seg[pos].max = seg[pos].first - 1L * seg[pos].zero * seg[pos].zero;
seg[pos].min = seg[pos].max;
if (seg[pos].max >= 0) ans ++;
return ;
}
if (l >= left && r <= right && seg[pos].max + seg[pos].first < 0) {
seg[pos].zero ++ ;
return ;
}
if (seg[pos].zero > 0) {
seg[pos << 1].zero += seg[pos].zero;
seg[pos << 1 | 1].zero += seg[pos].zero;
seg[pos].zero = 0;
}
if (seg[pos].first > 0) {
seg[pos << 1].first += seg[pos].first;
seg[pos << 1 | 1].first += seg[pos].first;
seg[pos].first = 0;
}
int mid = (l + r) >> 1;
if (left <= mid) {
addZero(l, mid, left, right, pos << 1);
}
if (right > mid) {
addZero(mid + 1, r, left, right, pos << 1 | 1);
}
seg[pos].max = max(seg[pos << 1].max, seg[pos << 1 | 1].max);
seg[pos].min = min(seg[pos << 1].min, seg[pos << 1 | 1].min);
}
void addFirst(int l, int r, int left, int right, int pos) {
if (l == r) {
seg[pos].first ++;
seg[pos].max = seg[pos].first - 1L * seg[pos].zero * seg[pos].zero;
seg[pos].min = seg[pos].max;
if (seg[pos].max >= 0) ans ++;
return ;
}
if (l >= left && r <= right && seg[pos].min + seg[pos].first >= -1) {
seg[pos].first ++;
ans += r - l + 1;
return ;
}
if (seg[pos].zero > 0) {
seg[pos << 1].zero += seg[pos].zero;
seg[pos << 1 | 1].zero += seg[pos].zero;
seg[pos].zero = 0;
}
if (seg[pos].first > 0) {
seg[pos << 1].first += seg[pos].first;
seg[pos << 1 | 1].first += seg[pos].first;
seg[pos].first = 0;
}
int mid = (l + r) >> 1;
if (left <= mid) {
addFirst(l, mid, left, right, pos << 1);
}
if (right > mid) {
addFirst(mid + 1, r, left, right, pos << 1 | 1);
}
seg[pos].max = max(seg[pos << 1].max, seg[pos << 1 | 1].max);
seg[pos].min = min(seg[pos << 1].min, seg[pos << 1 | 1].min);
}
int numberOfSubstrings(string s) {
for (int i = 0; i < maxn * 4; ++ i) {
seg[i].max = 0;
seg[i].min = 0;
seg[i].zero = 0;
seg[i].first = 0;
}
int zero = 0;
for (int i = 0; i < s.length(); ++ i) {
if (s[i] == '0') zero ++;
int first = i - zero + 1;
if (s[i] == '0') {
value[i + 1].max = -1;
value[i + 1].min = -1;
value[i + 1].zero = 1;
value[i + 1].first = 0;
} else {
value[i + 1].zero = 0;
value[i + 1].first = 1;
value[i + 1].max = 1;
value[i + 1].min = 1;
ans ++;
}
}
initTree(1, s.length(), 1);
for (int i = 1; i < s.length(); ++ i) {
if (s[i] == '0') {
addZero(1, s.length(), 1, i, 1);
} else {
addFirst(1, s.length(), 1, i, 1);
}
}
return ans;
}
};
下面是正确的思路
考虑到其实平方的增速是很大的,同样是每次子数组规模扩大1,会有两个操作,新增0,新增1,所以对于一个区间如果当前 1数量 - 0数量平方少于0,那么至少需要向后移动这么多位,才可能大于等于0(全部是1情况),那么可以考虑贪心的优化方式,也就是对于从 i 开始的子区间,j 从 i 到字符串末尾,定义 value = 1数量 - 0数量平方,如果value小于0,可以向后移动value位,如果value 大于等于0,那么考虑目前需要增加多少个0 可以让value 小于0,也就是需要 [sqrt(1数量) 向上取整 - 当前0数量] 个,那么也就可以向后移动 [sqrt(1数量) 向上取整 - 当前0数量] 位。这样的贪心思路,可以大大优化执行耗时
具体看代码逻辑
class Solution {
public:
const static int maxn = 4e4 + 1;
int sum[maxn][2];
int numberOfSubstrings(string s) {
long long ans = 0;
sum[0][0] = sum[0][1] = 0;
if (s[0] == '1') sum[1][1] = 1, sum[1][0] = 0;
if (s[0] == '0') sum[1][0] = 1, sum[1][1] = 0;
for (int i = 1; i < s.length(); ++ i) {
sum[i + 1][1] = sum[i][1] + (s[i] == '1');
sum[i + 1][0] = sum[i][0] + (s[i] == '0');
}
for (int i = 0; i < s.length(); ++ i) {
int j = i + 1;
while (j <= s.length()) {
int zero = sum[j][0] - sum[i][0];
int one = sum[j][1] - sum[i][1];
int zero2 = zero * zero;
int diff = one - zero2;
if (diff >= 0) {
int need_zero = sqrt(one);
if (need_zero * need_zero <= one) need_zero ++;
int diff_zero = need_zero - zero;
int l = s.length();
ans += 1L * min(diff_zero, l - j + 1);
j += diff_zero;
} else {
j -= diff;
}
}
}
return ans;
}
};
标签:周赛,Q3,int,sum,pos,value,seg,zero,LeetCode
From: https://www.cnblogs.com/wanshe-li/p/18328999