题目要求我们求出任意两点间最短路径之和,由于图比较特殊,除树边外只有祖先到其子树内的边,我们首先考虑最短路径有没有什么特殊性质。
注意到两点之间的最短路分为一下三种:
-
节点到其祖先的最短路:直接沿着树边向上走即可,否则一定会走多余的边,不是最优。
-
节点到其子树的最短路:此时最短路一定形如沿着树边走若干条边,再走一条非树边,走若干条树边如此交替进行,当然此处也可以连续走非树边。
-
两个节点没有祖先孩子关系的最短路:如果此时要从 \(u\) 点走到 \(v\) 点,由于图中没有横叉边,\(u\) 点必须要走到 \(\operatorname{lca}(u,v)\),否则一定无法通过非树边到 \(v\) 点,之后转为为从 \(\operatorname{lca}(u,v)\) 到 \(v\) 的第二类最短路。
这三条最短路都会经过 \(\operatorname{lca}(u,v)\),由于其唯一性,考虑枚举最近公共祖先统计答案,第二类是方便统计的,而此时第一类也已经被处理了,我们把第三类分为两部分:从 \(u\) 到 \(\operatorname{lca}(u,v)\) 和 从 \(\operatorname{lca}(u,v)\) 到 \(v\)。
第一段通过预处理即可解决,对于每个枚举到的 \(\operatorname{lca}\),由于树高是 \(\log\) 的,第二段路径所牵扯到的节点数也不会超过 \(N\log N\) 个,因此可以直接建反图暴力跑最短路统计,之后模仿点分治计算一遍即可。
代码中我对于每个 \(\operatorname{lca}\) 的子节点正序倒序分别枚举了一次,枚举到每颗子树时累加前面的子树走前半段,这颗子树走后半段的答案,即可保证不重不漏。
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<vector>
#include<queue>
#define ll long long
#define N 700005
using namespace std;
const ll mod=998244353;
const ll inf=0x3f3f3f3f3f3f3f3f;
ll tot,dsum[N],dfn[N],ri[N],val[N],dis[N],vis[N],dep[N];
ll bel[N],siz[N],ans,ttp[N],sta[N],p,ct[N];
struct graph{
ll e,head[N],to[N],nex[N],edg[N];
void add(ll u,ll v,ll w){
to[++e]=v;nex[e]=head[u];head[u]=e;edg[e]=w;
}
void clear(){
for(ll i=1;i<=e;i++)head[i]=nex[i]=edg[i]=to[i]=0;
e=0;
}
}T,G;
struct edg{
ll u,v,w;
};
struct Node{
ll v,val;
bool operator <(const Node &x)const{
return val>x.val;
}
};
void adj(ll &x){
x=(((x%mod)+mod)%mod);
}
priority_queue<Node> q;
vector<edg> E[N];
void dfs(ll x){
dfn[x]=++tot;bel[tot]=x;siz[x]=1;
for(ll i=T.head[x];i;i=T.nex[i]){
ll v=T.to[i],w=T.edg[i];dep[v]=dep[x]+w;adj(dep[v]);dfs(v);
dsum[x]+=w*siz[v]+dsum[v];dsum[x]%=mod;
siz[x]+=siz[v];
}
ri[x]=tot;
}
void dij(ll s){
q.push((Node){s,0});dis[s]=0;
while(q.size()){
ll x=q.top().v;q.pop();if(vis[x]) continue;
vis[x]++;
for(ll i=G.head[x];i;i=G.nex[i]){
ll v=G.to[i],w=G.edg[i];
if(dis[v]>dis[x]+w){
dis[v]=dis[x]+w;
q.push((Node){v,dis[v]});
}
}
}
}
void dfs2(ll x){
G.clear();
for(ll i=dfn[x];i<=ri[x];i++){
dis[i]=inf;vis[i]=0;ans+=dep[bel[i]]-dep[x];adj(ans);
for(ll j=0;j<E[bel[i]].size();j++){
G.add(dfn[E[bel[i]][j].u],dfn[E[bel[i]][j].v],E[bel[i]][j].w);
}
G.add(dfn[bel[i]],dfn[bel[i]/2],val[bel[i]]);
}
ll tmp=x,tmp2=x/2,sumdis=0,sumsiz=1,cnt=0,tp=0;
while(tmp2){
G.add(dfn[tmp],dfn[tmp2],val[tmp]);
vis[dfn[tmp]]=0;dis[dfn[tmp]]=inf;tmp=tmp2;tmp2/=2;
}
vis[1]=0;dis[1]=inf;dij(dfn[x]);tmp=x;p=0;
for(ll i=T.head[x];i;i=T.nex[i]){
sta[++p]=i;ll v=T.to[i];tp=0;cnt=0;
for(ll j=dfn[v];j<=ri[v];j++){
if(dis[j]!=inf){
tp+=dis[j];adj(tp);cnt++;
}
}
ttp[v]=tp;ct[v]=cnt;
ans+=((sumsiz*tp)%mod)+((cnt*sumdis)%mod);adj(ans);
sumsiz+=siz[v];adj(sumsiz);
sumdis+=dsum[v]+((T.edg[i]*siz[v])%mod);adj(sumdis);
}
sumsiz=sumdis=0;
for(ll i=p;i>=1;i--){
ll j=sta[i],v=T.to[j];tp=ttp[v];cnt=ct[v];
ans+=((sumsiz*tp)%mod)+((cnt*sumdis)%mod);adj(ans);
sumsiz+=siz[v];adj(sumsiz);
sumdis+=dsum[v]+((T.edg[j]*siz[v])%mod);adj(sumdis);
}
for(ll i=dfn[x];i<=ri[x];i++)G.head[i]=dis[i]=0;
while(tmp){G.head[dfn[tmp]]=0;dis[dfn[tmp]]=0;tmp/=2;}
for(ll i=T.head[x];i;i=T.nex[i])dfs2(T.to[i]);
}
ll read(){
ll x=0,f=1;char ch=getchar();
while(ch<'0' || ch>'9'){
if(ch=='-') f=-1;
ch=getchar();
}
while(ch>='0' && ch<='9'){
x=(x<<1)+(x<<3)+(ch^48);
ch=getchar();
}
return x*f;
}
int main(){
ll n,m,u,v,w;n=read();m=read();n=(1<<n)-1;
for(ll i=2;i<=n;i++){cin>>w;T.add((i/2),i,w);val[i]=w;}
for(ll i=1;i<=m;i++){
u=read();v=read();w=read();
E[v].push_back((edg){u,v,w});
}
dfs(1);dfs2(1);
cout<<ans;
}
标签:siz,题解,ll,P9481,dis,NOI2023,lca,include,mod
From: https://www.cnblogs.com/eastcloud/p/17596401.html