假设现在起点已经确定,我们观察从这个起点开始能走的最长路径长什么样。把这条最长路径中所有的非平地路径拿出来,它们肯定连成一线,因为不允许上坡;而一条路径重复走的情况只可能是在几个连续的平地那里来回走。所以路径的形状是一条主链(链上的边可以是下坡或者平地),上面挂着一些平地组成的环。由于图是树,所以"挂着的环"可以简化为只有一条边,但来回走了多次。
把原树中与平路相连的点称为好点,不与任何平路相连的点称为坏点。假设现在已经确定了主链,那么如果只走这条主链不走环,如果剩下了一些能量就浪费了。而当我们碰到主链中一个好点的时候就可以用与它连着的任意一条平路消耗能量。注意我们只能用这些平路额外消耗偶数个的能量,因为走这些平路总是来回来回地走的。发现只用这条主链上最靠后的一个好点来消耗能量是最优的。如果最后一个好点与起点的距离是奇数,那么在到达这个好点时,能量值是奇数,所以就会浪费1的能量值。令\(d\)为最后一个好点到起点的距离,则答案就是\(最后一个好点前面下坡的个数\cdot 2+最后一个好点后面的主链长度-d\ mod\ 2\)。
思考了这么多,可能会想到直接遍历树,并用两个线段树或者平衡树什么的分别维护最后一个好点到子树根距离为奇数和偶数的,在每个点处合并。但是这样非常难写,其实用一个点分治就解决了。对于每一层分治,仍然是套路地计算当前重心的不同子树互相之间的贡献,以及重心和子树之间的贡献。维护信息仍然是用两棵线段树,分别维护最后一个好点到子树根距离为奇数和偶数的信息。从左到右遍历重心的所有子树,对于每个子树,先把它dfs一遍,把子树中的节点挨个当做起点,用线段树中的信息求最长路径;然后再dfs一遍,把子树内节点作为终点的贡献也计入线段树。从前往后遍历一遍,再从后往前做同样的操作,就能处理所有的贡献关系了。哦,还需要第三棵线段树维护压根没有好点的信息。两次dfs时需要维护巨量的信息,细节很多,写的时候要小心。
时间复杂度\(O(nlog^2n)\)。
点击查看代码
#include <bits/stdc++.h>
#define rep(i,n) for(int i=0;i<n;++i)
#define repn(i,n) for(int i=1;i<=n;++i)
#define LL long long
#define pii pair <int,int>
#define fi first
#define se second
#define mpr make_pair
#define pb push_back
void fileio()
{
#ifdef LGS
freopen("in.txt","r",stdin);
freopen("out.txt","w",stdout);
#endif
}
void termin()
{
#ifdef LGS
std::cout<<"\n\nEXECUTION TERMINATED";
#endif
exit(0);
}
using namespace std;
int n,hei[200010],sz[200010],ans[200010];
vector <int> g[200010];
queue <int> q;
bool vis[200010],isGood[200010];
pii mn;
struct SegTree
{
int n2,dat[800010],tag[800010];
//tag=clear ? 1:0
void init()
{
n2=1;while(n2<n) n2*=2;
rep(i,n2+n2+3) dat[i]=-1e9,tag[i]=0;
}
void addTag(){dat[0]=-1e9;tag[0]=1;}
void pushDown(int k)
{
tag[k]=0;
tag[k+k+1]=tag[k+k+2]=1;
dat[k+k+1]=dat[k+k+2]=-1e9;
}
void upd(int k,int lb,int ub,int to,int val)
{
if(lb==ub)
{
if(val>dat[k]) dat[k]=val;
return;
}
if(tag[k]) pushDown(k);
int mid=(lb+ub)>>1;
if(to<=mid) upd(k+k+1,lb,mid,to,val);else upd(k+k+2,mid+1,ub,to,val);
dat[k]=max(dat[k+k+1],dat[k+k+2]);
}
void upd(int k,int val){upd(0,0,n2-1,k,val);}
int qry(int k,int lb,int ub,int tlb,int tub)
{
if(ub<tlb||tub<lb) return -1e9;
if(tlb<=lb&&ub<=tub) return dat[k];
if(tag[k]) pushDown(k);
return max(qry(k+k+1,lb,(lb+ub)/2,tlb,tub),qry(k+k+2,(lb+ub)/2+1,ub,tlb,tub));
}
}t0,t1,tn;
int qry(int cur,int depw,int tot,int cnt1)
{
int v0=t0.qry(0,0,t0.n2-1,0,cur),v1=t1.qry(0,0,t1.n2-1,0,cur);
int ret;
if(depw==0) ret=max(v0,v1-1);else ret=max(v0-1,v1);
ret+=cnt1*2;
ret=max(ret,tot+tn.qry(0,0,tn.n2-1,0,cur));
return ret;
}
int dfsSz(int pos,int par)
{
sz[pos]=1;
rep(i,g[pos].size()) if(g[pos][i]!=par&& !vis[g[pos][i]]) sz[pos]+=dfsSz(g[pos][i],pos);
return sz[pos];
}
void findCen(int pos,int par,int tot)
{
int mx=tot-sz[pos];
rep(i,g[pos].size()) if(g[pos][i]!=par&& !vis[g[pos][i]])
{
findCen(g[pos][i],pos,tot);
mx=max(mx,sz[g[pos][i]]);
}
mn=min(mn,mpr(mx,pos));
}
void dfsAdd(int pos,int par,int gooddep,int tot,int cnt1,int mnv,int curv,int dep)
{
if(isGood[pos]) gooddep=dep,tot=cnt1*2;
//if(pos==7) cout<<mnv<<' '<<tot<<' '<<gooddep<<' '<<dep<<endl;
if(gooddep>=0&&gooddep%2==0) t0.upd(-mnv,tot);
else if(gooddep>=0&&gooddep%2==1) t1.upd(-mnv,tot);
else tn.upd(-mnv,tot);
rep(i,g[pos].size()) if(g[pos][i]!=par&& !vis[g[pos][i]]&&hei[pos]>=hei[g[pos][i]])
{
int ev=(hei[pos]>hei[g[pos][i]] ? 1:-1),ntot=tot+1,ngd=gooddep,nc1=cnt1+(int)(ev==1);
if(isGood[g[pos][i]]) ngd=dep+1,ntot=nc1*2;
dfsAdd(g[pos][i],pos,ngd,ntot,nc1,min(mnv,curv+ev),curv+ev,dep+1);
}
}
void dfsAsk(int pos,int par,int gooddep,int tot,int cnt1,int mnv,int curv,int dep)
{
if(isGood[pos]&&gooddep==-1) gooddep=dep;
//cout<<pos<<'p'<<' '<<par<<' '<<vis[pos]<<endl;
if(mnv>=0) ans[pos]=max(ans[pos],qry(curv,dep%2,tot-(gooddep>-1&&(dep-gooddep)%2==1 ? 1:0),cnt1));
rep(i,g[pos].size()) if(g[pos][i]!=par&& !vis[g[pos][i]]&&hei[pos]<=hei[g[pos][i]])
{
int ev=(hei[pos]<hei[g[pos][i]] ? 1:-1),ntot=(gooddep==-1 ? tot+1:tot+2*(int)(ev==1)),nc1=cnt1+max(ev,0);
dfsAsk(g[pos][i],pos,gooddep,ntot,nc1,min(0,min(ev,ev+mnv)),curv+ev,dep+1);
}
}
void solve(int pos)
{
dfsSz(pos,0);
mn=mpr(1e9,1e9);findCen(pos,0,sz[pos]);
pos=mn.se;
vector <int> son;
rep(i,g[pos].size()) if(!vis[g[pos][i]]) son.pb(g[pos][i]);
//cout<<"curpos "<<pos<<endl;
rep(i,son.size())
{
if(hei[son[i]]>=hei[pos])
dfsAsk(son[i],pos,-1,1,(hei[son[i]]==hei[pos] ? 0:1),(hei[son[i]]==hei[pos] ? -1:0),(hei[son[i]]==hei[pos] ? -1:1),1);
if(hei[son[i]]<=hei[pos])
{
int ev=(hei[son[i]]<hei[pos] ? 1:-1);
if(isGood[pos]) dfsAdd(son[i],pos,0,1,(hei[son[i]]<hei[pos] ? 1:0),min(ev,0),ev,1);
else dfsAdd(son[i],pos,-1,1,(hei[son[i]]<hei[pos] ? 1:0),min(ev,0),ev,1);
}
}
//if(pos==4) cout<<qry(0,0,0,0)<<' '<<t1.qry(0,0,tn.n2-1,0,0)<<endl;
ans[pos]=max(ans[pos],qry(0,0,0,0));
tn.addTag();t0.addTag();t1.addTag();
if(isGood[pos]) t0.upd(0,0);
else tn.upd(0,0);
for(int i=((int)son.size())-1;i>=0;--i)
{
if(hei[son[i]]>=hei[pos])
dfsAsk(son[i],pos,-1,1,(hei[son[i]]==hei[pos] ? 0:1),(hei[son[i]]==hei[pos] ? -1:0),(hei[son[i]]==hei[pos] ? -1:1),1);
if(hei[son[i]]<=hei[pos])
{
int ev=(hei[son[i]]<hei[pos] ? 1:-1);
if(isGood[pos]) dfsAdd(son[i],pos,0,1,(hei[son[i]]<hei[pos] ? 1:0),min(ev,0),ev,1);
else dfsAdd(son[i],pos,-1,1,(hei[son[i]]<hei[pos] ? 1:0),min(ev,0),ev,1);
}
}
tn.addTag();t0.addTag();t1.addTag();
vis[pos]=true;
rep(i,son.size()) solve(son[i]);
}
int main()
{
fileio();
cin>>n;
rep(i,n+3) hei[i]=1e9;
repn(i,n)
{
int x;
scanf("%d",&x);
if(x==1) q.push(i),hei[i]=0;
}
int x,y;
rep(i,n-1)
{
scanf("%d%d",&x,&y);
g[x].pb(y);g[y].pb(x);
}
while(!q.empty())
{
int f=q.front();q.pop();
rep(i,g[f].size()) if(hei[g[f][i]]==1e9)
{
hei[g[f][i]]=hei[f]+1;
q.push(g[f][i]);
}
}
repn(i,n) rep(j,g[i].size()) if(hei[i]==hei[g[i][j]]) isGood[i]=true;
t0.init();t1.init();tn.init();
solve(1);
repn(i,n) printf("%d ",ans[i]);
puts("");
termin();
}