AT_agc030_d [AGC030D] Inversion Sum 题解
题目大意
给你一个长度为 \(n\) 的数列,然后给你 \(q\) 次交换操作,你每次可以选择操作或者不操作,问所有情况下逆序对的总和。(\(n, q \le 3000\))
分析
很容易想到 \(dp\),但是发现不好直接算方案。所以我们用一个小技巧,将求方案数转化为求期望,然后求出最终答案。那么就来考虑怎么设计期望 \(dp\)。
题解
我们先来想怎么求某一情况下期望的逆序对数。
第一反应想到计数,用 \(dp_{i, j}\) 来代表数对 \((i, j)\) 对最终答案作出的贡献。但是在写转移时发现有一点麻烦,于是浅改一下思路,用 \(dp_{i, j}\) 来代表 位置 \(i\) 上的数比位置 \(j\) 上的数大的概率。
我们很容易得到,某一情况下期望的逆序对数就是 \(\displaystyle \sum_{i = 1}^{n} \sum_{j = i + 1}^{n} dp_{i, j}\)。
\(dp\) 的初始值比较好想,就是 \(O(n^2)\) 跑一遍,对于每一个 \(i\) 位置上的数大于 \(j\) 位置上的数的情况,把 \(dp_{i, j}\) 赋成 \(1\) 即可。接下来我们考虑操作一次会使 \(dp\) 怎样变化。假设我们这一次操作所交换的两个位置是 \(x\) 和 \(y\),\(x\) 位置上的数为 \(a_x\),\(y\) 位置上的数为 \(a_y\)。
首先列一个非常显然的式子:
\[\begin{aligned} dp_{x, t} &= 1 - dp_{t, x}\\ dp_{x, t} &= 1 - dp_{t, x} \end{aligned} \]有了这个式子,我们就能容易地想到 \(dp_{x, y}\) 和 \(dp_{y, x}\) 的转化。首先我们想到,\(a_x\) 和 \(a_y\) 的关系有两种可能,相等或不相等。如果不相等,交换 \(x\) 和 \(y\) 以后,\(dp_{x, y}\) 和 \(dp_{y, x}\) 都会被直接赋值成 \(0.5\),这很好理解。如果相等,那么两者都是 \(0\)。又因为上面的那两个式子,所以我们可以直接简写成 \(\displaystyle dp_{x, y} = dp_{y, x} = \frac{dp_{x, y} + dp_{y, x}}{2}\)。
而对于其他的位置上的数,我们再设一个位置 \(t(1 \le t \le n \land t \neq x \land t \neq y)\)(\(\land\) 是“且”的意思,\(\lor\) 是“或”的意思),其位置上的数为 \(a_t\),那么一定有:
\[\begin{aligned} dp_{x, t} &= 1 - dp_{t, x}\\ dp_{x, t} &= 1 - dp_{t, x}\\ dp_{t, x} = dp_{t, y} &= \frac{dp_{t, x} + dp_{t, y}}{2}\\ dp_{x, t} = dp_{y, t} &= \frac{dp_{x, t} + dp_{y, t}}{2} \end{aligned} \]前两个式子显然成立(但是好像这个题没用上),我们考虑怎么去证明后两个式子。其实也很好证,因为如果两者换位成功,则 \(dp_{t, x} = dp_{t, y}\),否则不变。而我们又知道,成功的几率为 \(0.5\),那么 \(\displaystyle dp_{t, x} = \frac{dp_{t, x} + dp_{t, y}}{2}\) 显然成立。其他的同理。
既然有了这些结论,那么我们就可以对于每一次操作,枚举一遍 \(t\),最终求得某一情况下期望的逆序对数。
得到这个我们就能轻松地算出最终答案。因为一共有 \(2^q\) 种情况,所以我们用刚刚得到的再乘上 \(2^q\) 就是最终答案。
需要注意的是,计算过程中的除法要写成逆元。
时间复杂度: \(O(n^2 + qn)\)。
代码
//https://www.luogu.com.cn/problem/AT_agc030_d AT_agc030_d [AGC030D] Inversion Sum
#include <bits/stdc++.h>
#define M 3005
#define int long long
#define mod 1000000007
using namespace std;
inline int read() {
int x = 0, s = 1;
char ch = getchar();
while(ch < '0' || ch > '9') {
if(ch == '-')
s = -s;
ch = getchar();
}
while(ch >= '0' && ch <= '9') {
x = (x << 3) + (x << 1) + ch - '0';
ch = getchar();
}
return x * s;
}
void write(int x) {
if(x < 0) {
x = ~(x - 1);
putchar('-');
}
if(x > 9)
write(x / 10);
putchar(x % 10 + 48);
}
int n, q, a[M], dp[M][M], ans;
const int inv2 = 500000004;
inline int quick_pow(int base, int P) {
int ji = 1;
while(P) {
if(P & 1)
ji = (ji * base) % mod;
P >>= 1;
base = (base * base) % mod;
}
return ji;
}
signed main() {
n = read();
q = read();
for(int i = 1; i <= n; ++ i)
a[i] = read();
for(int i = 1; i <= n; ++ i)
for(int j = 1; j <= n; ++ j)
if(a[i] > a[j])
dp[i][j] = 1;
for(int i = 1; i <= q; ++ i) {
int x = read(), y = read();
if(x == y)
continue;
for(int j = 1; j <= n; ++ j) {
if(j != x && j != y) {
dp[y][j] = dp[x][j] = (dp[x][j] + dp[y][j]) % mod * inv2 % mod;
dp[j][y] = dp[j][x] = (dp[j][x] + dp[j][y]) % mod * inv2 % mod;
}
}
dp[x][y] = dp[y][x] = (dp[x][y] + dp[y][x]) % mod * inv2 % mod;
}
for(int i = 1; i <= n; ++ i)
for(int j = i + 1; j <= n; ++ j)
ans = (ans + dp[i][j]) % mod;
ans = ans * quick_pow(2, q) % mod;
write(ans);
}