加深了对分块算法的理解。
题目相当于求解一个区间内每种颜色出现次数平方和,这种题显然无法 polylog。
先尝试分块,将贡献拆成 散块 - 散块;散块 - 整块;整块 - 整块 三种。
散块 - 散块 是容易的,直接用桶计数就好。
整块 - 整块:设 \(d_{c, i}\) 表示颜色 \(c\) 在前 \(i\) 个块的出现次数,考虑维护 \(f(i, j) = \sum\limits_c d_{c, i} d_{c, j}\),拆掉平方,那么第 \(l\sim r\) 个块的答案为 \(f(r, r) + f(l - 1, l - 1) - 2\cdot f(l - 1, r)\)。
当第 \(k\) 块中颜色 \(c\) 的数量增加 \(w\) 时,则 \(d_{c, i}, d_{c, i + 1}, \dots\) 都会增加 \(w\),我们观察 \(f\) 的变化:
-
\(i\le j< k\):不变
-
\(i < k\le j\):\(d_{c, i} (d_{c, j} + w) = d_{c, i} d_{c, j} + w\cdot d_{c, i}\),增加了 \(w\cdot d_{c, i}\)。
-
\(k\le i, j\):\((d_{c, i} + w) (d_{c, j} + w) = d_{c, i} d_{c, j} + w\cdot d_{c, i} + w\cdot d_{c, j} + w^2\),增加了 \(w\cdot d_{c, i} + w\cdot d_{c, j} + w^2\)。
增加的贡献只和 \(i, j\) 其中一维有关,可以两个维度分别差分,做到单次修改或查询 \(\mathcal O(\sqrt n)\)。
散块 - 整块:也是容易的,枚举散块元素时,可以快速算出整块部分某一颜色的出现次数。
分析一下时间复杂度,观察到在“第 \(k\) 块颜色 \(c\) 数量增加 \(w\)”这一修改次数是 \(\mathcal O(m\sqrt n)\) 的,时间复杂度不对。
发现中间耗时间的主要是颜色全部相同的“纯色块”,考虑如果把纯色块拎出来扔进散块的部分,时间如何。
观察非纯色块颜色连续段数,每次区间覆盖操作只会在两边的散块中增加一个连续段,所以这玩意总数为 \(\mathcal O(m)\),均摊下来时间正确,为 \(\mathcal O((n + m)\sqrt n)\)。
- 启示:分块过程中的均摊时间分析;维护整块之间的信息。
点击查看代码
#include <bits/stdc++.h>
namespace Initial {
#define ll long long
#define ull unsigned long long
#define fi first
#define se second
#define mkp make_pair
#define pir pair <int, int>
#define pb emplace_back
#define i128 __int128
using namespace std;
const int maxn = 2e5 + 10, mod = 998244353;
const ll inf = 1e18;
int power(int a, int b = mod - 2) {
int s = 1;
while(b) {
if(b & 1) s = 1ll * s * a %mod;
a = 1ll * a * a %mod, b >>= 1;
} return s;
}
template <class T>
const inline ll pls(const T x, const T y) { return x + y >= mod? x + y - mod : x + y; }
template <class T>
const inline void add(T &x, const T y) { x = x + y >= mod? x + y - mod : x + y; }
template <class T>
const inline void chkmax(T &x, const T y) { x = x < y? y : x; }
template <class T>
const inline void chkmin(T &x, const T y) { x = x > y? y : x; }
} using namespace Initial;
namespace Read {
char buf[1 << 22], *p1, *p2;
#define getchar() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, (1 << 22) - 10, stdin), p1 == p2)? EOF : *p1++)
template <class T>
const inline void rd(T &x) {
char ch; bool neg = 0;
while(!isdigit(ch = getchar()))
if(ch == '-') neg = 1;
x = ch - '0';
while(isdigit(ch = getchar()))
x = (x << 1) + (x << 3) + ch - '0';
if(neg) x = -x;
}
} using Read::rd;
int n, m, B, bn, a[maxn], d[maxn][450], bl[maxn], cnt[maxn];
ll fx[450][450], fy[450][450], iscol[450], ans;
unsigned int lstans;
ll get_f(const int x, const int y) {
ll ans = 0;
for(int i = 1; i <= x; i++) ans += fy[i][y];
for(int i = x; i <= y; i++) ans += fx[i][x];
return ans;
}
void upd(const int u, const int c, const int w) {
for(int i = 1; i < u; i++) fx[u][i] += w * d[c][i];
for(int i = u; i <= bn; i++)
fx[i][i] += w * (w + d[c][i]), fy[u][i] += w * d[c][i];
for(int i = u; i <= bn; i++) d[c][i] += w;
}
void del(const int u, const int l, const int r) {
for(int i = l, j = l; i <= r; i++) {
if(a[j] ^ a[i]) upd(u, a[j], j - i), j = i;
if(i == r) upd(u, a[i], j - i - 1);
}
}
void modify(const int u, const int l, const int r, const int c) {
const int ul = (u - 1) * B + 1, ur = min(n, u * B);
if(l <= ul && ur <= r) {
if(!iscol[u]) del(u, ul, ur);
iscol[u] = c;
} else {
const int L = max(l, ul), R = min(r, ur);
if(iscol[u]) {
for(int i = ul; i <= ur; i++) a[i] = iscol[u];
upd(u, iscol[u], ur - ul - R + L), iscol[u] = 0;
} else del(u, L, R);
upd(u, c, R - L + 1);
for(int i = L; i <= R; i++) a[i] = c;
}
}
signed main() {
// freopen("p.in", "r", stdin);
// freopen("p.out", "w", stdout);
rd(n), rd(m); B = sqrt(n), bn = (n - 1) / B + 1;
for(int i = 1; i <= n; i++) rd(a[i]), bl[i] = (i - 1) / B + 1;
for(int i = 1; i <= n; i++) upd(bl[i], a[i], 1);
while(m--) {
int op, l, r; rd(op), rd(l), rd(r);
l ^= lstans, r ^= lstans;
if(op == 1) {
int c; rd(c), c ^= lstans;
// printf("modify %d %d %d\n", l, r, c);
for(int j = bl[l]; j <= bl[r]; j++) modify(j, l, r, c);
} else {
// printf("query %d %d\n", l, r);
if(bl[l] == bl[r]) {
ans = 0;
if(iscol[bl[l]]) {
ans = (r - l + 1) * (r - l + 1);
} else {
for(int i = l; i <= r; i++)
ans += 2 * cnt[a[i]] + 1, ++cnt[a[i]];
for(int i = l; i <= r; i++) cnt[a[i]] = 0;
}
lstans = ans = ans - (r - l + 1) >> 1;
printf("%lld\n", ans); continue;
}
ans = get_f(bl[r] - 1, bl[r] - 1) + get_f(bl[l], bl[l])
- 2 * get_f(bl[l], bl[r] - 1);
const int bl_l = bl[l], bl_r = bl[r];
const int pl = bl_l * B + 1, pr = (bl_r - 1) * B;
if(iscol[bl_l]) {
const int w = pl - l, x = iscol[bl_l];
ans += 1ll * (2 * (d[x][bl_r - 1] - d[x][bl_l] + cnt[x]) + w) * w;
cnt[x] += w;
} else
for(int i = l; i < pl; i++) {
const int x = a[i];
ans += 2 * (d[x][bl_r - 1] - d[x][bl_l] + cnt[x]) + 1;
++cnt[x];
}
if(iscol[bl_r]) {
const int w = r - pr, x = iscol[bl_r];
ans += 1ll * (2 * (d[x][bl_r - 1] - d[x][bl_l] + cnt[x]) + w) * w;
cnt[x] += w;
} else
for(int i = pr + 1; i <= r; i++) {
const int x = a[i];
ans += 2 * (d[x][bl_r - 1] - d[x][bl_l] + cnt[x]) + 1;
++cnt[x];
}
for(int i = bl_l + 1; i < bl_r; i++)
if(iscol[i]) {
const int x = iscol[i];
ans += 1ll * (2 * (d[x][bl_r - 1] - d[x][bl_l] + cnt[x]) + B) * B;
cnt[x] += B;
}
cnt[iscol[bl_l]] = cnt[iscol[bl_r]] = 0;
if(!iscol[bl_l])
for(int i = l; i < pl; i++) cnt[a[i]] = 0;
if(!iscol[bl_r])
for(int i = pr + 1; i <= r; i++) cnt[a[i]] = 0;
for(int i = bl_l + 1; i < bl_r; i++) cnt[iscol[i]] = 0;
lstans = ans = ans - (r - l + 1) >> 1;
printf("%lld\n", ans);
}
}
return 0;
}