好。
首先考虑怎么计算方案数。我们考虑按照 \(a_i\) 从小往大选,设排序后的下标为 \(b_i\),那么容易得出方案数为:
\[s = \prod_{i=1}^n (a_{b_i} - i + 1) \]我们设 \(c_i = a_{b_i} - i + 1\),这代表着某个数的选择方案数。
然后考虑经典拆贡献,枚举每一对 \((i, j)\),求 \(p_i > p_j\) 的方案数,这样累加起来就是答案。
首先假设 \(a_i < a_j\),我们直接枚举 \(p_i, p_j\) 的选择方案,容易得出为 \(\dbinom{c_i}{2} = \dfrac{c_i(c_i - 1)}{2}\)。然后考虑这会导致 \(a_i\) 与 \(a_j\) 之间的数的选择方案减少 \(1\),其他数的选择方案不变,那么可以写出式子:
\[\begin{aligned} f(i, j) &= \frac{c_i (c_i - 1)}{2} \times \frac{s}{c_i c_j} \times \prod_{a_i < a_k < a_j}\frac{c_k - 1}{c_k}\\ &=\frac{s (c_i - 1)}{2 c_j} \prod_{a_i < a_k < a_j}\frac{c_k - 1}{c_k} \end{aligned} \]假如说 \(a_i > a_j\),那么我们可以先对称的求出 \(p_i > p_j\) 的方案数,然后总方案数减去它即可,即:
\[f(i, j) =s - \frac{s (c_i - 1)}{2 c_j} \prod_{a_i < a_k < a_j}\frac{c_k - 1}{c_k} \]我们考虑维护这个东西。我们按照 \(a_i\) 从小往大的顺序依次计算答案,然后以 \(i\) 为下标建一棵线段树,这样我们就可以将 \(i < j\) 与 \(i > j\) 的分开来计算。对于后者,我们还需要统计有多少个数,可以拿树状数组维护。这个式子容易通过区间乘,区间求和计算。
#include <bits/stdc++.h>
using namespace std;
const int MAXN = 200005, P = 1000000007;
int qpow(int a, int b) {
int ans = 1;
while (b) {
if (b & 1) ans = 1ll * ans * a % P;
a = 1ll * a * a % P;
b >>= 1;
}
return ans;
}
int n, a[MAXN], b[MAXN], c[MAXN];
struct SegmentTree {
struct Node {
int val, tag;
Node() : val(0), tag(1) {}
} t[MAXN << 2];
#define lc (i << 1)
#define rc (i << 1 | 1)
void tag(int i, int v) {
t[i].tag = 1ll * t[i].tag * v % P;
t[i].val = 1ll * t[i].val * v % P;
}
void pushDown(int i) { tag(lc, t[i].tag), tag(rc, t[i].tag), t[i].tag = 1; }
void mul(int a, int b, int v, int i = 1, int l = 1, int r = n) {
if (a > b) return;
if (a <= l && r <= b) return tag(i, v);
int mid = (l + r) >> 1;
pushDown(i);
if (a <= mid) mul(a, b, v, lc, l, mid);
if (b > mid) mul(a, b, v, rc, mid + 1, r);
t[i].val = (t[lc].val + t[rc].val) % P;
}
void set(int d, int v, int i = 1, int l = 1, int r = n) {
if (l == r) {
t[i].val = v;
return;
}
int mid = (l + r) >> 1;
pushDown(i);
if (d <= mid) set(d, v, lc, l, mid);
else set(d, v, rc, mid + 1, r);
t[i].val = (t[lc].val + t[rc].val) % P;
}
int query(int a, int b, int i = 1, int l = 1, int r = n) {
if (a > b) return 0;
if (a <= l && r <= b) return t[i].val;
int mid = (l + r) >> 1;
pushDown(i);
if (b <= mid) return query(a, b, lc, l, mid);
if (a > mid) return query(a, b, rc, mid + 1, r);
return (query(a, b, lc, l, mid) + query(a, b, rc, mid + 1, r)) % P;
}
} st;
struct BinaryIndexTree {
int a[MAXN];
#define lowbit(x) (x & (-x))
void add(int d, int v) {
while (d <= n) {
a[d] += v;
d += lowbit(d);
}
}
int query(int d) {
if (!d) return 0;
int ret = 0;
while (d) {
ret += a[d];
d -= lowbit(d);
}
return ret;
}
} bit;
int main() {
scanf("%d", &n);
for (int i = 1; i <= n; i++) {
scanf("%d", &a[i]);
b[i] = i;
}
sort(b + 1, b + 1 + n, [&](int x, int y) { return a[x] < a[y]; });
int s = 1;
for (int i = 1; i <= n; i++) {
c[i] = a[b[i]] - i + 1;
if (c[i] <= 0) {
printf("0\n");
return 0;
}
s = 1ll * s * c[i] % P;
}
int ans = 0;
for (int i = 1; i <= n; i++) {
ans = (ans + 1ll * st.query(1, b[i] - 1) * qpow(2 * c[i], P - 2)) % P;
ans = (ans - 1ll * st.query(b[i] + 1, n) * qpow(2 * c[i], P - 2) % P + P) % P;
ans = (ans + 1ll * (i - bit.query(b[i]) - 1) * s) % P;
st.mul(1, n, 1ll * (c[i] - 1) * qpow(c[i], P - 2) % P);
st.set(b[i], 1ll * s * (c[i] - 1) % P);
bit.add(b[i], 1);
}
printf("%d\n", ans);
return 0;
}
标签:frac,val,int,mid,解题,MAXN,AGC023E,Inversions,return
From: https://www.cnblogs.com/apjifengc/p/17416195.html