给定一棵树,给每个点填 \(0\) 或 \(1\),给每条边填 \(\text{AND}\) 或 \(\text{OR}\),求有多少种填法满足存在一种缩边的顺序,使得每次把一条边的两个端点缩成一个点,权为原端点与边的运算值,最终点的权为 \(1\)。答案对 \(998244353\) 取模。
\(n \leq 10^5\)。
不妨钦定 \(1\) 为根。考虑一个特殊的局部:如果在操作过程中,出现了某个叶子的值为 \(1\) 且连接它的边为 \(\text{OR}\),那么只需要将其放在最后操作,就一定能够得到 \(1\)。我们用有序对 \((1, \text{OR})\) 来描述这个叶子的状态。
接着我们考虑出现的叶子的其他状态:
-
\((1,\text{AND})\) 和 \((0,\text{OR})\) :显然这个叶子对答案没有影响,因此原树合法的充要条件是上面的部分合法。
-
\((0,\text{AND})\):由于操作之后其父亲的权会变为 \(0\),我们不妨在这样的叶子出现时就立刻对其进行操作,这样能够减小它对答案的负面影响。因此原树合法的充要条件还是上面的部分合法。
也就是说,如果不存在某种操作方式使得能够出现 \((1, \text{OR})\) 这样的叶子,那么依次删掉叶子一定不劣。此时合法的充要条件是:根节点权为 \(1\),且其儿子形成的叶子的状态都是 \((1,\text{AND})\) 或 \((0,\text{OR})\)。这是相对容易计数的。
考虑容斥,改为计算不合法的方案数。如果我们能求出不存在 \((1,\text{OR})\) 的方案数 \(f\) 和其中合法的方案数 \(g\),那么 \(f-g\) 就是不合法的方案数。考虑树形 DP,设 \(f_u\) 表示 \(u\) 的子树中不存在 \((1,\text{OR})\) 的方案数,\(g_u\) 表示 \(u\) 的子树中不存在 \((1,\text{OR})\) 且合法的方案数。根据上面的分析,初始时 \(f_u = 2,g_u = 1\)。
对于 \(f\),有转移 \(f_u \gets f_u \times \prod \limits_{v \in \text{son}_u} (2 \times f_v - g_v )\)。其中 \(2 \times f_v\) 是因为 \((u,v)\) 可以是 \(\text{OR}\) 或 \(\text{AND}\),减去 \(g_v\) 是为了排除 \(v\) 子树合法且 \((u,v)\) 为 \(\text{OR}\),即出现 \((1,\text{OR})\) 的情况。
对于 \(g\),有转移 \(g_u \gets g_u \times \prod \limits_{v \in \text{son}_u} f_v\)。注意无论 \(v\) 子树是否合法,由于其成为叶子时状态只能是 \((1, \text{AND})\) 或 \((0,\text{OR})\),因此 \((u,v)\) 实际上是确定的,因此 \(f_v\) 的系数为 \(1\)。
时间复杂度 \(O(n)\)。
code
#include <bits/stdc++.h>
using namespace std;
const int N = 1e5 + 5, mod = 998244353;
int n, f[N], g[N];
vector <int> e[N];
int ksm(int a, int b) {
int ret = 1;
while (b) {
if (b & 1) ret = 1LL * ret * a % mod;
a = 1LL * a * a % mod, b >>= 1;
}
return ret;
}
void dfs(int u, int ff) {
f[u] = 2, g[u] = 1;
for (auto v : e[u]) {
if (v == ff) {
continue;
}
dfs(v, u);
f[u] = 1LL * f[u] * (1LL * f[v] * 2 % mod + mod - g[v]) % mod;
g[u] = 1LL * g[u] * f[v] % mod;
}
}
int main() {
ios :: sync_with_stdio(false);
cin.tie(nullptr);
cin >> n;
for (int i = 1, x, y; i < n; i++) {
cin >> x >> y;
e[x].push_back(y);
e[y].push_back(x);
}
dfs(1, 0);
int ans = (1LL * ksm(2, 2 * n - 1) + mod - f[1] + g[1]) % mod;
cout << ans << "\n";
return 0;
}