首先将 \(a\) 从小到大排序,设 \(p_i\) 为排序后的 \(a_i\) 位于原序列第 \(p_i\) 个位置,\(x_i\) 为要填的排列的第 \(i\) 个数。
设 \(A = \prod\limits_{i = 1}^n (a_i - i + 1)\),则 \(A\) 为排列的总方案数(考虑按 \(a_i\) 从小到大填即得)。
套路地,统计每对 \((i, j), i < j\) 造成的逆序对贡献。设 \(f(i, j)\) 为 \((p_i, p_j)\) 在排列中构成逆序对的方案。
若 \(p_i < p_j\),则 \(x_i > x_j\) 有:
\[\begin{aligned}f(i, j) & = \frac{(a_i - i + 1)(a_i - i)}{2} \times \frac{A}{(a_i - i + 1)(a_j - j + 1)} \times \prod\limits_{k = i + 1}^{j - 1} \frac{a_k - k}{a_k - k + 1} \\ & = \frac{(a_i - i)A}{2(a_j - j + 1)} \times \prod\limits_{k = i + 1}^{j - 1} \frac{a_k - k}{a_k - k + 1}\end{aligned} \]考虑在 \([1, a_i]\) 中选出两个数分配给 \(x_i\) 和 \(x_j\),在总方案数中去除 \(x_i, x_j\) 造成的贡献,对于 \(k \in [i + 1, j - 1]\),\(x_k\) 能选的数少了 \(1\) 个,故减去。然后约分化简得上式。
若 \(p_i > p_j\),我们计算 \((i, j)\) 构成顺序对的方案数再减去,有:
\[f'(i, j) = A - \frac{(a_i - i)A}{2(a_j - j + 1)} \times \prod\limits_{k = i + 1}^{j - 1} \frac{a_k - k}{a_k - k + 1} \]看到式子有个 product 很不顺眼,考虑设 \(b_i = \prod\limits_{j = 1}^i \frac{a_j - j}{a_j - j + 1}\),\(c_i = \frac{1}{b_i} = \prod\limits_{j = 1}^i \frac{a_j - j + 1}{a_j - j}\)。那么:
\[f(i, j) = A \times ((a_i - i) \times c_i) \times \frac{b_{j - 1}}{a_j - j + 1} \]这是一个二维偏序的形式(\(i < j \land p_i < p_j\))。树状数组维护 \((a_i - i) \times c_i\) 的和,在 \(j\) 处乘上 \(\frac{b_{j - 1}}{a_j - j + 1}\) 并加入最终答案即可。
对于 \(f'(i, j)\),我们还需要计算 \(i < j \land p_i > p_j\) 的数量,可以再开一个树状数组。
但是这样有个问题,可能存在 \(a_i - i = 0\),因此可能存在 \(b_i = 0\)。为了不影响前缀积,考虑强制把 \(a_i - i = 0\) 的位置当作 \(1\) 乘进去,然后规定计算 \(f(i, j)\) 时,若 \(\exists k \in [i + 1, j - 1], a_k - k = 0\),就使 \(f(i, j) = 0\)。那我们可以把 \(a_k - k = 0\) 的位置看作一个挡板,把序列分成若干个块,每次只计算块内互相贡献的答案即可。
目前是 AtCoder 最优解。
code
// Problem: E - Inversions
// Contest: AtCoder - AtCoder Grand Contest 023
// URL: https://atcoder.jp/contests/agc023/tasks/agc023_e
// Memory Limit: 256 MB
// Time Limit: 3000 ms
//
// Powered by CP Editor (https://cpeditor.org)
#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mkp make_pair
#define mems(a, x) memset((a), (x), sizeof(a))
using namespace std;
typedef long long ll;
typedef double db;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<ll, ll> pii;
#define getchar() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1 << 21, stdin), p1 == p2) ? EOF : *p1++)
char buf[1 << 21], *p1 = buf, *p2 = buf;
inline int read() {
char c = getchar();
int x = 0;
for (; !isdigit(c); c = getchar()) ;
for (; isdigit(c); c = getchar()) x = (x << 1) + (x << 3) + (c ^ 48);
return x;
}
const int maxn = 200100;
const ll mod = 1000000007;
const ll inv2 = (mod + 1) / 2;
ll n, inv[maxn], b[maxn], c[maxn], d[maxn], f[maxn], g[maxn];
struct node {
ll x, i;
} a[maxn];
inline void upd(ll &x, ll y) {
((x += y) >= mod) && (x -= mod);
}
struct BIT {
ll c[maxn];
inline void update(int x, ll d) {
for (int i = x; i <= n; i += (i & (-i))) {
upd(c[i], d);
}
}
inline ll query(int x) {
ll res = 0;
for (int i = x; i; i -= (i & (-i))) {
upd(res, c[i]);
}
return res;
}
inline ll query(int l, int r) {
return (query(r) - query(l - 1) + mod) % mod;
}
} t1, t2;
void solve() {
n = read();
inv[0] = inv[1] = 1;
for (int i = 2; i <= n; ++i) {
inv[i] = (mod - mod / i) * inv[mod % i] % mod;
}
for (int i = 1; i <= n; ++i) {
a[i].x = read();
a[i].i = i;
}
sort(a + 1, a + n + 1, [&](const node &a, const node &b) {
return a.x < b.x;
});
ll A = 1;
for (int i = 1; i <= n; ++i) {
A = A * (a[i].x - i + 1) % mod;
}
if (!A) {
puts("0");
return;
}
ll B = A * inv2 % mod;
b[0] = c[0] = 1;
for (int i = 1; i <= n; ++i) {
b[i] = b[i - 1] * max(a[i].x - i, 1LL) % mod * inv[a[i].x - i + 1] % mod;
c[i] = c[i - 1] * inv[a[i].x - i] % mod * (a[i].x - i + 1) % mod;
f[i] = b[i - 1] * inv[a[i].x - i + 1] % mod;
g[i] = (a[i].x - i) * c[i] % mod;
}
ll ans = 0;
for (int i = 1, j = 1; i <= n; ++i) {
ans = (ans + B * t1.query(a[i].i - 1) % mod * f[i] % mod) % mod;
ll res = B * t1.query(a[i].i + 1, n) % mod * f[i] % mod;
ans = (ans + A * t2.query(a[i].i + 1, n) % mod - res + mod) % mod;
t1.update(a[i].i, g[i]);
t2.update(a[i].i, 1);
if (a[i].x == i) {
while (j < i) {
t1.update(a[j].i, mod - g[j]);
++j;
}
}
}
printf("%lld\n", ans);
}
int main() {
int T = 1;
// scanf("%d", &T);
while (T--) {
solve();
}
return 0;
}