前言
开博客记录做题笔记的 flag 我立过 \(n\) 遍了,无一例外都倒了。
这次一定要坚持下来,一周至少一题不能咕,养成好习惯。
【LOJ3124】氪金手游
容易发现给定的 \((u_i,v_i)\) 是一棵树,先考虑简化问题:树是以 \(1\) 为根的外向树(每条边都从靠近 \(1\) 的点连向远离 \(1\) 的点),且 \(w_i\) 确定。
令 \(f(i)\) 表示考虑满足以 \(i\) 为根的子树的条件的概率,容易得到转移式:\(f(i)=\dfrac{w_i}{\sum_{v \in sub(i)} w_v} \prod_{x \in son(i)} f(x)\),其中 \(sub(i)\) 表示 \(i\) 的子树内的点集合,\(son(i)\) 表示 \(i\) 的儿子集合。意义也很显然,\(i\) 要在子树内所有点之前被抽到,抽到之后所有子树均独立。
考虑某条边反向的情况,如果有一条边 \(v \rightarrow u\) 满足 \(u\) 是 \(v\) 的父亲,可以做如下转化:观察到某种抽卡顺序,要么是先抽到 \(u\) 再抽到 \(v\),要么是先抽到 \(v\) 再抽到 \(u\),即 \(P(u \rightarrow v)+P(v \rightarrow u)=P((u,v)之间没有边)\),我们也很容易写出 \((u,v)\) 之间没有边时的转移方程:\(f(u)=\dfrac{w_u}{\sum_{x \in \{sub(u)-v\}} w_x} \prod_{y \in son(u)} f(y)\),也就可以得到一条边反向时的情况。
然后考虑加上 \(w_i\) 概率的情况,我们只需要在状态中增加一维:\(f(i,j)\) 表示当 \(\sum_{u \in sub(i)}w_u=j\) 时,满足以 \(i\) 为根的子树中条件的概率。转移的时候类似树上背包的方法合并两棵子树,然后乘上对于 \(w_i\) 的限制即可,时间复杂度可以做到均摊的 \(O(n^2)\)。
Code
#include<bits/stdc++.h>
using namespace std;
#define int long long
#define fi first
#define se second
#define pii pair<long long,long long>
#define mp make_pair
#define pb push_back
const int mod=998244353;
const int inf=0x3f3f3f3f;
const int INF=1e18;
int fpow(int x,int b){
if(x==0) return 0;
if(b==0) return 1;
int res=1;
while(b>0){
if(b&1) res=res*x%mod;
x=x*x%mod;
b>>=1;
}
return res;
}
int n;
map <pii,bool> dir;
int P[1005][4];
vector <int> g[1005];
int dp[1005][3005],sz[1005],f[1005][3005];
void dfs(int u,int fa)
{
sz[u]=1;
vector <int> ss;
ss.clear();
ss.pb(-1);
for(int i=0;i<g[u].size();i++)
{
int v=g[u][i];
if(v==fa) continue;
dfs(v,u),ss.pb(v),sz[u]+=sz[v];
}
f[0][0]=1;
for(int i=1;i<=ss.size();i++) memset(f[i],0,sizeof(f[i]));
for(int i=1;i<ss.size();i++)
{
int v=ss[i];
if(dir[mp(u,v)])
{
for(int s1=0;s1<=3*n;s1++) if(f[i-1][s1]) for(int s2=0;s2<=3*sz[v];s2++) if(dp[v][s2])
f[i][s1+s2]=(f[i][s1+s2]+f[i-1][s1]*dp[v][s2])%mod;
}
else
{
int t=0;
for(int s2=0;s2<=3*sz[v];s2++) t=(t+dp[v][s2])%mod;
for(int s1=0;s1<=3*n;s1++)
{
f[i][s1]=(f[i][s1]+f[i-1][s1]*t)%mod;
if(f[i-1][s1]) for(int s2=0;s2<=3*sz[v];s2++) if(dp[v][s2])
f[i][s1+s2]=(f[i][s1+s2]-f[i-1][s1]*dp[v][s2]%mod+mod)%mod;
}
}
}
int t=ss.size()-1;
for(int i=0;i<=3*n;i++) for(int w=1;w<4;w++)
dp[u][i+w]=(dp[u][i+w]+f[t][i]*P[u][w]%mod*w%mod*fpow(i+w,mod-2))%mod;
// cout<<u<<" "<<t<<endl;
// for(int i=1;i<=3*n;i++) cout<<dp[u][i]<<" ";
// puts("");
}
void solve()
{
cin>>n;
for(int i=1;i<=n;i++)
{
int a1,a2,a3;
cin>>a1>>a2>>a3;
P[i][1]=a1*fpow(a1+a2+a3,mod-2)%mod;
P[i][2]=a2*fpow(a1+a2+a3,mod-2)%mod;
P[i][3]=a3*fpow(a1+a2+a3,mod-2)%mod;
// cout<<P[i][1]<<" "<<P[i][2]<<" "<<P[i][3]<<endl;
}
for(int i=1;i<n;i++)
{
int u,v;
cin>>u>>v;
g[u].pb(v),g[v].pb(u);
dir[mp(u,v)]=1;
}
dfs(1,-1);
int ans=0;
for(int i=1;i<=3*n;i++) ans=(ans+dp[1][i])%mod;
cout<<ans;
}
signed main()
{
int _=1;
//cin>>_;
while(_--) solve();
return 0;
}