educational 的。另一道类似的题是 [ABC269Ex] Antichain(但是我还没写)。
考虑令 \(b_u = a_u - \sum\limits_{v \in son_u} a_v\)。那么 \(\sum\limits_{i = 1}^n b_i = a_1 = x\),且 \(\forall i \in [1, n], b_i \ge 0\)。所以最后连通块内有 \(y\) 个点,那么贡献系数为 \(\binom{x + y - 1}{y - 1}\)。所以转为计算包含 \(1\) 的连通块有 \(i\) 个点的方案数。
考虑经典树形背包,设 \(f_{u, i}\) 为 \(u\) 子树内包含 \(u\) 的连通块点数为 \(i\) 的方案数。特别地 \(f_{u, 0} = 1\) 表示转移上去断掉这条边。记 \(L, R\) 分别为 \(u\) 的左右儿子,有:
\[f_{u, i + j + 1} \gets f_{u, i} f_{v, j} \]写成生成函数的形式,就是:
\[F_u(x) = x F_L(x) F_R(x) + 1 \]你发现这玩意直接做优化不了,因为 \(F_u(x)\) 的次数是 \(sz_u\) 级别的。这启发我们想到重链剖分。
具体地,考虑在重链顶处计算重链顶的多项式 \(F_u(x)\)。设重链上的点从浅到深依次为 \(a_1, a_2, \ldots, a_n\),\(a_i\) 的轻儿子为 \(b_i\)(为了方便若没有轻儿子则 \(b_i = 0\),\(F_0(x) = 1\)),我们有:
\[F_{a_n}(x) = x F_{b_n}(x) + 1 \]\[F_{a_{n - 1}}(x) = x F_{a_n}(x) F_{b_{n - 1}}(x) = x F_{b_{n - 1}}(x) (x F_{b_n}(x) + 1) + 1 \]以此类推,设 \(G_i = x F_{b_i}(x)\),那么 \(F_u(x) = G_1 (G_2(\ldots (G_n + 1)) \ldots + 1) + 1 = (\sum\limits_{i = 1}^n \prod\limits_{j = 1}^i G_j) + 1\)。
这个东西可以分治 NTT 计算。具体就是每次递归 \([l, r]\) 返回一个二元组 \((\sum\limits_{i = l}^r \prod\limits_{j = l}^i G_j, \prod\limits_{i = l}^r G_i)\),那么 \([l, mid]\) 和 \([mid + 1, r]\) 的信息就可以合并了。
考虑每次计算的 \(G_i\) 次数之和为一棵树所有轻儿子的子树大小 \(= O(n \log n)\),分治 NTT 再带两个 \(\log\),总时间复杂度就是 \(O(n \log^3 n)\)。实际运行效率还可以。
code
// Problem: F. Tree
// Contest: Codeforces - Codeforces Round 499 (Div. 1)
// URL: https://codeforces.com/contest/1010/problem/F
// Memory Limit: 256 MB
// Time Limit: 7000 ms
//
// Powered by CP Editor (https://cpeditor.org)
#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mkp make_pair
#define mems(a, x) memset((a), (x), sizeof(a))
using namespace std;
typedef long long ll;
typedef double db;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<ll, ll> pii;
const int maxn = 500100;
const ll mod = 998244353, gg = 3;
inline ll qpow(ll b, ll p) {
ll res = 1;
while (p) {
if (p & 1) {
res = res * b % mod;
}
b = b * b % mod;
p >>= 1;
}
return res;
}
typedef vector<ll> poly;
int r[maxn];
inline poly NTT(poly a, int op) {
int n = (int)a.size();
for (int i = 0; i < n; ++i) {
if (i < r[i]) {
swap(a[i], a[r[i]]);
}
}
for (int k = 1; k < n; k <<= 1) {
ll wn = qpow(op == 1 ? gg : qpow(gg, mod - 2), (mod - 1) / (k << 1));
for (int i = 0; i < n; i += (k << 1)) {
ll w = 1;
for (int j = 0; j < k; ++j, w = w * wn % mod) {
ll x = a[i + j], y = w * a[i + j + k] % mod;
a[i + j] = (x + y) % mod;
a[i + j + k] = (x - y + mod) % mod;
}
}
}
if (op == -1) {
ll inv = qpow(n, mod - 2);
for (int i = 0; i < n; ++i) {
a[i] = a[i] * inv % mod;
}
}
return a;
}
inline poly operator * (poly a, poly b) {
a = NTT(a, 1);
b = NTT(b, 1);
int n = (int)a.size();
for (int i = 0; i < n; ++i) {
a[i] = a[i] * b[i] % mod;
}
a = NTT(a, -1);
return a;
}
inline poly operator + (poly a, poly b) {
int n = (int)a.size() - 1, m = (int)b.size() - 1;
poly res(max(n, m) + 1);
for (int i = 0; i <= n; ++i) {
res[i] = (res[i] + a[i]) % mod;
}
for (int i = 0; i <= m; ++i) {
res[i] = (res[i] + b[i]) % mod;
}
return res;
}
inline poly mul(poly a, poly b) {
int n = (int)a.size() - 1, m = (int)b.size() - 1, k = 0;
while ((1 << k) < n + m + 1) {
++k;
}
for (int i = 1; i < (1 << k); ++i) {
r[i] = (r[i >> 1] >> 1) | ((i & 1) << (k - 1));
}
poly A(1 << k), B(1 << k);
for (int i = 0; i <= n; ++i) {
A[i] = a[i];
}
for (int i = 0; i <= m; ++i) {
B[i] = b[i];
}
poly res = A * B;
res.resize(n + m + 1);
return res;
}
ll n, m;
vector<int> G[maxn];
int sz[maxn], son[maxn], b[maxn], top[maxn];
poly F[maxn], a[maxn];
void dfs(int u, int fa) {
sz[u] = 1;
int mx = -1;
vector<int> S;
for (int v : G[u]) {
if (v == fa) {
continue;
}
S.pb(v);
dfs(v, u);
sz[u] += sz[v];
if (sz[v] > mx) {
son[u] = v;
mx = sz[v];
}
}
for (int v : S) {
if (son[u] != v) {
b[u] = v;
}
}
}
void dfs2(int u, int tp) {
top[u] = tp;
if (!son[u]) {
return;
}
dfs2(son[u], tp);
for (int v : G[u]) {
if (!top[v]) {
dfs2(v, v);
}
}
}
pair<poly, poly> calc(int l, int r) {
if (l == r) {
return mkp(a[l], a[l]);
}
int mid = (l + r) >> 1;
auto L = calc(l, mid), R = calc(mid + 1, r);
return mkp(L.fst + mul(L.scd, R.fst), mul(L.scd, R.scd));
}
void dfs3(int u, int fa) {
for (int v : G[u]) {
if (v == fa) {
continue;
}
dfs3(v, u);
}
if (u == top[u]) {
int K = 0;
for (int v = u; v; v = son[v]) {
a[++K] = poly(1, 0);
for (ll x : F[b[v]]) {
a[K].pb(x);
}
}
auto res = calc(1, K);
F[u] = res.fst;
F[u][0] = 1;
}
}
void solve() {
scanf("%lld%lld", &n, &m);
for (int i = 1, u, v; i < n; ++i) {
scanf("%d%d", &u, &v);
G[u].pb(v);
G[v].pb(u);
}
dfs(1, -1);
dfs2(1, 1);
F[0].pb(1);
dfs3(1, -1);
ll ans = 0, fac = 1, mul = 1;
for (int i = 1; i <= n; ++i) {
ans = (ans + mul * qpow(fac, mod - 2) % mod * F[1][i]) % mod;
fac = fac * i % mod;
mul = mul * ((m + i) % mod) % mod;
}
printf("%lld\n", ans);
}
int main() {
int T = 1;
// scanf("%d", &T);
while (T--) {
solve();
}
return 0;
}