居然差一点场切了。
首先可以将两棵树上对应的点看作一个点的两个不同状态考虑一个类似最短路的东西:设 \(dis_{i,j,0/1,0/1}\) 为树上 \(0/1\) 状态的 \(i\) 点到 \(0/1\) 状态的最短路。考虑怎样维护这个值。
由于是树上路径问题,容易发现设 \(k\) 为树上 \((i,j)\) 路径上任意一个点,都有 \(\min(dis_{i,k,p,0}+dis_{k,j,0,q},dis_{i,k,p,1}+dis_{k,j,1,q})=dis_{i,j,p,q}\)。利用这个性质,容易发现,可以通过树剖加线段树维护 \(dis\)。
具体来说,将每条边对应到点上,维护经过一段连续的 dfn 序所对应的边的最短路径。线段树每段区间拆成两段后,pushup 枚举中间状态,用一个类似 floyd 更新的方式得到这一段的答案。其他都是基本操作。并且注意有些地方要翻链。
除此之外,还有可能最优路径是从起点到路径外的一点,再回到中点。为了应对这种情况,可以预处理的时候更新一下两棵树对应点连的边的边权,通过两次 dp 找到最优值。
code:
点击查看代码
int n,m,q;
ll c[N],w[N][2];
int fa[N],dep[N],siz[N],son[N];
int cur,top[N],dfn[N],rk[N];
int tot,head[N];
struct node{
int to,nxt;
ll cw1,cw2;
}e[N<<1];
struct Node{
ll dis[2][2];
Node(){
mems(dis,0);
}
}tr[N<<2],emp;
inline void add(int u,int v,ll w1,ll w2){
e[++tot]={v,head[u],w1,w2};
head[u]=tot;
}
void dfs1(int u,int f){
fa[u]=f;
dep[u]=dep[f]+1;
siz[u]=1;
go(i,u){
int v=e[i].to;
if(v==f)
continue;
w[v][0]=e[i].cw1,w[v][1]=e[i].cw2;
dfs1(v,u);
siz[u]+=siz[v];
if(siz[v]>siz[son[u]])
son[u]=v;
}
}
void dfs2(int u,int t){
top[u]=t;
dfn[u]=++cur,rk[cur]=u;
if(!son[u])
return;
dfs2(son[u],t);
go(i,u){
int v=e[i].to;
if(v==fa[u]||v==son[u])
continue;
dfs2(v,v);
}
}
void init1(int u,int f){
go(i,u){
int v=e[i].to;
if(v==f)
continue;
init1(v,u);
c[u]=min(c[u],c[v]+w[v][0]+w[v][1]);
}
}
void init2(int u,int f){
go(i,u){
int v=e[i].to;
if(v==f)
continue;
c[v]=min(c[v],c[u]+w[v][0]+w[v][1]);
init2(v,u);
}
}
inline Node rev(Node x){
swap(x.dis[0][1],x.dis[1][0]);
return x;
}
inline Node pushup(Node ls,Node rs){
Node o;
rep(i,0,1){
rep(j,0,1){
o.dis[i][j]=min(ls.dis[i][0]+rs.dis[0][j],ls.dis[i][1]+rs.dis[1][j]);
}
}
return o;
}
void build(int l,int r,int o){
if(l==r){
ll A=w[rk[l]][0],B=w[rk[l]][1],C=c[rk[l]],D=c[fa[rk[l]]];
tr[o].dis[0][0]=min(A,B+C+D);
tr[o].dis[0][1]=min(A+C,B+D);
tr[o].dis[1][0]=min(A+D,B+C);
tr[o].dis[1][1]=min(A+C+D,B);
return;
}
int mid=(l+r)>>1;
build(l,mid,o<<1);
build(mid+1,r,o<<1|1);
tr[o]=pushup(tr[o<<1],tr[o<<1|1]);
}
inline Node query(int l,int r,int o,int x,int y){
if(r<x||l>y)
return emp;
if(l>=x&&r<=y)
return tr[o];
int mid=(l+r)>>1;
if(x<=mid&&y>mid)
return pushup(query(l,mid,o<<1,x,y),query(mid+1,r,o<<1|1,x,y));
if(x<=mid)
return query(l,mid,o<<1,x,y);
return query(mid+1,r,o<<1|1,x,y);
}
ll ask(int u,int v){
int opu=(u&1)^1,opv=(v&1)^1;
u=(u+1)/2,v=(v+1)/2;
Node x,y;
x.dis[0][1]=x.dis[1][0]=c[u];
y.dis[0][1]=y.dis[1][0]=c[v];
while(top[u]!=top[v]){
if(dep[top[u]]>dep[top[v]]){
x=pushup(query(1,n,1,dfn[top[u]],dfn[u]),x);
u=fa[top[u]];
}else{
y=pushup(query(1,n,1,dfn[top[v]],dfn[v]),y);
v=fa[top[v]];
}
}
if(dfn[u]<dfn[v])
y=pushup(query(1,n,1,dfn[u]+1,dfn[v]),y);
if(dfn[u]>dfn[v])
x=pushup(query(1,n,1,dfn[v]+1,dfn[u]),x);
x=pushup(rev(x),y);
return x.dis[opu][opv];
}
void Yorushika(){
scanf("%d",&n);
emp.dis[0][1]=emp.dis[1][0]=1ll*inf*inf;
rep(i,1,n){
scanf("%lld",&c[i]);
}
rep(i,1,n-1){
int u,v;
ll w1,w2;
scanf("%d%d%lld%lld",&u,&v,&w1,&w2);
add(u,v,w1,w2);
add(v,u,w1,w2);
}
dfs1(1,0);
dfs2(1,1);
init1(1,0);
init2(1,0);
build(1,n,1);
scanf("%d",&q);
rep(i,1,q){
int u,v;
scanf("%d%d",&u,&v);
printf("%lld\n",ask(u,v));
}
}
signed main(){
int t=1;
// scanf("%d",&t);
while(t--)
Yorushika();
}