前言
顺着线段树合并的标签找到这道题,感觉可做,但一写写了整整一天/kk。
题意
给出一个以 \(1\) 为根的 \(n\) 个节点的二叉树,每个叶子上有一个权值,非叶子节点的权值有一定概率为其子节点中权较大 \(/\) 较小的那个,问根节点每个不同权值的概率,答案按一定方式输出。
思路
首先把所有叶子的权离散化。朴素的想法,对于每个节点建一棵权值线段树,权值线段树上第 \(i\) 个位置记录了该节点权值为 \(i\) 的概率,然后使用线段树合并,一直并到根得到答案。
关键在于如何合并信息。
记当前节点为 \(x\),如果 \(x\) 只有一个儿子,直接把儿子上的那棵线段树挂到父亲上。
如果有两个儿子,记它们分别为 \(l\) 和 \(r\),并设 \(l_d\) 和 \(r_d\) 分别表示 \(l\) 节点和 \(r\) 节点上权值 \(d\) 出现的概率。设 \(x\) 的权值有 \(k\) 的概率为子结点的权值的最大值。由于题目保证权值互不相同,那么整理可得:
\[x_d = (k * \sum\limits_{i=1}^{d-1}l_i+(1-k) * \sum\limits_{i=d+1}^{Max}l_i) * r_i + (k * \sum\limits_{i=1}^{d-1}r_i+(1-k) * \sum\limits_{i=d+1}^{Max}r_i) * l_i \]由于一个节点上的所有概率和为 \(1\),所以求出前缀和就能求出后缀和,而根据线段树先遍历左儿子后遍历右儿子的顺序,访问到线段树上当前的单个节点时一定已经访问了它左边的所有点,所以前缀和也可以在合并的时候顺便求出。
然后一开始就直接这样写了,然后发现每一次合并的时候要把值域里所有不为零的点全都访问到单点,导致要建巨大多的新点。一开始没有考虑这一点,疯狂\(\text{RE}\) + \(\text{MLE}\)。甚至考虑过把数组回滚掉,然后TLE了
所以我们考虑优化。
当前的瓶颈在于:对于正在合并的两颗线段树,就算当前这个区间只有一个线段树有节点,依然要继续往下递归以维护单点信息。但可以发现除了根节点以外,其余节点的单点信息都是冗余的,如果能在以上这种情况的时候能像正常的线段树合并一样直接返回,就可以保证复杂度。
记正在合并的两颗线段树中,当前这个区间为 \([l, r]\) ,有节点的线段树为 \(\text{A}\) ,无节点的为 \(\text{B}\) 。那么 \(\text{B}\) 所对应的树上的那个节点在此权值区间里出现概率均为 \(0\)。因为我们维护的是区间概率和,那么根据上式,我们可以直接更改 \(\text{A}\) 在 \([l, r]\) 的概率和:
\[A_{[l, r]} = A_{[l, r]} * (k * \sum\limits_{i=1}^{l-1}B_i+(1-k) * \sum\limits_{i=r+1}^{Max}B_i) \]由于后面那部分是定值,所以同时给这个区间打上一个区间乘的 \(\text{tag}\) 。在下传的时候,由于 \(\text{B}\) 在 \([l,r]\) 的值均为 \(0\) ,所以缩小区间范围并不会影响 \(\text{tag}\) 的值。
时间复杂度 \(O(n \log n)\)
pushdown写错了,调半天/fn
Code
点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define ull unsigned long long
#define l(x) x<<1
#define r(x) x<<1|1
const ll SIZE = 300005;
const ll mod = 998244353;
ll n, tot, totv;
ll head[SIZE], ver[SIZE*2], nxt[SIZE*2];
ll aa[SIZE], ds[SIZE], bb[SIZE], n1;
ll rt[SIZE];
ll ans, cc, cnta, cntb;
inline ll rd(){
ll f = 1, x = 0;
char ch = getchar();
while(ch < '0' || ch > '9'){
if(ch == '-') f = -1;
ch = getchar();
}
while(ch >= '0' && ch <= '9'){
x = (x << 1) + (x << 3) + (ch ^ 48);
ch = getchar();
}
return f*x;
}
struct Tree{
ll l, r;
ll P, tag;
Tree(){
P = l = r = 0;
tag = 1;
}
}t[SIZE*40];
ll power(ll x, ll y){
ll jl = 1;
while(y){
if(y & 1) jl = (jl * x) % mod;
x = (x * x) % mod;
y >>= 1;
}
return jl;
}
void add(ll x, ll y){
ver[++tot] = y, nxt[tot] = head[x];
head[x] = tot;
}
ll get(ll x){
return lower_bound(bb+1, bb+n1+1, x) - bb;
}
void pushup(ll p){
t[p].P = (t[t[p].l].P + t[t[p].r].P) % mod;
}
void pushdown(ll p){
if(t[p].tag != 1){
t[t[p].l].P = (t[t[p].l].P * t[p].tag) % mod;
t[t[p].r].P = (t[t[p].r].P * t[p].tag) % mod;
t[t[p].l].tag = (t[t[p].l].tag * t[p].tag) % mod;
t[t[p].r].tag = (t[t[p].r].tag * t[p].tag) % mod;
t[p].tag = 1;
}
}
void update(ll &p, ll l, ll r, ll x, ll fa){
if(!p) p = ++totv;
if(l == r){
t[p].P = 10000;
return;
}
pushdown(p);
ll mid = (l + r) >> 1;
if(x <= mid) update(t[p].l, l, mid, x, fa);
else update(t[p].r, mid+1, r, x, fa);
pushup(p);
}
ll merge(ll &a, ll &b, ll l, ll r, ll k, ll fa){
if(!a && !b) return 0;
if((!a || !b) && fa != 1){
ll nxa = 10000 - cnta - t[a].P;
while(nxa < 0) nxa += mod;
while(nxa >= mod) nxa -= mod;
ll nxb = 10000 - cntb - t[b].P;
while(nxb < 0) nxb += mod;
while(nxb >= mod) nxb -= mod;
if(a){
t[a].tag = ((t[a].tag * (((cntb * k % mod) + (nxb * (10000-k) % mod)) % mod) % mod) * (cc * cc % mod)) % mod;
cnta = (cnta + t[a].P) % mod;
t[a].P = (t[a].P * t[a].tag) % mod;
}
if(b){
t[b].tag = ((t[b].tag * (((cnta * k % mod) + (nxa * (10000-k) % mod)) % mod) % mod) * (cc * cc % mod)) % mod;
cntb = (cntb + t[b].P) % mod;
t[b].P = (t[b].P * t[b].tag) % mod;
}
return a+b;
}
if(!a) a = ++totv;
if(!b) b = ++totv;
if(l == r){
ll jl = t[a].P, jjl = t[b].P;
ll nxa = 10000 - cnta - jl;
while(nxa < 0) nxa += mod;
while(nxa >= mod) nxa -= mod;
ll nxb = 10000 - cntb - t[b].P;
while(nxb < 0) nxb += mod;
while(nxb >= mod) nxb -= mod;
t[a].P = ((((((nxa * t[b].P % mod) * cc)%mod + ((nxb * t[a].P % mod) *cc)) % mod) * (10000 - k)) % mod) * cc %mod;
t[a].P = (t[a].P + ((((((cnta * t[b].P % mod) * cc)%mod + ((cntb * jl % mod) *cc)) % mod) * k) % mod) * cc %mod) % mod;
if(fa == 1 && t[a].P != 0){
ll id = get(bb[l]);
ll hh = t[a].P * cc % mod;
hh = (hh * hh) % mod;
ans = (ans + ((id * bb[l] % mod) * hh) % mod) % mod;
}
cnta = (cnta + jl) % mod;
cntb = (cntb + jjl) % mod;
return a;
}
ll mid = (l + r) >> 1;
pushdown(t[a].l); pushdown(t[b].l);
pushdown(t[a].r); pushdown(t[b].r);
t[a].l = merge(t[a].l, t[b].l, l, mid, k, fa);
t[a].r = merge(t[a].r, t[b].r, mid+1, r, k, fa);
pushup(a);
return a;
}
void dfs(ll x, ll fa){
bool ff = 1; ll la;
for(ll i = head[x]; i; i = nxt[i]){
ll y = ver[i];
if(y == fa) continue;
dfs(y, x);
cnta = cntb = 0;
if(ff == 1) rt[x] = rt[y], la = y;
else merge(rt[x], rt[y], 1ll, 300000ll, aa[x], x);
ff = 0;
}
if(ff) update(rt[x], 1, 300000, get(aa[x]), x);
}
int main(){
n = rd(); rd();
for(ll i = 2; i <= n; i++){
ll x = rd(), y = i;
add(x, y); add(y, x);
ds[x]++; rt[i] = i;
}
rt[1] = 1; totv = n;
for(ll i = 1; i <= n; i++){
aa[i] = rd();
if(ds[i] == 0){
bb[++n1] = aa[i];
}
}
sort(bb+1, bb+n1+1);
cc = power(10000, mod-2);
dfs(1, 0);
printf("%lld", ans);
return 0;
}