这题做得真艰难。
先考虑第一问。
一眼看上去并没有什么复杂度脱离值域的办法。考虑枚举一个 \(x\) 表示最小值,那么点权只能在 \([x, x + K]\) 中。
点权最小值不一定为 \(x\),减去点权在 \([x + 1, x + K]\) 中的答案即可,也就是把 \(K\) 减 \(1\) 后再算一遍。
那么可以得出每个点权的取值范围为 \([\max(x, l_i), \min(x + K, r_i)]\)。
设第 \(i\) 个点有 \(a_u\) 种取值。答案就是树上所有简单路径的 \(a_u\) 乘积之和。
那么很容易做一个 dp,可以算出 \(f_u\) 表示 \(u\) 子树内延伸到 \(u\) 的路径中,每条路径的取值之和。
合并儿子时有 \(f_u \gets a_u \times f_v\)。
然后考虑所有 \(\text{LCA}\) 为 \(u\) 的点对的贡献,相当于选 \(v_1 \in son_u, v_2 \in son_u, v1 \ne v2\),能产生 \(f_{v_1} \times f_{v_2} \times a_u\) 的贡献。
我们发现,\([\max(x, l_i), \min(x + K, r_i)]\) 只能组合出来 \(4\) 种取值范围:\([x, x + K], [x, r_i], [l_i, x + K], [l_i, r_i]\),并且只和 \(x\) 和 \(l_i - K, l_i, r_i - K, r_i\) 的大小关系有关。
所以我们有 \(O(n)\) 个断点,在每相邻两个断点组成的左开右闭区间 \([L, R)\) 内,\(a_u\) 可以表示成关于 \(x\) 的至多一次项式 \(Ax + B\)。
我们现在希望计算树上所有简单路径的多项式 \(a_u\) 的乘积之和,可以使用上述的 dp 做法求出,多项式乘法暴力就行。
设我们最后求出来的树上所有简单路径的多项式 \(a_u\) 的乘积之和为 \(\sum\limits_{i = 0}^n A_i x^i\)。答案就是 \(\sum\limits_{i = 0}^n A_i \sum\limits_{x = L}^{R - 1} x^i\)。也就是说要快速算 \(\sum\limits_{x = 0}^N x^M\)。
这就是 CF622F The Sum of the k-th Powers。这个和就是一个 \(M + 1\) 次多项式,直接拉格朗日插值即可。注意讨论 \(L, R\) 中有负数的情况。
然后考虑第二问。
仍然先考虑暴力。设第 \(u\) 个点所有取值之和为 \(b_u\),那么对于树上一条简单路径 \(p_1, p_2, \ldots, p_k\),我们希望求 \(\sum\limits_{i = 1}^k b_{p_i} \prod\limits_{j \ne i} a_{p_j}\)。
这个也可以 dp 求出。设 \(f_{u, 0/1}\) 表示一条从 \(u\) 子树内延伸到 \(u\) 的路径,中间是否有一个点乘的是 \(b_i\) 而不是 \(a_i\)。
合并儿子时有转移 \(f_{u, 1} \gets a_u f_{v, 1} + b_u f_{v, 0}\)。
然后仍然考虑所有 \(\text{LCA}\) 为 \(u\) 的点对的贡献,相当于选 \(v_1 \in son_u, v_2 \in son_u, v1 \ne v2\),能产生 \(f_{v_1, 0} \times f_{v_2, 0} \times b_u + f_{v_1, 1} \times f_{v_2, 0} \times a_u + f_{v_1, 0} \times f_{v_2, 1} \times a_u\) 的贡献。
然后也可以像第一问一样,先分段,然后把 \(b_u\) 表示成关于 \(x\) 的至多二次项式,dp 后拉格朗日插值算 \(\sum\limits_{x = 0}^N x^M\) 解决。
时间复杂度 \(O(n^3)\),但是好像跑得比大多数做法都快?
code
// Problem: P8290 [省选联考 2022] 填树
// Contest: Luogu
// URL: https://www.luogu.com.cn/problem/P8290
// Memory Limit: 512 MB
// Time Limit: 2000 ms
//
// Powered by CP Editor (https://cpeditor.org)
#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mkp make_pair
#define mems(a, x) memset((a), (x), sizeof(a))
using namespace std;
typedef long long ll;
typedef double db;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<ll, ll> pii;
const int maxn = 2020;
const ll mod = 1000000007;
const ll inv2 = (mod + 1) / 2;
inline ll qpow(ll b, ll p) {
ll res = 1;
while (p) {
if (p & 1) {
res = res * b % mod;
}
b = b * b % mod;
p >>= 1;
}
return res;
}
ll n, m, a[maxn], b[maxn], lsh[maxn], tot, pw[maxn][maxn], fac[maxn], ifac[maxn];
vector<int> G[maxn];
typedef vector<ll> poly;
inline poly operator + (const poly &a, const poly &b) {
int n = (int)a.size(), m = (int)b.size();
poly res(max(n, m));
for (int i = 0; i < max(n, m); ++i) {
if (i < n) {
res[i] += a[i];
}
if (i < m) {
res[i] += b[i];
}
(res[i] >= mod) && (res[i] -= mod);
}
return res;
}
inline poly operator * (const poly &a, const poly &b) {
if (a.empty() || b.empty()) {
return poly();
}
int n = (int)a.size() - 1, m = (int)b.size() - 1;
poly res(n + m + 1);
for (int i = 0; i <= n; ++i) {
for (int j = 0; j <= m; ++j) {
res[i + j] = (res[i + j] + a[i] * b[j]) % mod;
}
}
return res;
}
ll pre[maxn], suf[maxn];
// 0 ^ m + 1 ^ m + 2 ^ m + ... + n ^ m
inline ll calc(ll n, ll m) {
if (n <= 0) {
return 0;
}
if (n <= m + 5) {
ll ans = 0;
for (int i = 0; i <= n; ++i) {
ans = (ans + pw[i][m]) % mod;
}
return ans;
}
pre[0] = 1;
for (int i = 1; i <= m + 2; ++i) {
pre[i] = pre[i - 1] * (n - i) % mod;
}
suf[m + 3] = 1;
for (int i = m + 2; i; --i) {
suf[i] = suf[i + 1] * (n - i) % mod;
}
ll s = 0, ans = 0;
for (int i = 1; i <= m + 2; ++i) {
s = (s + pw[i][m]) % mod;
ll coef = pre[i - 1] * suf[i + 1] % mod;
coef = coef * ifac[i - 1] % mod * ifac[m + 2 - i] % mod;
if ((m + 2 - i) & 1) {
coef = (mod - coef) % mod;
}
ans = (ans + coef * s) % mod;
}
return ans;
}
// l ^ m + (l + 1) ^ m + (l + 2) ^ m + ... + r ^ m
inline ll calc(ll l, ll r, ll m) {
if (!m) {
return (r - l + 1) % mod;
}
if (l <= 0 && r <= 0) {
return (mod + ((m & 1) ? (-calc(-l, m) + calc(-r - 1, m)) : (calc(-l, m) - calc(-r - 1, m)))) % mod;
} else if (l <= 0 && r > 0) {
return (mod + mod + ((m & 1) ? -calc(-l, m) : calc(-l, m)) + calc(r, m)) % mod;
} else {
return (calc(r, m) - calc(l - 1, m) + mod) % mod;
}
}
poly A[maxn], B[maxn], P, Q, F[maxn][2];
void dfs(int u, int fa) {
F[u][0] = A[u];
F[u][1] = B[u];
for (int v : G[u]) {
if (v == fa) {
continue;
}
dfs(v, u);
F[u][0] = F[u][0] + A[u] * F[v][0];
F[u][1] = F[u][1] + A[u] * F[v][1] + B[u] * F[v][0];
}
P = P + F[u][0];
Q = Q + F[u][1];
}
void dfs2(int u, int fa) {
poly a, b;
for (int v : G[u]) {
if (v == fa) {
continue;
}
dfs2(v, u);
P = P + F[v][0] * A[u] * a;
Q = Q + F[v][0] * A[u] * b + F[v][1] * A[u] * a + F[v][0] * B[u] * a;
a = a + F[v][0];
b = b + F[v][1];
}
}
inline pii calc(ll m) {
tot = 0;
for (int i = 1; i <= n; ++i) {
lsh[++tot] = a[i];
lsh[++tot] = a[i] - m;
lsh[++tot] = b[i] - m;
lsh[++tot] = b[i] + 1;
}
sort(lsh + 1, lsh + tot + 1);
tot = unique(lsh + 1, lsh + tot + 1) - lsh - 1;
ll ans1 = 0, ans2 = 0;
for (int _ = 1; _ < tot; ++_) {
ll L = lsh[_], R = lsh[_ + 1];
for (int i = 1; i <= n; ++i) {
A[i] = B[i] = poly();
ll l = a[i], r = b[i];
if (max(l, L) > min(r, L + m)) {
continue;
}
if (L >= l && L >= r - m) {
A[i] = poly(2);
A[i][0] = r + 1;
A[i][1] = mod - 1;
B[i] = poly(3);
B[i][0] = (r * r + r) % mod * inv2 % mod;
B[i][1] = inv2;
B[i][2] = (mod - inv2) % mod;
} else if (L >= l && L < r - m) {
A[i] = poly(1);
A[i][0] = m + 1;
B[i] = poly(2);
B[i][0] = m * (m + 1) % mod * inv2 % mod;
B[i][1] = m + 1;
} else if (L < l && L >= r - m) {
A[i] = poly(1);
A[i][0] = r - l + 1;
B[i] = poly(1);
B[i][0] = calc(l, r, 1);
} else if (L < l && L < r - m) {
A[i] = poly(2);
A[i][0] = (m - l + 1 + mod) % mod;
A[i][1] = 1;
B[i] = poly(3);
ll p = (l + m) % mod, q = (m - l + 1 + mod) % mod;
B[i][0] = p * q % mod * inv2 % mod;
B[i][1] = inv2 * (p + q) % mod;
B[i][2] = inv2;
}
}
P = Q = poly();
dfs(1, -1);
dfs2(1, -1);
for (int i = 0; i < (int)P.size(); ++i) {
ans1 = (ans1 + P[i] * calc(L, R - 1, i)) % mod;
}
for (int i = 0; i < (int)Q.size(); ++i) {
ans2 = (ans2 + Q[i] * calc(L, R - 1, i)) % mod;
}
// printf("%lld %lld %lld\n", L, R, ans2);
// for (int i = 1; i <= n; ++i) {
// printf("i: %d\n", i);
// for (ll x : B[i]) {
// printf("%lld ", x);
// }
// putchar('\n');
// }
}
return mkp(ans1, ans2);
}
void solve() {
scanf("%lld%lld", &n, &m);
for (int i = 1; i <= n; ++i) {
scanf("%lld%lld", &a[i], &b[i]);
}
int up = n * 2 + 5;
for (int i = 0; i <= up; ++i) {
pw[i][0] = 1;
for (int j = 1; j <= up; ++j) {
pw[i][j] = pw[i][j - 1] * i % mod;
}
}
fac[0] = 1;
for (int i = 1; i <= up; ++i) {
fac[i] = fac[i - 1] * i % mod;
}
ifac[up] = qpow(fac[up], mod - 2);
for (int i = up - 1; ~i; --i) {
ifac[i] = ifac[i + 1] * (i + 1) % mod;
}
for (int i = 1, u, v; i < n; ++i) {
scanf("%d%d", &u, &v);
G[u].pb(v);
G[v].pb(u);
}
pii x = calc(m), y = calc(m - 1);
printf("%lld\n%lld\n", (x.fst - y.fst + mod) % mod, (x.scd - y.scd + mod) % mod);
}
int main() {
int T = 1;
// scanf("%d", &T);
while (T--) {
solve();
}
return 0;
}